sseung0703 commited on
Commit
e8c4ed3
1 Parent(s): 8efbff1
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,5 +1,5 @@
1
  # Putting NeRF on a Diet: Semantically Consistent Few-Shot View Synthesis Implementation
2
- ###** WARNING : it is not the completed REAME (Until Thursday)**
3
  <p align="center"><img width="450" alt="스크린샷 2021-07-04 오후 4 11 51" src="https://user-images.githubusercontent.com/77657524/126361638-4aad58e8-4efb-4fc5-bf78-f53d03799e1e.png"></p>
4
 
5
  the Pytorch, JAX/Flax based code implementation of this paper [Putting NeRF on a Diet : Ajay Jain, Matthew Tancik, Pieter Abbeel, Arxiv : https://arxiv.org/abs/2104.00677]
@@ -94,57 +94,38 @@ python -m train \
94
  ```
95
  You can toggle the semantic loss by “use_semantic_loss” in configuration files.
96
 
97
- ## 💎 Performance
98
-
99
- ### ❗️ Performance Tables
100
- #### 4 Shot Blender Dataset PSNR Result
101
-
102
- | Scene | Chair | Drums | Ficus | Hotdog | Lego | Materials | Mic | Ship | Mean |
103
- |---------|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|
104
- | NeRF | 33.00 | 25.01 | 30.13 | 36.18 | 32.54 | 29.62 | 32.91 | 28.65 | 31.01 |
105
- | DietNeRF | **34.08** | **25.03** | **30.43** | **36.92** | **33.28** | **29.91** | **34.53** | **29.36** | **31.69** |
106
 
107
- #### Loss Graph Comparison btw NeRF vs DietNeRF in Drum Scene
108
 
109
- <p align="center"><img width="400" alt="스크린샷 2021-07-04 오후 4 11 51" src="https://user-images.githubusercontent.com/77657524/126384510-423b9070-a3e5-4e18-8b4e-30c15c5b39c6.png">
110
- </p>
111
 
 
 
 
 
 
 
112
 
113
- ### ❗ Rendering GIF images by 8-shot learned Diet-NeRF
114
 
115
  DietNeRF has a strong capacity to generalise on novel and challenging views with EXTREMELY SMALL TRAINING SAMPLES!
116
  The animations below shows the performance difference between DietNeRF (left) v.s. NeRF (right) with only 4 training images:
117
 
118
- #### SHIP
119
- ![Text](./assets/ship-dietnerf.gif) ![Alt Text](./assets/ship-nerf.gif)
120
-
121
- #### LEGO
122
- ![Text](./assets/ship-dietnerf.gif) ![Alt Text](./assets/ship-nerf.gif)
123
-
124
- #### HOTDOG
125
- ![Text](./assets/ship-dietnerf.gif) ![Alt Text](./assets/ship-nerf.gif)
126
 
127
 
128
- ### ❗ Rendered Rendering images by 4-shot learned Diet-NeRF vs Vanilla-NeRF
 
 
 
129
 
130
  #### SHIP
131
- @ will be filled
132
-
133
- #### LEGO
134
- @ will be filled
135
-
136
- #### HOTDOG
137
- @ will be filled
138
-
139
- ### ❗ Rendered examples by occluded 14-shot learned NeRF and Diet-NeRF
140
- This result is on the quite initial state and expected to be improved.
141
-
142
- #### Training poses
143
- <img width="1400" src="https://user-images.githubusercontent.com/26036843/126111980-4f332c87-a7f0-42e0-a355-8e77621bbca4.png">
144
-
145
- #### Rendered novel poses
146
- <img width="800" src="https://user-images.githubusercontent.com/26036843/126113080-a6a48f3d-2629-4efc-a740-fe908ca6b5c3.png">
147
-
148
 
149
  ## 🤩 Demo
150
 
 
1
  # Putting NeRF on a Diet: Semantically Consistent Few-Shot View Synthesis Implementation
2
+
3
  <p align="center"><img width="450" alt="스크린샷 2021-07-04 오후 4 11 51" src="https://user-images.githubusercontent.com/77657524/126361638-4aad58e8-4efb-4fc5-bf78-f53d03799e1e.png"></p>
4
 
5
  the Pytorch, JAX/Flax based code implementation of this paper [Putting NeRF on a Diet : Ajay Jain, Matthew Tancik, Pieter Abbeel, Arxiv : https://arxiv.org/abs/2104.00677]
 
94
  ```
95
  You can toggle the semantic loss by “use_semantic_loss” in configuration files.
96
 
97
+ ## 💎 Expriment Result
 
 
 
 
 
 
 
 
98
 
 
99
 
100
+ ### Rendered Rendering images by 8-shot learned Diet-NeRF (200000 iter)
101
+ ### CHAIR / HOTDOG / DRUM
102
 
103
+ <p align="center">
104
+ <table>
105
+ <tr>
106
+ <td><img alt="" src="./assets/chair.png" width="300"/></td><td><img alt="" src="./assets/hotdog.png" width="300"/></td><td><img alt="" src="./assets/drum.png" width="300"/></td>
107
+ <tr>
108
+ </table></p>
109
 
110
+ ### ❗ Rendering GIF images by 4-shot learned Diet-NeRF and Diet-NeRF (50000 iter)
111
 
112
  DietNeRF has a strong capacity to generalise on novel and challenging views with EXTREMELY SMALL TRAINING SAMPLES!
113
  The animations below shows the performance difference between DietNeRF (left) v.s. NeRF (right) with only 4 training images:
114
 
 
 
 
 
 
 
 
 
115
 
116
 
117
+ ### ❗ Rendered GIF by occluded 14-shot learned NeRF and Diet-NeRF (100000 iter)
118
+ We made aritificial occulusion on the right side of image.
119
+ The reconstruction quality can be compared with this experiment.
120
+ Diet NeRF shows better quailty than Original NeRF when It is occulused.
121
 
122
  #### SHIP
123
+ <p align="center">
124
+ <table>
125
+ <tr>
126
+ <td><img alt="" src="./assets/ship-dietnerf.gif" width="300"/></td><td><img alt="" src="./assets/ship-nerf.gif" width="300"/></td>
127
+ <tr>
128
+ </table></p>
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  ## 🤩 Demo
131
 
__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
assets/chair.png ADDED
assets/drawing.png ADDED
assets/drum.png ADDED
assets/hotdog.png ADDED
configs/blender.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 128
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 4096
9
+ randomized: true
configs/demo.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 128
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 1024
9
+ randomized: true
10
+ max_steps: 50000
configs/diet_nerf_tpu_vm_4shot.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 128
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 1024
9
+ randomized: true
10
+ max_steps: 200000
11
+ print_every: 100
12
+ render_every: 500
13
+ save_every: 5000
14
+ use_semantic_loss: true
15
+ clip_model_name: openai/clip-vit-base-patch32
16
+ clip_output_dtype: float16
17
+ sc_loss_every: 16
18
+ few_shot: 4
configs/diet_nerf_tpu_vm_few_shot.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 128
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 1024
9
+ randomized: true
10
+ max_steps: 200000
11
+ print_every: 100
12
+ render_every: 500
13
+ save_every: 5000
14
+ use_semantic_loss: true
15
+ clip_model_name: openai/clip-vit-base-patch32
16
+ clip_output_dtype: float16
17
+ sc_loss_every: 16
18
+ few_shot: 8
configs/diet_nerf_tpu_vm_test.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 64
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 1026
9
+ randomized: true
10
+ max_steps: 200000
11
+ print_every: 100
12
+ render_every: 1000
13
+ save_every: 5000
14
+ use_semantic_loss: true
15
+ clip_model_name: openai/clip-vit-base-patch32
16
+ clip_output_dtype: float16
17
+ sc_loss_every: 16
18
+ few_shot: -1
configs/eval_diet_nerf_tpu_vm_few_shot.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 128
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 1024
9
+ randomized: true
10
+ max_steps: 200000
11
+ print_every: 100
12
+ render_every: 5000
13
+ save_every: 5000
14
+ use_semantic_loss: true
15
+ clip_model_name: openai/clip-vit-base-patch32
16
+ clip_output_dtype: float32
17
+ sc_loss_factor: 4
18
+ sc_loss_every: 16
19
+ sc_loss_mult: 10
20
+ few_shot: 8
21
+ spherify: True
22
+ lindisp: True
configs/nerf_tpu_vm_4shot.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 128
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 1024
9
+ randomized: true
10
+ max_steps: 200000
11
+ print_every: 100
12
+ render_every: 500
13
+ save_every: 5000
14
+ use_semantic_loss: false
15
+ clip_model_name: openai/clip-vit-base-patch32
16
+ clip_output_dtype: float32
17
+ sc_loss_every: 16
18
+ few_shot: 4
configs/nerf_tpu_vm_few_shot.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 128
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 1024
9
+ randomized: true
10
+ max_steps: 200000
11
+ print_every: 100
12
+ render_every: 500
13
+ save_every: 5000
14
+ use_semantic_loss: false
15
+ clip_model_name: openai/clip-vit-base-patch32
16
+ clip_output_dtype: float32
17
+ sc_loss_every: 16
18
+ few_shot: 8
configs/orig_nerf_tpu_vm_full.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 128
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 1024
9
+ randomized: true
10
+ max_steps: 200000
11
+ print_every: 1000
12
+ render_every: 5000
13
+ save_every: 5000
configs/orig_nerf_tpu_vm_test.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 128
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 1024
9
+ randomized: true
10
+ max_steps: 200000
11
+ print_every: 100
12
+ render_every: 500
13
+ save_every: 500
eval.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Lint as: python3
17
+ """Evaluation script for Nerf."""
18
+ import math
19
+ import glob
20
+ import os
21
+ from os import path
22
+ import functools
23
+
24
+ from absl import app
25
+ from absl import flags
26
+ import flax
27
+ from flax.metrics import tensorboard
28
+ from flax.training import checkpoints
29
+ import jax
30
+ from jax import random
31
+ import tensorflow as tf
32
+
33
+ from tqdm import tqdm
34
+ import cv2
35
+ import numpy as np
36
+ from PIL import Image
37
+
38
+ from nerf import datasets
39
+ from nerf import models
40
+ from nerf import utils
41
+ from nerf import clip_utils
42
+
43
+ FLAGS = flags.FLAGS
44
+ utils.define_flags()
45
+
46
+
47
+
48
+ def compute_lpips(image1, image2, model):
49
+ """Compute the LPIPS metric."""
50
+ # The LPIPS model expects a batch dimension.
51
+ return model(
52
+ tf.convert_to_tensor(image1[None, Ellipsis]),
53
+ tf.convert_to_tensor(image2[None, Ellipsis]))[0]
54
+
55
+
56
+ def predict_to_image(pred_out):
57
+ image_arr = np.array(np.clip(pred_out, 0., 1.) * 255.).astype(np.uint8)
58
+ return Image.fromarray(image_arr)
59
+
60
+
61
+ def main(unused_argv):
62
+ # Hide the GPUs and TPUs from TF so it does not reserve memory on them for
63
+ # LPIPS computation or dataset loading.
64
+ tf.config.experimental.set_visible_devices([], "GPU")
65
+ tf.config.experimental.set_visible_devices([], "TPU")
66
+
67
+ #wandb.init(project="hf-flax-clip-nerf", entity="wandb", sync_tensorboard=True)
68
+
69
+ rng = random.PRNGKey(20200823)
70
+
71
+ if FLAGS.config is not None:
72
+ utils.update_flags(FLAGS)
73
+ if FLAGS.train_dir is None:
74
+ raise ValueError("train_dir must be set. None set now.")
75
+ if FLAGS.data_dir is None:
76
+ raise ValueError("data_dir must be set. None set now.")
77
+
78
+ dataset = datasets.get_dataset("test", FLAGS)
79
+ rng, key = random.split(rng)
80
+ model, init_variables = models.get_model(key, dataset.peek(), FLAGS)
81
+ optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables)
82
+ state = utils.TrainState(optimizer=optimizer)
83
+ del optimizer, init_variables
84
+ state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
85
+
86
+ # Rendering is forced to be deterministic even if training was randomized, as
87
+ # this eliminates "speckle" artifacts.
88
+ def render_fn(variables, key_0, key_1, rays):
89
+ return model.apply(variables, key_0, key_1, rays, False)
90
+
91
+ # pmap over only the data input.
92
+ render_pfn = jax.pmap(
93
+ render_fn,
94
+ in_axes=(None, None, None, 0),
95
+ donate_argnums=3,
96
+ axis_name="batch",
97
+ )
98
+
99
+ # Compiling to the CPU because it's faster and more accurate.
100
+ ssim_fn = jax.jit(
101
+ functools.partial(utils.compute_ssim, max_val=1.), backend="cpu")
102
+
103
+ last_step = 0
104
+ out_dir = path.join(FLAGS.train_dir, "path_renders" if FLAGS.render_path else "test_preds")
105
+ os.makedirs(out_dir, exist_ok=True)
106
+ if FLAGS.save_output:
107
+ print(f'eval output will be saved: {out_dir}')
108
+ else:
109
+ print(f'eval output will not be saved')
110
+
111
+ if not FLAGS.eval_once:
112
+ summary_writer = tensorboard.SummaryWriter(
113
+ path.join(FLAGS.train_dir, "eval"))
114
+
115
+ def generate_spinning_gif(radius, phi, gif_fn, frame_n):
116
+ _rng = random.PRNGKey(0)
117
+ partial_render_fn = functools.partial(render_pfn, state.optimizer.target)
118
+ gif_images = []
119
+ for theta in tqdm(np.linspace(-math.pi, math.pi, frame_n)):
120
+ camtoworld = np.array(clip_utils.pose_spherical(radius, theta, phi))
121
+ rays = dataset.camtoworld_matrix_to_rays(camtoworld, downsample=4)
122
+ _rng, key0, key1 = random.split(_rng, 3)
123
+ color, _, _ = utils.render_image(partial_render_fn, rays,
124
+ _rng, False, chunk=4096)
125
+ image = predict_to_image(color)
126
+ gif_images.append(image)
127
+ gif_images[0].save(gif_fn, save_all=True,
128
+ append_images=gif_images,
129
+ duration=100, loop=0)
130
+ return gif_images
131
+
132
+ if FLAGS.generate_gif_only:
133
+ print('generate GIF file only')
134
+ _radius = 4.
135
+ _phi = (30 * math.pi) / 180
136
+ _gif_fn = os.path.join(out_dir, 'spinning.gif')
137
+ generate_spinning_gif(_radius, _phi, _gif_fn, frame_n=30)
138
+ print(f'GIF file for spinning views written: {_gif_fn}')
139
+ return
140
+ else:
141
+ print('generate GIF file AND evaluate model performance')
142
+
143
+ is_gif_written = False
144
+ while True:
145
+ step = int(state.optimizer.state.step)
146
+ if step <= last_step:
147
+ continue
148
+ if FLAGS.save_output and (not utils.isdir(out_dir)):
149
+ utils.makedirs(out_dir)
150
+ psnr_values = []
151
+ ssim_values = []
152
+ #lpips_values = []
153
+ if not FLAGS.eval_once:
154
+ showcase_index = np.random.randint(0, dataset.size)
155
+ for idx in range(dataset.size):
156
+ print(f"Evaluating {idx + 1}/{dataset.size}")
157
+ batch = next(dataset)
158
+ pred_color, pred_disp, pred_acc = utils.render_image(
159
+ functools.partial(render_pfn, state.optimizer.target),
160
+ batch["rays"],
161
+ rng,
162
+ FLAGS.dataset == "llff",
163
+ chunk=FLAGS.chunk)
164
+ if jax.host_id() != 0: # Only record via host 0.
165
+ continue
166
+ if not FLAGS.eval_once and idx == showcase_index:
167
+ showcase_color = pred_color
168
+ showcase_disp = pred_disp
169
+ showcase_acc = pred_acc
170
+ if not FLAGS.render_path:
171
+ showcase_gt = batch["pixels"]
172
+ if not FLAGS.render_path:
173
+ psnr = utils.compute_psnr(((pred_color - batch["pixels"]) ** 2).mean())
174
+ ssim = ssim_fn(pred_color, batch["pixels"])
175
+ #lpips = compute_lpips(pred_color, batch["pixels"], lpips_model)
176
+ print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}")
177
+ psnr_values.append(float(psnr))
178
+ ssim_values.append(float(ssim))
179
+ #lpips_values.append(float(lpips))
180
+ if FLAGS.save_output:
181
+ utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx)))
182
+ utils.save_img(pred_disp[Ellipsis, 0],
183
+ path.join(out_dir, "disp_{:03d}.png".format(idx)))
184
+ if (not FLAGS.eval_once) and (jax.host_id() == 0):
185
+ summary_writer.image("pred_color", showcase_color, step)
186
+ summary_writer.image("pred_disp", showcase_disp, step)
187
+ summary_writer.image("pred_acc", showcase_acc, step)
188
+ if not FLAGS.render_path:
189
+ summary_writer.scalar("psnr", np.mean(np.array(psnr_values)), step)
190
+ summary_writer.scalar("ssim", np.mean(np.array(ssim_values)), step)
191
+ #summary_writer.scalar("lpips", np.mean(np.array(lpips_values)), step)
192
+ summary_writer.image("target", showcase_gt, step)
193
+
194
+ if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0):
195
+ with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f:
196
+ f.write(" ".join([str(v) for v in psnr_values]))
197
+ with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f:
198
+ f.write(" ".join([str(v) for v in ssim_values]))
199
+ #with utils.open_file(path.join(out_dir, f"lpips_{step}.txt"), "w") as f:
200
+ #f.write(" ".join([str(v) for v in lpips_values]))
201
+ with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f:
202
+ f.write("{}".format(np.mean(np.array(psnr_values))))
203
+ with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f:
204
+ f.write("{}".format(np.mean(np.array(ssim_values))))
205
+ #with utils.open_file(path.join(out_dir, "lpips.txt"), "w") as f:
206
+ #f.write("{}".format(np.mean(np.array(lpips_values))))
207
+ print(f'performance metrics written as txt files: {out_dir}')
208
+
209
+ imglist = glob.glob(os.path.join(out_dir, "[0-9][0-9][0-9].png"))
210
+ sorted_files = sorted(imglist, key=lambda x: int(x.split('/')[-1].split('.')[0]))
211
+ fourcc = cv2.VideoWriter_fourcc(*'MP4V')
212
+ fps = 10.0
213
+ img = cv2.imread(sorted_files[0], cv2.IMREAD_COLOR)
214
+ video_fn = os.path.join(out_dir, "rendering_video.mp4")
215
+ out = cv2.VideoWriter(video_fn, fourcc, fps,
216
+ (img.shape[1], img.shape[0]))
217
+
218
+ for i in range(len(sorted_files)):
219
+ img = cv2.imread(sorted_files[i], cv2.IMREAD_COLOR)
220
+ out.write(img)
221
+ out.release()
222
+ print(f'video file written: {video_fn}')
223
+
224
+ # write gif file for spinning views of a scene
225
+ if not is_gif_written:
226
+ _radius = 4.
227
+ _phi = (30 * math.pi) / 180
228
+ _gif_fn = os.path.join(out_dir, 'spinning.gif')
229
+ generate_spinning_gif(_radius, _phi, _gif_fn, frame_n=30)
230
+ print(f'GIF file for spinning views written: {_gif_fn}')
231
+ is_gif_written = True
232
+
233
+ if FLAGS.eval_once:
234
+ break
235
+ if int(step) >= FLAGS.max_steps:
236
+ break
237
+ last_step = step
238
+
239
+
240
+ if __name__ == "__main__":
241
+ app.run(main)
eval.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The Google Research Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ #!/bin/bash
16
+ CONFIG=$1
17
+ DATA_ROOT=$2
18
+ ROOT_DIR=/tmp/jaxnerf/"$CONFIG"
19
+ if [ $CONFIG == "llff" ]
20
+ then
21
+ SCENES="room fern leaves fortress orchids flower trex horns"
22
+ DATA_FOLDER="nerf_llff_data"
23
+ else
24
+ SCENES="lego chair drums ficus hotdog materials mic ship"
25
+ DATA_FOLDER="nerf_synthetic"
26
+ fi
27
+
28
+ # launch evaluation jobs for all scenes.
29
+ for scene in $SCENES; do
30
+ python -m jaxnerf.eval \
31
+ --data_dir="$DATA_ROOT"/"$DATA_FOLDER"/"$scene" \
32
+ --train_dir="$ROOT_DIR"/"$scene" \
33
+ --chunk=4096 \
34
+ --config=configs/"$CONFIG"
35
+ done
36
+
37
+ # collect PSNR of all scenes.
38
+ touch "$ROOT_DIR"/psnr.txt
39
+ for scene in $SCENES; do
40
+ printf "${scene}: " >> "$ROOT_DIR"/psnr.txt
41
+ cat "$ROOT_DIR"/"$scene"/test_preds/psnr.txt >> \
42
+ "$ROOT_DIR"/psnr.txt
43
+ printf $'\n' >> "$ROOT_DIR"/psnr.txt
44
+ done
example_data/imgs/r_0.png ADDED
example_data/transforms_test.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"camera_angle_x": 0.6911112070083618, "frames": [{"file_path": "./imgs/r_0", "rotation": 0.012566370614359171, "transform_matrix": [[-0.9999021887779236, 0.004192245192825794, -0.013345719315111637, -0.05379832163453102], [-0.013988681137561798, -0.2996590733528137, 0.95394366979599, 3.845470428466797], [-4.656612873077393e-10, 0.9540371894836426, 0.29968830943107605, 1.2080823183059692], [0.0, 0.0, 0.0, 1.0]]}]}
example_data/transforms_train.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"camera_angle_x": 0.6911112070083618, "frames": [{"file_path": "./imgs/r_0", "rotation": 0.012566370614359171, "transform_matrix": [[-0.9999021887779236, 0.004192245192825794, -0.013345719315111637, -0.05379832163453102], [-0.013988681137561798, -0.2996590733528137, 0.95394366979599, 3.845470428466797], [-4.656612873077393e-10, 0.9540371894836426, 0.29968830943107605, 1.2080823183059692], [0.0, 0.0, 0.0, 1.0]]}]}
fork-of-first-touch-of-nerf-in-jax.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
nerf/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
nerf/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (137 Bytes). View file
 
nerf/__pycache__/clip_utils.cpython-37.pyc ADDED
Binary file (5.16 kB). View file
 
nerf/__pycache__/datasets.cpython-37.pyc ADDED
Binary file (18.3 kB). View file
 
nerf/__pycache__/model_utils.cpython-37.pyc ADDED
Binary file (10 kB). View file
 
nerf/__pycache__/models.cpython-37.pyc ADDED
Binary file (5.08 kB). View file
 
nerf/__pycache__/utils.cpython-37.pyc ADDED
Binary file (15.8 kB). View file
 
nerf/clip_utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+ from absl import flags
4
+ from functools import partial
5
+
6
+ import jax
7
+ from jax import random
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+ from transformers import FlaxCLIPModel
11
+
12
+ from nerf import utils
13
+
14
+ FLAGS = flags.FLAGS
15
+
16
+ @partial(jax.jit, static_argnums=[0])
17
+ def semantic_loss(clip_model, src_image, target_embedding):
18
+ #c_image = utils.unshard(src_image[0])
19
+ f_image = utils.unshard(src_image[-1])
20
+
21
+ w = int(math.sqrt(src_image[-1].size//3))
22
+ #c_image = c_image.reshape([w, w, 3])
23
+ f_image = f_image.reshape([w, w, 3])
24
+
25
+ src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.expand_dims(f_image,0).transpose(0, 3, 1, 2)))
26
+ #src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image, f_image]).transpose(0, 3, 1, 2)))
27
+ src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
28
+ sc_loss = 1 - jnp.sum(src_embedding * target_embedding)
29
+ return sc_loss, f_image
30
+
31
+ def semantic_step_multi(render_pfn, clip_model, rng, state, batch, lr):
32
+ random_rays = jax.tree_map(lambda x: utils.shard(x).astype(jnp.float16), batch["random_rays"])
33
+ target_embedding = batch["embedding"].astype(jnp.float16)
34
+ rng, key_0, key_1 = random.split(rng,3)
35
+
36
+ def loss_fn(variables):
37
+ src_image = render_pfn(variables, key_0, key_1, random_rays)
38
+ sc_loss, f_image = semantic_loss(clip_model, src_image, target_embedding)
39
+ return sc_loss * FLAGS.sc_loss_mult, f_image
40
+ (sc_loss, src_image), grad = jax.value_and_grad(loss_fn, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
41
+ return sc_loss, grad, src_image
42
+
43
+ @partial(jax.jit, static_argnums=[0, 1])
44
+ def semantic_step_single(model, clip_model, rng, state, batch, lr):
45
+ batch = jax.tree_map(lambda x: x.astype(jnp.float16), batch)
46
+ # the batch is without shard
47
+ random_rays = batch["random_rays"]
48
+ rng, key_0, key_1 = random.split(rng,3)
49
+
50
+ def semantic_loss(variables):
51
+ c_image, f_image = model.apply(variables, key_0, key_1, random_rays, False, rgb_only = True)
52
+ # reshape flat pixel to an image (assume 3 channels & square shape)
53
+ w = int(math.sqrt(f_image.shape[0]))
54
+ # c_image = c_image.reshape([w, w, 3])
55
+ f_image = f_image.reshape([w, w, 3])
56
+
57
+ src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.expand_dims(f_image,0).transpose(0, 3, 1, 2)))
58
+ # src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image, f_image]).transpose(0, 3, 1, 2)))
59
+ src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
60
+ target_embedding = batch["embedding"]
61
+ sc_loss = 0.5 * jnp.sum((src_embedding - target_embedding)**2)
62
+ return sc_loss * FLAGS.sc_loss_mult, f_image
63
+ (sc_loss, src_image), grad = jax.value_and_grad(semantic_loss, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
64
+ return sc_loss, grad, src_image
65
+
66
+ def trans_t(t):
67
+ return jnp.array([
68
+ [1, 0, 0, 0],
69
+ [0, 1, 0, 0],
70
+ [0, 0, 1, t],
71
+ [0, 0, 0, 1]], dtype=jnp.float32)
72
+
73
+ def rot_phi(phi):
74
+ return jnp.array([
75
+ [1, 0, 0, 0],
76
+ [0, jnp.cos(phi), jnp.sin(phi), 0],
77
+ [0,-jnp.sin(phi), jnp.cos(phi), 0],
78
+ [0, 0, 0, 1]], dtype=jnp.float32)
79
+
80
+ def rot_theta(th):
81
+ return jnp.array([
82
+ [jnp.cos(th), 0,-jnp.sin(th), 0],
83
+ [0, 1, 0, 0],
84
+ [jnp.sin(th), 0, jnp.cos(th), 0],
85
+ [0, 0, 0, 1]], dtype=jnp.float32)
86
+
87
+ def pose_spherical(radius, theta, phi):
88
+ c2w = trans_t(radius)
89
+ c2w = rot_phi(phi) @ c2w
90
+ c2w = rot_theta(theta) @ c2w
91
+ c2w = jnp.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
92
+ return c2w
93
+
94
+ def random_pose(rng, bds):
95
+ rng, *rng_inputs = jax.random.split(rng, 3)
96
+ radius = random.uniform(rng_inputs[1], minval=bds[0], maxval=bds[1])
97
+ theta = random.uniform(rng_inputs[1], minval=-jnp.pi, maxval=jnp.pi)
98
+ phi = random.uniform(rng_inputs[1], minval=0, maxval=jnp.pi/2)
99
+ return pose_spherical(radius, theta, phi)
100
+
101
+ def preprocess_for_CLIP(image):
102
+ """
103
+ jax-based preprocessing for CLIP
104
+ image [B, 3, H, W]: batch image
105
+ return [B, 3, 224, 224]: pre-processed image for CLIP
106
+ """
107
+ B, D, H, W = image.shape
108
+ mean = jnp.array([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1)
109
+ std = jnp.array([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1)
110
+ image = jax.image.resize(image, (B, D, 224, 224), 'bicubic') # assume that images have rectangle shape.
111
+ image = (image - mean.astype(image.dtype)) / std.astype(image.dtype)
112
+ return image
113
+
114
+ def init_CLIP(dtype: str, model_name: Optional[str]) -> FlaxCLIPModel:
115
+ if dtype == 'float16':
116
+ dtype = jnp.float16
117
+ elif dtype == 'float32':
118
+ dtype = jnp.float32
119
+ else:
120
+ raise ValueError
121
+
122
+ if model_name is None:
123
+ model_name = 'openai/clip-vit-base-patch32'
124
+
125
+ return FlaxCLIPModel.from_pretrained(model_name, dtype=dtype)
nerf/datasets.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Lint as: python3
17
+ """Different datasets implementation plus a general port for all the datasets."""
18
+ INTERNAL = False # pylint: disable=g-statement-before-imports
19
+ import json
20
+ import os, time
21
+ from os import path
22
+ import queue
23
+ import threading
24
+
25
+ if not INTERNAL:
26
+ import cv2 # pylint: disable=g-import-not-at-top
27
+ import jax
28
+ import numpy as np
29
+ from PIL import Image
30
+
31
+ from nerf import utils
32
+ from nerf import clip_utils
33
+
34
+ def get_dataset(split, args, clip_model = None):
35
+ return dataset_dict[args.dataset](split, args, clip_model)
36
+
37
+
38
+ def convert_to_ndc(origins, directions, focal, w, h, near=1.):
39
+ """Convert a set of rays to NDC coordinates."""
40
+ # Shift ray origins to near plane
41
+ t = -(near + origins[..., 2]) / directions[..., 2]
42
+ origins = origins + t[..., None] * directions
43
+
44
+ dx, dy, dz = tuple(np.moveaxis(directions, -1, 0))
45
+ ox, oy, oz = tuple(np.moveaxis(origins, -1, 0))
46
+
47
+ # Projection
48
+ o0 = -((2 * focal) / w) * (ox / oz)
49
+ o1 = -((2 * focal) / h) * (oy / oz)
50
+ o2 = 1 + 2 * near / oz
51
+
52
+ d0 = -((2 * focal) / w) * (dx / dz - ox / oz)
53
+ d1 = -((2 * focal) / h) * (dy / dz - oy / oz)
54
+ d2 = -2 * near / oz
55
+
56
+ origins = np.stack([o0, o1, o2], -1)
57
+ directions = np.stack([d0, d1, d2], -1)
58
+ return origins, directions
59
+
60
+
61
+ class Dataset(threading.Thread):
62
+ """Dataset Base Class."""
63
+
64
+ def __init__(self, split, flags, clip_model):
65
+ super(Dataset, self).__init__()
66
+ self.queue = queue.Queue(3) # Set prefetch buffer to 3 batches.
67
+ self.daemon = True
68
+ self.use_pixel_centers = flags.use_pixel_centers
69
+ self.split = split
70
+
71
+ if split == "train":
72
+ self._train_init(flags, clip_model)
73
+ elif split == "test":
74
+ self._test_init(flags)
75
+ else:
76
+ raise ValueError(
77
+ "the split argument should be either \"train\" or \"test\", set"
78
+ "to {} here.".format(split))
79
+ self.batch_size = flags.batch_size // jax.process_count()
80
+ self.batching = flags.batching
81
+ self.render_path = flags.render_path
82
+ self.far = flags.far
83
+ self.near = flags.near
84
+ self.max_steps = flags.max_steps
85
+ self.start()
86
+
87
+ def __iter__(self):
88
+ return self
89
+
90
+ def __next__(self):
91
+ """Get the next training batch or test example.
92
+
93
+ Returns:
94
+ batch: dict, has "pixels" and "rays".
95
+ """
96
+ x = self.queue.get()
97
+ if self.split == "train":
98
+ return utils.shard(x)
99
+ else:
100
+ return utils.to_device(x)
101
+
102
+ def peek(self):
103
+ """Peek at the next training batch or test example without dequeuing it.
104
+
105
+ Returns:
106
+ batch: dict, has "pixels" and "rays".
107
+ """
108
+ x = self.queue.queue[0].copy() # Make a copy of the front of the queue.
109
+ if self.split == "train":
110
+ return utils.shard(x)
111
+ else:
112
+ return utils.to_device(x)
113
+
114
+ def run(self):
115
+ if self.split == "train":
116
+ next_func = self._next_train
117
+ else:
118
+ next_func = self._next_test
119
+ while True:
120
+ self.queue.put(next_func())
121
+
122
+ @property
123
+ def size(self):
124
+ return self.n_examples
125
+
126
+ def _train_init(self, flags, clip_model):
127
+ """Initialize training."""
128
+ self._load_renderings(flags, clip_model)
129
+ self._generate_rays()
130
+
131
+ if flags.batching == "all_images":
132
+ # flatten the ray and image dimension together.
133
+ self.images = self.images.reshape([-1, 3])
134
+ self.rays = utils.namedtuple_map(lambda r: r.reshape([-1, r.shape[-1]]),
135
+ self.rays)
136
+ elif flags.batching == "single_image":
137
+ self.images = self.images.reshape([-1, self.resolution, 3])
138
+ self.rays = utils.namedtuple_map(
139
+ lambda r: r.reshape([-1, self.resolution, r.shape[-1]]), self.rays)
140
+ else:
141
+ raise NotImplementedError(
142
+ f"{flags.batching} batching strategy is not implemented.")
143
+
144
+ def _test_init(self, flags):
145
+ self._load_renderings(flags, clip_model = None)
146
+ self._generate_rays()
147
+ self.it = 0
148
+
149
+ def _next_train(self):
150
+ """Sample next training batch."""
151
+
152
+ if self.batching == "all_images":
153
+ ray_indices = np.random.randint(0, self.rays[0].shape[0],
154
+ (self.batch_size,))
155
+ batch_pixels = self.images[ray_indices]
156
+ batch_rays = utils.namedtuple_map(lambda r: r[ray_indices], self.rays)
157
+ raise NotImplementedError("image_index not implemented for batching=all_images")
158
+
159
+ elif self.batching == "single_image":
160
+ image_index = np.random.randint(0, self.n_examples, ())
161
+ ray_indices = np.random.randint(0, self.rays[0][0].shape[0],
162
+ (self.batch_size,))
163
+ batch_pixels = self.images[image_index][ray_indices]
164
+ batch_rays = utils.namedtuple_map(lambda r: r[image_index][ray_indices],
165
+ self.rays)
166
+ else:
167
+ raise NotImplementedError(
168
+ f"{self.batching} batching strategy is not implemented.")
169
+ return {"pixels": batch_pixels, "rays": batch_rays, "image_index": image_index}
170
+
171
+ def _next_test(self):
172
+ """Sample next test example."""
173
+ idx = self.it
174
+ self.it = (self.it + 1) % self.n_examples
175
+
176
+ if self.render_path:
177
+ return {"rays": utils.namedtuple_map(lambda r: r[idx], self.render_rays)}
178
+ else:
179
+ return {"pixels": self.images[idx],
180
+ "rays": utils.namedtuple_map(lambda r: r[idx], self.rays),
181
+ "image_index": idx}
182
+
183
+ # TODO(bydeng): Swap this function with a more flexible camera model.
184
+ def _generate_rays(self):
185
+ """Generating rays for all images."""
186
+ pixel_center = 0.5 if self.use_pixel_centers else 0.0
187
+ x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking
188
+ np.arange(self.w, dtype=np.float32) + pixel_center, # X-Axis (columns)
189
+ np.arange(self.h, dtype=np.float32) + pixel_center, # Y-Axis (rows)
190
+ indexing="xy")
191
+ camera_dirs = np.stack([(x - self.w * 0.5) / self.focal,
192
+ -(y - self.h * 0.5) / self.focal, -np.ones_like(x)],
193
+ axis=-1)
194
+ directions = ((camera_dirs[None, ..., None, :] *
195
+ self.camtoworlds[:, None, None, :3, :3]).sum(axis=-1))
196
+ origins = np.broadcast_to(self.camtoworlds[:, None, None, :3, -1],
197
+ directions.shape)
198
+ viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
199
+ self.rays = utils.Rays(
200
+ origins=origins, directions=directions, viewdirs=viewdirs)
201
+
202
+ def camtoworld_matrix_to_rays(self, camtoworld, downsample = 1):
203
+ """ render one instance of rays given a camera to world matrix (4, 4) """
204
+ pixel_center = 0.5 if self.use_pixel_centers else 0.0
205
+ # TODO @Alex: apply mesh downsampling here
206
+ x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking
207
+ np.arange(self.w, step = downsample, dtype=np.float32) + pixel_center, # X-Axis (columns)
208
+ np.arange(self.h, step = downsample, dtype=np.float32) + pixel_center, # Y-Axis (rows)
209
+ indexing="xy")
210
+ camera_dirs = np.stack([(x - self.w * 0.5) / self.focal,
211
+ -(y - self.h * 0.5) / self.focal, -np.ones_like(x)],
212
+ axis=-1)
213
+ directions = (camera_dirs[..., None, :] * camtoworld[None, None, :3, :3]).sum(axis=-1)
214
+ origins = np.broadcast_to(camtoworld[None, None, :3, -1], directions.shape)
215
+ viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
216
+ return utils.Rays(origins=origins, directions=directions, viewdirs=viewdirs)
217
+
218
+ class Blender(Dataset):
219
+ """Blender Dataset."""
220
+
221
+ def _load_renderings(self, flags, clip_model = None):
222
+ """Load images from disk."""
223
+ if flags.render_path:
224
+ raise ValueError("render_path cannot be used for the blender dataset.")
225
+ cams, images, meta = self.load_files(flags.data_dir, self.split, flags.factor, flags.few_shot)
226
+
227
+ self.images = np.stack(images, axis=0)
228
+ if flags.white_bkgd:
229
+ self.images = (self.images[..., :3] * self.images[..., -1:] +
230
+ (1. - self.images[..., -1:]))
231
+ else:
232
+ self.images = self.images[..., :3]
233
+ self.h, self.w = self.images.shape[1:3]
234
+ self.resolution = self.h * self.w
235
+ self.camtoworlds = np.stack(cams, axis=0)
236
+ camera_angle_x = float(meta["camera_angle_x"])
237
+ self.focal = .5 * self.w / np.tan(.5 * camera_angle_x)
238
+ self.n_examples = self.images.shape[0]
239
+
240
+ if flags.use_semantic_loss and clip_model is not None:
241
+ embs = []
242
+ for img in self.images:
243
+ img = np.expand_dims(np.transpose(img,[2,0,1]), 0)
244
+ emb = clip_model.get_image_features(pixel_values = clip_utils.preprocess_for_CLIP(img))
245
+ embs.append( emb/np.linalg.norm(emb) )
246
+ self.embeddings = np.concatenate(embs, 0)
247
+
248
+ self.image_idx = np.arange(self.images.shape[0])
249
+ np.random.shuffle(self.image_idx)
250
+ self.image_idx = self.image_idx.tolist()
251
+
252
+ @staticmethod
253
+ def load_files(data_dir, split, factor, few_shot):
254
+ with utils.open_file(path.join(data_dir, "transforms_{}.json".format(split)), "r") as fp:
255
+ meta = json.load(fp)
256
+ images = []
257
+ cams = []
258
+
259
+ frames = np.arange(len(meta["frames"]))
260
+ if few_shot > 0 and split == 'train':
261
+ np.random.seed(0)
262
+ np.random.shuffle(frames)
263
+ frames = frames[:few_shot]
264
+
265
+ # if split == 'train':
266
+ # frames = [2,5,10,40,52,53,69,78,83,85,90,94,96,97]
267
+
268
+ for i in frames:
269
+ frame = meta["frames"][i]
270
+ fname = os.path.join(data_dir, frame["file_path"] + ".png")
271
+ with utils.open_file(fname, "rb") as imgin:
272
+ image = np.array(Image.open(imgin)).astype(np.float32) / 255.
273
+ if factor == 2:
274
+ [halfres_h, halfres_w] = [hw // 2 for hw in image.shape[:2]]
275
+ image = cv2.resize(image, (halfres_w, halfres_h),
276
+ interpolation=cv2.INTER_AREA)
277
+ elif factor == 4:
278
+ [halfres_h, halfres_w] = [hw // 4 for hw in image.shape[:2]]
279
+ image = cv2.resize(image, (halfres_w, halfres_h),
280
+ interpolation=cv2.INTER_AREA)
281
+ elif factor > 0:
282
+ raise ValueError("Blender dataset only supports factor=0 or 2 or 4, {} "
283
+ "set.".format(factor))
284
+ cams.append(np.array(frame["transform_matrix"], dtype=np.float32))
285
+ images.append(image)
286
+
287
+ print(f'No. of samples: {len(frames)}')
288
+ return cams, images, meta
289
+
290
+ def _next_train(self):
291
+ batch_dict = super(Blender, self)._next_train()
292
+ if self.batching == "single_image":
293
+ image_index = batch_dict.pop("image_index")
294
+ else:
295
+ raise NotImplementedError
296
+ return batch_dict
297
+
298
+ def get_clip_data(self):
299
+ if len(self.image_idx) == 0:
300
+ self.image_idx = np.arange(self.images.shape[0])
301
+ np.random.shuffle(self.image_idx)
302
+ self.image_idx = self.image_idx.tolist()
303
+ image_index = self.image_idx.pop()
304
+
305
+ batch_dict = {}
306
+ batch_dict["embedding"] = self.embeddings[image_index]
307
+
308
+ src_seed = int(time.time())
309
+ src_rng = jax.random.PRNGKey(src_seed)
310
+ src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far)))
311
+ random_rays = self.camtoworld_matrix_to_rays(src_camtoworld, downsample = 4)
312
+ cx = np.random.randint(80, 120)
313
+ cy = np.random.randint(80, 120)
314
+ d = 70
315
+ random_rays = jax.tree_map(lambda x: x[cy-d:cy+d,cx-d:cx+d], random_rays)
316
+ w = random_rays[0].shape[0] - random_rays[0].shape[0]%jax.local_device_count()
317
+ random_rays = jax.tree_map(lambda x: x[:w,:w].reshape(-1,3), random_rays)
318
+ batch_dict["random_rays"] = random_rays
319
+ return batch_dict
320
+
321
+ class LLFF(Dataset):
322
+ """LLFF Dataset."""
323
+
324
+ def _load_renderings(self, flags):
325
+ """Load images from disk."""
326
+ # Load images.
327
+ imgdir_suffix = ""
328
+ if flags.factor > 0:
329
+ imgdir_suffix = "_{}".format(flags.factor)
330
+ factor = flags.factor
331
+ else:
332
+ factor = 1
333
+ imgdir = path.join(flags.data_dir, "images" + imgdir_suffix)
334
+ if not utils.file_exists(imgdir):
335
+ raise ValueError("Image folder {} doesn't exist.".format(imgdir))
336
+ imgfiles = [
337
+ path.join(imgdir, f)
338
+ for f in sorted(utils.listdir(imgdir))
339
+ if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
340
+ ]
341
+ images = []
342
+ for imgfile in imgfiles:
343
+ with utils.open_file(imgfile, "rb") as imgin:
344
+ image = np.array(Image.open(imgin), dtype=np.float32) / 255.
345
+ images.append(image)
346
+ images = np.stack(images, axis=-1)
347
+
348
+ # Load poses and bds.
349
+ with utils.open_file(path.join(flags.data_dir, "poses_bounds.npy"),
350
+ "rb") as fp:
351
+ poses_arr = np.load(fp)
352
+ poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
353
+ bds = poses_arr[:, -2:].transpose([1, 0])
354
+ if poses.shape[-1] != images.shape[-1]:
355
+ raise RuntimeError("Mismatch between imgs {} and poses {}".format(
356
+ images.shape[-1], poses.shape[-1]))
357
+
358
+ # Update poses according to downsampling.
359
+ poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1])
360
+ poses[2, 4, :] = poses[2, 4, :] * 1. / factor
361
+
362
+ # Correct rotation matrix ordering and move variable dim to axis 0.
363
+ poses = np.concatenate(
364
+ [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
365
+ poses = np.moveaxis(poses, -1, 0).astype(np.float32)
366
+ images = np.moveaxis(images, -1, 0)
367
+ bds = np.moveaxis(bds, -1, 0).astype(np.float32)
368
+
369
+ # Rescale according to a default bd factor.
370
+ scale = 1. / (bds.min() * .75)
371
+ poses[:, :3, 3] *= scale
372
+ bds *= scale
373
+
374
+ # Recenter poses.
375
+ poses = self._recenter_poses(poses)
376
+
377
+ # Generate a spiral/spherical ray path for rendering videos.
378
+ if flags.spherify:
379
+ poses = self._generate_spherical_poses(poses, bds)
380
+ self.spherify = True
381
+ else:
382
+ self.spherify = False
383
+ if not flags.spherify and self.split == "test":
384
+ self._generate_spiral_poses(poses, bds)
385
+
386
+ # Select the split.
387
+ i_test = np.arange(images.shape[0])[::flags.llffhold]
388
+ i_train = np.array(
389
+ [i for i in np.arange(int(images.shape[0])) if i not in i_test])
390
+ if self.split == "train":
391
+ indices = i_train
392
+ else:
393
+ indices = i_test
394
+ images = images[indices]
395
+ poses = poses[indices]
396
+
397
+ self.images = images
398
+ self.camtoworlds = poses[:, :3, :4]
399
+ self.focal = poses[0, -1, -1]
400
+ self.h, self.w = images.shape[1:3]
401
+ self.resolution = self.h * self.w
402
+ if flags.render_path:
403
+ self.n_examples = self.render_poses.shape[0]
404
+ else:
405
+ self.n_examples = images.shape[0]
406
+
407
+ def _generate_rays(self):
408
+ """Generate normalized device coordinate rays for llff."""
409
+ if self.split == "test":
410
+ n_render_poses = self.render_poses.shape[0]
411
+ self.camtoworlds = np.concatenate([self.render_poses, self.camtoworlds],
412
+ axis=0)
413
+
414
+ super()._generate_rays()
415
+
416
+ if not self.spherify:
417
+ ndc_origins, ndc_directions = convert_to_ndc(self.rays.origins,
418
+ self.rays.directions,
419
+ self.focal, self.w, self.h)
420
+ self.rays = utils.Rays(
421
+ origins=ndc_origins,
422
+ directions=ndc_directions,
423
+ viewdirs=self.rays.viewdirs)
424
+
425
+ # Split poses from the dataset and generated poses
426
+ if self.split == "test":
427
+ self.camtoworlds = self.camtoworlds[n_render_poses:]
428
+ split = [np.split(r, [n_render_poses], 0) for r in self.rays]
429
+ split0, split1 = zip(*split)
430
+ self.render_rays = utils.Rays(*split0)
431
+ self.rays = utils.Rays(*split1)
432
+
433
+ def _recenter_poses(self, poses):
434
+ """Recenter poses according to the original NeRF code."""
435
+ poses_ = poses.copy()
436
+ bottom = np.reshape([0, 0, 0, 1.], [1, 4])
437
+ c2w = self._poses_avg(poses)
438
+ c2w = np.concatenate([c2w[:3, :4], bottom], -2)
439
+ bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
440
+ poses = np.concatenate([poses[:, :3, :4], bottom], -2)
441
+ poses = np.linalg.inv(c2w) @ poses
442
+ poses_[:, :3, :4] = poses[:, :3, :4]
443
+ poses = poses_
444
+ return poses
445
+
446
+ def _poses_avg(self, poses):
447
+ """Average poses according to the original NeRF code."""
448
+ hwf = poses[0, :3, -1:]
449
+ center = poses[:, :3, 3].mean(0)
450
+ vec2 = self._normalize(poses[:, :3, 2].sum(0))
451
+ up = poses[:, :3, 1].sum(0)
452
+ c2w = np.concatenate([self._viewmatrix(vec2, up, center), hwf], 1)
453
+ return c2w
454
+
455
+ def _viewmatrix(self, z, up, pos):
456
+ """Construct lookat view matrix."""
457
+ vec2 = self._normalize(z)
458
+ vec1_avg = up
459
+ vec0 = self._normalize(np.cross(vec1_avg, vec2))
460
+ vec1 = self._normalize(np.cross(vec2, vec0))
461
+ m = np.stack([vec0, vec1, vec2, pos], 1)
462
+ return m
463
+
464
+ def _normalize(self, x):
465
+ """Normalization helper function."""
466
+ return x / np.linalg.norm(x)
467
+
468
+ def _generate_spiral_poses(self, poses, bds):
469
+ """Generate a spiral path for rendering."""
470
+ c2w = self._poses_avg(poses)
471
+ # Get average pose.
472
+ up = self._normalize(poses[:, :3, 1].sum(0))
473
+ # Find a reasonable "focus depth" for this dataset.
474
+ close_depth, inf_depth = bds.min() * .9, bds.max() * 5.
475
+ dt = .75
476
+ mean_dz = 1. / (((1. - dt) / close_depth + dt / inf_depth))
477
+ focal = mean_dz
478
+ # Get radii for spiral path.
479
+ tt = poses[:, :3, 3]
480
+ rads = np.percentile(np.abs(tt), 90, 0)
481
+ c2w_path = c2w
482
+ n_views = 120
483
+ n_rots = 2
484
+ # Generate poses for spiral path.
485
+ render_poses = []
486
+ rads = np.array(list(rads) + [1.])
487
+ hwf = c2w_path[:, 4:5]
488
+ zrate = .5
489
+ for theta in np.linspace(0., 2. * np.pi * n_rots, n_views + 1)[:-1]:
490
+ c = np.dot(c2w[:3, :4], (np.array(
491
+ [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads))
492
+ z = self._normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
493
+ render_poses.append(np.concatenate([self._viewmatrix(z, up, c), hwf], 1))
494
+ self.render_poses = np.array(render_poses).astype(np.float32)[:, :3, :4]
495
+
496
+ def _generate_spherical_poses(self, poses, bds):
497
+ """Generate a 360 degree spherical path for rendering."""
498
+ # pylint: disable=g-long-lambda
499
+ p34_to_44 = lambda p: np.concatenate([
500
+ p,
501
+ np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])
502
+ ], 1)
503
+ rays_d = poses[:, :3, 2:3]
504
+ rays_o = poses[:, :3, 3:4]
505
+
506
+ def min_line_dist(rays_o, rays_d):
507
+ a_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
508
+ b_i = -a_i @ rays_o
509
+ pt_mindist = np.squeeze(-np.linalg.inv(
510
+ (np.transpose(a_i, [0, 2, 1]) @ a_i).mean(0)) @ (b_i).mean(0))
511
+ return pt_mindist
512
+
513
+ pt_mindist = min_line_dist(rays_o, rays_d)
514
+ center = pt_mindist
515
+ up = (poses[:, :3, 3] - center).mean(0)
516
+ vec0 = self._normalize(up)
517
+ vec1 = self._normalize(np.cross([.1, .2, .3], vec0))
518
+ vec2 = self._normalize(np.cross(vec0, vec1))
519
+ pos = center
520
+ c2w = np.stack([vec1, vec2, vec0, pos], 1)
521
+ poses_reset = (
522
+ np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4]))
523
+ rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))
524
+ sc = 1. / rad
525
+ poses_reset[:, :3, 3] *= sc
526
+ bds *= sc
527
+ rad *= sc
528
+ centroid = np.mean(poses_reset[:, :3, 3], 0)
529
+ zh = centroid[2]
530
+ radcircle = np.sqrt(rad ** 2 - zh ** 2)
531
+ new_poses = []
532
+
533
+ for th in np.linspace(0., 2. * np.pi, 120):
534
+ camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
535
+ up = np.array([0, 0, -1.])
536
+ vec2 = self._normalize(camorigin)
537
+ vec0 = self._normalize(np.cross(vec2, up))
538
+ vec1 = self._normalize(np.cross(vec2, vec0))
539
+ pos = camorigin
540
+ p = np.stack([vec0, vec1, vec2, pos], 1)
541
+ new_poses.append(p)
542
+
543
+ new_poses = np.stack(new_poses, 0)
544
+ new_poses = np.concatenate([
545
+ new_poses,
546
+ np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)
547
+ ], -1)
548
+ poses_reset = np.concatenate([
549
+ poses_reset[:, :3, :4],
550
+ np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape)
551
+ ], -1)
552
+ if self.split == "test":
553
+ self.render_poses = new_poses[:, :3, :4]
554
+ return poses_reset
555
+
556
+
557
+ dataset_dict = {"blender": Blender,
558
+ "llff": LLFF}
nerf/model_utils.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Lint as: python3
17
+ """Helper functions/classes for model definition."""
18
+
19
+ import functools
20
+ from typing import Any, Callable
21
+
22
+ from flax import linen as nn
23
+ import jax
24
+ from jax import lax
25
+ from jax import random
26
+ import jax.numpy as jnp
27
+
28
+
29
+ class MLP(nn.Module):
30
+ """A simple MLP."""
31
+ net_depth: int = 8 # The depth of the first part of MLP.
32
+ net_width: int = 256 # The width of the first part of MLP.
33
+ net_depth_condition: int = 1 # The depth of the second part of MLP.
34
+ net_width_condition: int = 128 # The width of the second part of MLP.
35
+ net_activation: Callable[..., Any] = nn.relu # The activation function.
36
+ skip_layer: int = 4 # The layer to add skip layers to.
37
+ num_rgb_channels: int = 3 # The number of RGB channels.
38
+ num_sigma_channels: int = 1 # The number of sigma channels.
39
+
40
+ @nn.compact
41
+ def __call__(self, x, condition=None):
42
+ """
43
+ Evaluate the MLP.
44
+
45
+ Args:
46
+ x: jnp.ndarray(float32), [batch, num_samples, feature], points.
47
+ condition: jnp.ndarray(float32), [batch, feature], if not None, this
48
+ variable will be part of the input to the second part of the MLP
49
+ concatenated with the output vector of the first part of the MLP. If
50
+ None, only the first part of the MLP will be used with input x. In the
51
+ original paper, this variable is the view direction.
52
+
53
+ Returns:
54
+ raw_rgb: jnp.ndarray(float32), with a shape of
55
+ [batch, num_samples, num_rgb_channels].
56
+ raw_sigma: jnp.ndarray(float32), with a shape of
57
+ [batch, num_samples, num_sigma_channels].
58
+ """
59
+ feature_dim = x.shape[-1]
60
+ num_samples = x.shape[1]
61
+ x = x.reshape([-1, feature_dim])
62
+ dense_layer = functools.partial(
63
+ nn.Dense, kernel_init=jax.nn.initializers.glorot_uniform())
64
+ inputs = x
65
+
66
+ dtype = x.dtype
67
+ for i in range(self.net_depth):
68
+ x = dense_layer(self.net_width, dtype = dtype)(x)
69
+ x = self.net_activation(x)
70
+ if i % self.skip_layer == 0 and i > 0:
71
+ x = jnp.concatenate([x, inputs], axis=-1)
72
+ raw_sigma = dense_layer(self.num_sigma_channels, dtype = dtype)(x).reshape(
73
+ [-1, num_samples, self.num_sigma_channels])
74
+
75
+ if condition is not None:
76
+ # Output of the first part of MLP.
77
+ bottleneck = dense_layer(self.net_width, dtype = dtype)(x)
78
+ # Broadcast condition from [batch, feature] to
79
+ # [batch, num_samples, feature] since all the samples along the same ray
80
+ # have the same viewdir.
81
+ condition = jnp.tile(condition[:, None, :], (1, num_samples, 1))
82
+ # Collapse the [batch, num_samples, feature] tensor to
83
+ # [batch * num_samples, feature] so that it can be fed into nn.Dense.
84
+ condition = condition.reshape([-1, condition.shape[-1]])
85
+ x = jnp.concatenate([bottleneck, condition], axis=-1)
86
+ # Here use 1 extra layer to align with the original nerf model.
87
+ for i in range(self.net_depth_condition):
88
+ x = dense_layer(self.net_width_condition, dtype = dtype)(x)
89
+ x = self.net_activation(x)
90
+ raw_rgb = dense_layer(self.num_rgb_channels, dtype = dtype)(x).reshape(
91
+ [-1, num_samples, self.num_rgb_channels])
92
+ return raw_rgb, raw_sigma
93
+
94
+
95
+ def cast_rays(z_vals, origins, directions):
96
+ return origins[..., None, :] + z_vals[..., None] * directions[..., None, :]
97
+
98
+
99
+ def sample_along_rays(key, origins, directions, num_samples, near, far,
100
+ randomized, lindisp):
101
+ """
102
+ Stratified sampling along the rays.
103
+
104
+ Args:
105
+ key: jnp.ndarray, random generator key.
106
+ origins: jnp.ndarray(float32), [batch_size, 3], ray origins.
107
+ directions: jnp.ndarray(float32), [batch_size, 3], ray directions.
108
+ num_samples: int.
109
+ near: float, near clip.
110
+ far: float, far clip.
111
+ randomized: bool, use randomized stratified sampling.
112
+ lindisp: bool, sampling linearly in disparity rather than depth.
113
+
114
+ Returns:
115
+ z_vals: jnp.ndarray, [batch_size, num_samples], sampled z values.
116
+ points: jnp.ndarray, [batch_size, num_samples, 3], sampled points.
117
+ """
118
+ batch_size = origins.shape[0]
119
+
120
+ dtype = origins.dtype
121
+
122
+ t_vals = jnp.linspace(0., 1., num_samples, dtype = dtype)
123
+ if lindisp:
124
+ z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * t_vals)
125
+ else:
126
+ z_vals = near * (1. - t_vals) + far * t_vals
127
+
128
+ if randomized:
129
+ mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
130
+ upper = jnp.concatenate([mids, z_vals[..., -1:]], -1)
131
+ lower = jnp.concatenate([z_vals[..., :1], mids], -1)
132
+ t_rand = random.uniform(key, [batch_size, num_samples])
133
+ z_vals = lower + (upper - lower) * t_rand
134
+ else:
135
+ # Broadcast z_vals to make the returned shape consistent.
136
+ z_vals = jnp.broadcast_to(z_vals[None, ...], [batch_size, num_samples]).astype(dtype)
137
+
138
+ coords = cast_rays(z_vals, origins, directions)
139
+ return z_vals, coords
140
+
141
+
142
+ def posenc(x, min_deg, max_deg, legacy_posenc_order=False):
143
+ """
144
+ Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
145
+
146
+ Instead of computing [sin(x), cos(x)], we use the trig identity
147
+ cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
148
+
149
+ Args:
150
+ x: jnp.ndarray, variables to be encoded. Note that x should be in [-pi, pi].
151
+ min_deg: int, the minimum (inclusive) degree of the encoding.
152
+ max_deg: int, the maximum (exclusive) degree of the encoding.
153
+ legacy_posenc_order: bool, keep the same ordering as the original tf code.
154
+
155
+ Returns:
156
+ encoded: jnp.ndarray, encoded variables.
157
+ """
158
+ if min_deg == max_deg:
159
+ return x
160
+
161
+ dtype = x.dtype
162
+
163
+ scales = jnp.array([2 ** i for i in range(min_deg, max_deg)], dtype = dtype)
164
+ if legacy_posenc_order:
165
+ xb = x[..., None, :] * scales[:, None]
166
+ four_feat = jnp.reshape(
167
+ jnp.sin(jnp.stack([xb, xb + 0.5 * jnp.pi], -2)),
168
+ list(x.shape[:-1]) + [-1])
169
+ else:
170
+ xb = jnp.reshape((x[..., None, :] * scales[:, None]),
171
+ list(x.shape[:-1]) + [-1])
172
+ four_feat = jnp.sin(jnp.concatenate([xb, xb + 0.5 * jnp.pi], axis=-1))
173
+ return jnp.concatenate([x] + [four_feat], axis=-1)
174
+
175
+
176
+ def volumetric_rendering(rgb, sigma, z_vals, dirs, white_bkgd):
177
+ """
178
+ Volumetric Rendering Function.
179
+
180
+ Args:
181
+ rgb: jnp.ndarray(float32), color, [batch_size, num_samples, 3]
182
+ sigma: jnp.ndarray(float32), density, [batch_size, num_samples, 1].
183
+ z_vals: jnp.ndarray(float32), [batch_size, num_samples].
184
+ dirs: jnp.ndarray(float32), [batch_size, 3].
185
+ white_bkgd: bool.
186
+
187
+ Returns:
188
+ comp_rgb: jnp.ndarray(float32), [batch_size, 3].
189
+ disp: jnp.ndarray(float32), [batch_size].
190
+ acc: jnp.ndarray(float32), [batch_size].
191
+ weights: jnp.ndarray(float32), [batch_size, num_samples]
192
+ """
193
+ dtype = rgb.dtype
194
+
195
+ eps = jnp.array(1e-10, dtype = dtype)
196
+ dists = jnp.concatenate([
197
+ z_vals[..., 1:] - z_vals[..., :-1],
198
+ jnp.broadcast_to(jnp.array([1e10]),#, dtype = dtype),
199
+ z_vals[..., :1].shape)
200
+ ], -1)
201
+ dists = dists * jnp.linalg.norm(dirs[..., None, :], axis=-1)
202
+ # Note that we're quietly turning sigma from [..., 0] to [...].
203
+ alpha = 1.0 - jnp.exp(-sigma[..., 0] * dists)
204
+ accum_prod = jnp.concatenate([
205
+ jnp.ones_like(alpha[..., :1], alpha.dtype),
206
+ jnp.cumprod(1.0 - alpha[..., :-1] + eps, axis=-1)
207
+ ],
208
+ axis=-1)
209
+ weights = alpha * accum_prod
210
+ weights = weights.astype(dtype)
211
+
212
+ comp_rgb = (weights[..., None] * rgb).sum(axis=-2)
213
+ depth = (weights * z_vals).sum(axis=-1)
214
+ acc = weights.sum(axis=-1)
215
+ # Equivalent to (but slightly more efficient and stable than):
216
+ # disp = 1 / max(eps, where(acc > eps, depth / acc, 0))
217
+ inv_eps = 1 / eps
218
+ disp = acc / depth
219
+ disp = jnp.where((disp > 0) & (disp < inv_eps) & (acc > eps), disp, inv_eps)
220
+ if white_bkgd:
221
+ comp_rgb = comp_rgb + (1. - acc[..., None])
222
+ return comp_rgb, disp, acc, weights
223
+
224
+
225
+ def piecewise_constant_pdf(key, bins, weights, num_samples, randomized):
226
+ """
227
+ Piecewise-Constant PDF sampling.
228
+
229
+ Args:
230
+ key: jnp.ndarray(float32), [2,], random number generator.
231
+ bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
232
+ weights: jnp.ndarray(float32), [batch_size, num_bins].
233
+ num_samples: int, the number of samples.
234
+ randomized: bool, use randomized samples.
235
+
236
+ Returns:
237
+ z_samples: jnp.ndarray(float32), [batch_size, num_samples].
238
+ """
239
+ # Pad each weight vector (only if necessary) to bring its sum to `eps`. This
240
+ # avoids NaNs when the input is zeros or small, but has no effect otherwise.
241
+ dtype = bins.dtype
242
+
243
+ eps = 1e-5
244
+ weight_sum = jnp.sum(weights, axis=-1, keepdims=True)
245
+ padding = jnp.maximum(0, eps - weight_sum)
246
+ weights += padding / weights.shape[-1]
247
+ weight_sum += padding
248
+
249
+ # Compute the PDF and CDF for each weight vector, while ensuring that the CDF
250
+ # starts with exactly 0 and ends with exactly 1.
251
+ pdf = weights / weight_sum
252
+ cdf = jnp.minimum(1, jnp.cumsum(pdf[..., :-1], axis=-1))
253
+ cdf = jnp.concatenate([
254
+ jnp.zeros(list(cdf.shape[:-1]) + [1], dtype = dtype), cdf,
255
+ jnp.ones(list(cdf.shape[:-1]) + [1], dtype = dtype)
256
+ ],
257
+ axis=-1)
258
+
259
+ # Draw uniform samples.
260
+ if randomized:
261
+ # Note that `u` is in [0, 1) --- it can be zero, but it can never be 1.
262
+ u = random.uniform(key, list(cdf.shape[:-1]) + [num_samples])
263
+ else:
264
+ # Match the behavior of random.uniform() by spanning [0, 1-eps].
265
+ u = jnp.linspace(0., 1. - jnp.finfo(dtype).eps, num_samples, dtype = dtype)
266
+ u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples])
267
+
268
+ # Identify the location in `cdf` that corresponds to a random sample.
269
+ # The final `True` index in `mask` will be the start of the sampled interval.
270
+ mask = u[..., None, :] >= cdf[..., :, None]
271
+
272
+ def find_interval(x):
273
+ # Grab the value where `mask` switches from True to False, and vice versa.
274
+ # This approach takes advantage of the fact that `x` is sorted.
275
+ x0 = jnp.max(jnp.where(mask, x[..., None], x[..., :1, None]), -2)
276
+ x1 = jnp.min(jnp.where(~mask, x[..., None], x[..., -1:, None]), -2)
277
+ return x0, x1
278
+
279
+ bins_g0, bins_g1 = find_interval(bins)
280
+ cdf_g0, cdf_g1 = find_interval(cdf)
281
+
282
+ t = jnp.clip(jnp.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1)
283
+ samples = bins_g0 + t * (bins_g1 - bins_g0)
284
+
285
+ # Prevent gradient from backprop-ing through `samples`.
286
+ return lax.stop_gradient(samples)
287
+
288
+
289
+ def sample_pdf(key, bins, weights, origins, directions, z_vals, num_samples,
290
+ randomized):
291
+ """
292
+ Hierarchical sampling.
293
+
294
+ Args:
295
+ key: jnp.ndarray(float32), [2,], random number generator.
296
+ bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
297
+ weights: jnp.ndarray(float32), [batch_size, num_bins].
298
+ origins: jnp.ndarray(float32), [batch_size, 3], ray origins.
299
+ directions: jnp.ndarray(float32), [batch_size, 3], ray directions.
300
+ z_vals: jnp.ndarray(float32), [batch_size, num_coarse_samples].
301
+ num_samples: int, the number of samples.
302
+ randomized: bool, use randomized samples.
303
+
304
+ Returns:
305
+ z_vals: jnp.ndarray(float32),
306
+ [batch_size, num_coarse_samples + num_fine_samples].
307
+ points: jnp.ndarray(float32),
308
+ [batch_size, num_coarse_samples + num_fine_samples, 3].
309
+ """
310
+ z_samples = piecewise_constant_pdf(key, bins, weights, num_samples,
311
+ randomized)
312
+ # Compute united z_vals and sample points
313
+ z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1)
314
+ coords = cast_rays(z_vals, origins, directions)
315
+ return z_vals, coords
316
+
317
+
318
+ def add_gaussian_noise(key, raw, noise_std, randomized):
319
+ """
320
+ Adds gaussian noise to `raw`, which can used to regularize it.
321
+
322
+ Args:
323
+ key: jnp.ndarray(float32), [2,], random number generator.
324
+ raw: jnp.ndarray(float32), arbitrary shape.
325
+ noise_std: float, The standard deviation of the noise to be added.
326
+ randomized: bool, add noise if randomized is True.
327
+
328
+ Returns:
329
+ raw + noise: jnp.ndarray(float32), with the same shape as `raw`.
330
+ """
331
+ if (noise_std is not None) and randomized:
332
+ return raw + random.normal(key, raw.shape, dtype=raw.dtype) * noise_std
333
+ else:
334
+ return raw
nerf/models.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Lint as: python3
17
+ """Different model implementation plus a general port for all the models."""
18
+ from typing import Any, Callable
19
+ from flax import linen as nn
20
+ from jax import random
21
+ import jax.numpy as jnp
22
+
23
+ from nerf import model_utils
24
+ from nerf import utils
25
+
26
+
27
+ def get_model(key, example_batch, args):
28
+ """A helper function that wraps around a 'model zoo'."""
29
+ model_dict = {"nerf": construct_nerf}
30
+ return model_dict[args.model](key, example_batch, args)
31
+
32
+
33
+ class NerfModel(nn.Module):
34
+ """Nerf NN Model with both coarse and fine MLPs."""
35
+ num_coarse_samples: int # The number of samples for the coarse nerf.
36
+ num_fine_samples: int # The number of samples for the fine nerf.
37
+ use_viewdirs: bool # If True, use viewdirs as an input.
38
+ near: float # The distance to the near plane
39
+ far: float # The distance to the far plane
40
+ noise_std: float # The std dev of noise added to raw sigma.
41
+ net_depth: int # The depth of the first part of MLP.
42
+ net_width: int # The width of the first part of MLP.
43
+ net_depth_condition: int # The depth of the second part of MLP.
44
+ net_width_condition: int # The width of the second part of MLP.
45
+ net_activation: Callable[..., Any] # MLP activation
46
+ skip_layer: int # How often to add skip connections.
47
+ num_rgb_channels: int # The number of RGB channels.
48
+ num_sigma_channels: int # The number of density channels.
49
+ white_bkgd: bool # If True, use a white background.
50
+ min_deg_point: int # The minimum degree of positional encoding for positions.
51
+ max_deg_point: int # The maximum degree of positional encoding for positions.
52
+ deg_view: int # The degree of positional encoding for viewdirs.
53
+ lindisp: bool # If True, sample linearly in disparity rather than in depth.
54
+ rgb_activation: Callable[..., Any] # Output RGB activation.
55
+ sigma_activation: Callable[..., Any] # Output sigma activation.
56
+ legacy_posenc_order: bool # Keep the same ordering as the original tf code.
57
+
58
+ @nn.compact
59
+ def __call__(self, rng_0, rng_1, rays, randomized, rgb_only = False):
60
+ """Nerf Model.
61
+
62
+ Args:
63
+ rng_0: jnp.ndarray, random number generator for coarse model sampling.
64
+ rng_1: jnp.ndarray, random number generator for fine model sampling.
65
+ rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs.
66
+ randomized: bool, use randomized stratified sampling.
67
+
68
+ Returns:
69
+ ret: list, [(rgb_coarse, disp_coarse, acc_coarse), (rgb, disp, acc)]
70
+ """
71
+ # Stratified sampling along rays
72
+ key, rng_0 = random.split(rng_0)
73
+ dtype = rays[0].dtype
74
+
75
+ z_vals, samples = model_utils.sample_along_rays(
76
+ key,
77
+ rays.origins,
78
+ rays.directions,
79
+ self.num_coarse_samples,
80
+ self.near,
81
+ self.far,
82
+ randomized,
83
+ self.lindisp,
84
+ )
85
+
86
+ samples_enc = model_utils.posenc(
87
+ samples,
88
+ self.min_deg_point,
89
+ self.max_deg_point,
90
+ self.legacy_posenc_order,
91
+ )
92
+
93
+ # Construct the "coarse" MLP.
94
+ coarse_mlp = model_utils.MLP(
95
+ net_depth=self.net_depth,
96
+ net_width=self.net_width,
97
+ net_depth_condition=self.net_depth_condition,
98
+ net_width_condition=self.net_width_condition,
99
+ net_activation=self.net_activation,
100
+ skip_layer=self.skip_layer,
101
+ num_rgb_channels=self.num_rgb_channels,
102
+ num_sigma_channels=self.num_sigma_channels)
103
+
104
+ # Point attribute predictions
105
+ if self.use_viewdirs:
106
+ viewdirs_enc = model_utils.posenc(
107
+ rays.viewdirs,
108
+ 0,
109
+ self.deg_view,
110
+ self.legacy_posenc_order,
111
+ )
112
+ raw_rgb, raw_sigma = coarse_mlp(samples_enc, viewdirs_enc)
113
+ else:
114
+ viewdirs_enc = None
115
+ raw_rgb, raw_sigma = coarse_mlp(samples_enc)
116
+ # Add noises to regularize the density predictions if needed
117
+ key, rng_0 = random.split(rng_0)
118
+ raw_sigma = model_utils.add_gaussian_noise(
119
+ key,
120
+ raw_sigma,
121
+ self.noise_std,
122
+ randomized,
123
+ )
124
+ rgb = self.rgb_activation(raw_rgb)
125
+ sigma = self.sigma_activation(raw_sigma)
126
+ # Volumetric rendering.
127
+ comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
128
+ rgb,
129
+ sigma,
130
+ z_vals,
131
+ rays.directions,
132
+ white_bkgd=self.white_bkgd,
133
+ )
134
+
135
+ ret = [
136
+ (comp_rgb, disp, acc),
137
+ ]
138
+
139
+ if self.num_fine_samples > 0 and not(rgb_only):
140
+ z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
141
+ key, rng_1 = random.split(rng_1)
142
+
143
+ z_vals, samples = model_utils.sample_pdf(
144
+ key,
145
+ z_vals_mid,
146
+ weights[..., 1:-1],
147
+ rays.origins,
148
+ rays.directions,
149
+ z_vals,
150
+ self.num_fine_samples,
151
+ randomized,
152
+ )
153
+ samples_enc = model_utils.posenc(
154
+ samples,
155
+ self.min_deg_point,
156
+ self.max_deg_point,
157
+ self.legacy_posenc_order,
158
+ )
159
+
160
+ # Construct the "fine" MLP.
161
+ fine_mlp = model_utils.MLP(
162
+ net_depth=self.net_depth,
163
+ net_width=self.net_width,
164
+ net_depth_condition=self.net_depth_condition,
165
+ net_width_condition=self.net_width_condition,
166
+ net_activation=self.net_activation,
167
+ skip_layer=self.skip_layer,
168
+ num_rgb_channels=self.num_rgb_channels,
169
+ num_sigma_channels=self.num_sigma_channels)
170
+
171
+ if self.use_viewdirs:
172
+ raw_rgb, raw_sigma = fine_mlp(samples_enc, viewdirs_enc)
173
+ else:
174
+ raw_rgb, raw_sigma = fine_mlp(samples_enc)
175
+ key, rng_1 = random.split(rng_1)
176
+ raw_sigma = model_utils.add_gaussian_noise(
177
+ key,
178
+ raw_sigma,
179
+ self.noise_std,
180
+ randomized,
181
+ )
182
+ rgb = self.rgb_activation(raw_rgb)
183
+ sigma = self.sigma_activation(raw_sigma)
184
+
185
+ comp_rgb, disp, acc, unused_weights = model_utils.volumetric_rendering(
186
+ rgb,
187
+ sigma,
188
+ z_vals,
189
+ rays.directions,
190
+ white_bkgd=self.white_bkgd,
191
+ )
192
+ ret.append((comp_rgb, disp, acc))
193
+ if rgb_only:
194
+ #return [ret[0][0], ret[1][0]]
195
+ return [None, ret[0][0]]
196
+ return ret
197
+
198
+ def construct_nerf(key, example_batch, args):
199
+ """Construct a Neural Radiance Field.
200
+
201
+ Args:
202
+ key: jnp.ndarray. Random number generator.
203
+ example_batch: dict, an example of a batch of data.
204
+ args: FLAGS class. Hyperparameters of nerf.
205
+
206
+ Returns:
207
+ model: nn.Model. Nerf model with parameters.
208
+ state: flax.Module.state. Nerf model state for stateful parameters.
209
+ """
210
+ net_activation = getattr(nn, str(args.net_activation))
211
+ rgb_activation = getattr(nn, str(args.rgb_activation))
212
+ sigma_activation = getattr(nn, str(args.sigma_activation))
213
+
214
+ # Assert that rgb_activation always produces outputs in [0, 1], and
215
+ # sigma_activation always produce non-negative outputs.
216
+ x = jnp.exp(jnp.linspace(-90, 90, 1024))
217
+ x = jnp.concatenate([-x[::-1], x], 0)
218
+
219
+ rgb = rgb_activation(x)
220
+ if jnp.any(rgb < 0) or jnp.any(rgb > 1):
221
+ raise NotImplementedError(
222
+ "Choice of rgb_activation `{}` produces colors outside of [0, 1]"
223
+ .format(args.rgb_activation))
224
+
225
+ sigma = sigma_activation(x)
226
+ if jnp.any(sigma < 0):
227
+ raise NotImplementedError(
228
+ "Choice of sigma_activation `{}` produces negative densities".format(
229
+ args.sigma_activation))
230
+
231
+ model = NerfModel(
232
+ min_deg_point=args.min_deg_point,
233
+ max_deg_point=args.max_deg_point,
234
+ deg_view=args.deg_view,
235
+ num_coarse_samples=args.num_coarse_samples,
236
+ num_fine_samples=args.num_fine_samples,
237
+ use_viewdirs=args.use_viewdirs,
238
+ near=args.near,
239
+ far=args.far,
240
+ noise_std=args.noise_std,
241
+ white_bkgd=args.white_bkgd,
242
+ net_depth=args.net_depth,
243
+ net_width=args.net_width,
244
+ net_depth_condition=args.net_depth_condition,
245
+ net_width_condition=args.net_width_condition,
246
+ skip_layer=args.skip_layer,
247
+ num_rgb_channels=args.num_rgb_channels,
248
+ num_sigma_channels=args.num_sigma_channels,
249
+ lindisp=args.lindisp,
250
+ net_activation=net_activation,
251
+ rgb_activation=rgb_activation,
252
+ sigma_activation=sigma_activation,
253
+ legacy_posenc_order=args.legacy_posenc_order)
254
+ rays = example_batch["rays"]
255
+ key1, key2, key3 = random.split(key, num=3)
256
+
257
+ init_variables = model.init(
258
+ key1,
259
+ rng_0=key2,
260
+ rng_1=key3,
261
+ rays=utils.namedtuple_map(lambda x: x[0], rays),
262
+ randomized=args.randomized)
263
+
264
+ return model, init_variables
nerf/utils.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Lint as: python3
17
+ """Utility functions."""
18
+ import collections
19
+ import os
20
+ from os import path
21
+ import pickle
22
+ from absl import flags
23
+ import flax
24
+ import jax
25
+ import jax.numpy as jnp
26
+ import jax.scipy as jsp
27
+ import numpy as np
28
+ from PIL import Image
29
+ import yaml
30
+ from nerf import datasets
31
+
32
+ BASE_DIR = ""
33
+ INTERNAL = False
34
+
35
+
36
+ @flax.struct.dataclass
37
+ class TrainState:
38
+ optimizer: flax.optim.Optimizer
39
+
40
+
41
+ @flax.struct.dataclass
42
+ class Stats:
43
+ loss: float
44
+ psnr: float
45
+ loss_c: float
46
+ psnr_c: float
47
+ weight_l2: float
48
+
49
+
50
+ Rays = collections.namedtuple("Rays", ("origins", "directions", "viewdirs"))
51
+
52
+
53
+ def namedtuple_map(fn, tup):
54
+ """Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
55
+ return type(tup)(*map(fn, tup))
56
+
57
+
58
+ def define_flags():
59
+ """Define flags for both training and evaluation modes."""
60
+ flags.DEFINE_string("train_dir", None, "where to store ckpts and logs")
61
+ flags.DEFINE_string("data_dir", None, "input data directory.")
62
+ flags.DEFINE_string("config", None,
63
+ "using config files to set hyperparameters.")
64
+
65
+ # CLIP part Flags
66
+ flags.DEFINE_bool("use_semantic_loss", True,
67
+ "whether use semantic loss or not")
68
+ flags.DEFINE_string("clip_model_name", "openai/clip-vit-base-patch32", "model type for CLIP")
69
+ flags.DEFINE_string("clip_output_dtype", "float32",
70
+ "float32/ float16 (float16 for memory saving)")
71
+ flags.DEFINE_integer("sc_loss_every", 16,
72
+ "no. of steps to take before performing semantic loss evaluation")
73
+ flags.DEFINE_float("sc_loss_mult", 1e-3,
74
+ "weighting for semantic loss from CLIP")
75
+
76
+ # Dataset Flags
77
+ # TODO(pratuls): rename to dataset_loader and consider cleaning up
78
+ flags.DEFINE_enum("dataset", "blender",
79
+ list(k for k in datasets.dataset_dict.keys()),
80
+ "The type of dataset feed to nerf.")
81
+ flags.DEFINE_enum(
82
+ "batching", "single_image", ["single_image", "all_images"],
83
+ "source of ray sampling when collecting training batch,"
84
+ "single_image for sampling from only one image in a batch,"
85
+ "all_images for sampling from all the training images.")
86
+ flags.DEFINE_bool(
87
+ "white_bkgd", True, "using white color as default background."
88
+ "(used in the blender dataset only)")
89
+ flags.DEFINE_integer("batch_size", 1024,
90
+ "the number of rays in a mini-batch (for training).")
91
+ flags.DEFINE_integer("factor", 4,
92
+ "the downsample factor of images, 0 for no downsample.")
93
+ flags.DEFINE_bool("spherify", False, "set for spherical 360 scenes.")
94
+ flags.DEFINE_bool(
95
+ "render_path", False, "render generated path if set true."
96
+ "(used in the llff dataset only)")
97
+ flags.DEFINE_integer(
98
+ "llffhold", 8, "will take every 1/N images as LLFF test set."
99
+ "(used in the llff dataset only)")
100
+ flags.DEFINE_bool(
101
+ "use_pixel_centers", False,
102
+ "If True, generate rays through the center of each pixel. Note: While "
103
+ "this is the correct way to handle rays, it is not the way rays are "
104
+ "handled in the original NeRF paper. Setting this TRUE yields ~ +1 PSNR "
105
+ "compared to Vanilla NeRF.")
106
+
107
+ # Model Flags
108
+ flags.DEFINE_string("model", "nerf", "name of model to use.")
109
+ flags.DEFINE_float("near", 2., "near clip of volumetric rendering.")
110
+ flags.DEFINE_float("far", 6., "far clip of volumentric rendering.")
111
+ flags.DEFINE_integer("net_depth", 8, "depth of the first part of MLP.")
112
+ flags.DEFINE_integer("net_width", 256, "width of the first part of MLP.")
113
+ flags.DEFINE_integer("net_depth_condition", 1,
114
+ "depth of the second part of MLP.")
115
+ flags.DEFINE_integer("net_width_condition", 128,
116
+ "width of the second part of MLP.")
117
+ flags.DEFINE_float("weight_decay_mult", 0, "The multiplier on weight decay")
118
+ flags.DEFINE_integer(
119
+ "skip_layer", 4, "add a skip connection to the output vector of every"
120
+ "skip_layer layers.")
121
+ flags.DEFINE_integer("num_rgb_channels", 3, "the number of RGB channels.")
122
+ flags.DEFINE_integer("num_sigma_channels", 1,
123
+ "the number of density channels.")
124
+ flags.DEFINE_bool("randomized", True, "use randomized stratified sampling.")
125
+ flags.DEFINE_integer("min_deg_point", 0,
126
+ "Minimum degree of positional encoding for points.")
127
+ flags.DEFINE_integer("max_deg_point", 10,
128
+ "Maximum degree of positional encoding for points.")
129
+ flags.DEFINE_integer("deg_view", 4,
130
+ "Degree of positional encoding for viewdirs.")
131
+ flags.DEFINE_integer(
132
+ "num_coarse_samples", 64,
133
+ "the number of samples on each ray for the coarse model.")
134
+ flags.DEFINE_integer("num_fine_samples", 128,
135
+ "the number of samples on each ray for the fine model.")
136
+ flags.DEFINE_bool("use_viewdirs", True, "use view directions as a condition.")
137
+ flags.DEFINE_float(
138
+ "noise_std", None, "std dev of noise added to regularize sigma output."
139
+ "(used in the llff dataset only)")
140
+ flags.DEFINE_bool("lindisp", False,
141
+ "sampling linearly in disparity rather than depth.")
142
+ flags.DEFINE_string("net_activation", "relu",
143
+ "activation function used within the MLP.")
144
+ flags.DEFINE_string("rgb_activation", "sigmoid",
145
+ "activation function used to produce RGB.")
146
+ flags.DEFINE_string("sigma_activation", "relu",
147
+ "activation function used to produce density.")
148
+ flags.DEFINE_bool(
149
+ "legacy_posenc_order", False,
150
+ "If True, revert the positional encoding feature order to an older version of this codebase."
151
+ )
152
+
153
+ # Train Flags
154
+ flags.DEFINE_float("lr_init", 5e-4, "The initial learning rate.")
155
+ flags.DEFINE_float("lr_final", 5e-6, "The final learning rate.")
156
+ flags.DEFINE_integer(
157
+ "lr_delay_steps", 0, "The number of steps at the beginning of "
158
+ "training to reduce the learning rate by lr_delay_mult")
159
+ flags.DEFINE_float(
160
+ "lr_delay_mult", 1., "A multiplier on the learning rate when the step "
161
+ "is < lr_delay_steps")
162
+ flags.DEFINE_float("grad_max_norm", 0.,
163
+ "The gradient clipping magnitude (disabled if == 0).")
164
+ flags.DEFINE_float("grad_max_val", 0.,
165
+ "The gradient clipping value (disabled if == 0).")
166
+
167
+ flags.DEFINE_integer("max_steps", 1000000,
168
+ "the number of optimization steps.")
169
+ flags.DEFINE_integer("save_every", 10000,
170
+ "the number of steps to save a checkpoint.")
171
+ flags.DEFINE_integer("print_every", 100,
172
+ "the number of steps between reports to tensorboard.")
173
+ flags.DEFINE_integer(
174
+ "render_every", 5000, "the number of steps to render a test image,"
175
+ "better to be x00 for accurate step time record.")
176
+ flags.DEFINE_integer("gc_every", 10000,
177
+ "the number of steps to run python garbage collection.")
178
+ flags.DEFINE_integer("few_shot", -1,
179
+ "the number of images.")
180
+
181
+ # Eval Flags
182
+ flags.DEFINE_bool(
183
+ "eval_once", True,
184
+ "evaluate the model only once if true, otherwise keeping evaluating new"
185
+ "checkpoints if there's any.")
186
+ flags.DEFINE_bool("save_output", True,
187
+ "save predicted images to disk if True.")
188
+ flags.DEFINE_integer(
189
+ "chunk", 1024,
190
+ "the size of chunks for evaluation inferences, set to the value that"
191
+ "fits your GPU/TPU memory.")
192
+ flags.DEFINE_bool("generate_gif_only", False,
193
+ "in eval.py, we only generate GIF file for the trained model")
194
+
195
+
196
+ def update_flags(args):
197
+ """Update the flags in `args` with the contents of the config YAML file."""
198
+ pth = path.join(BASE_DIR, args.config + ".yaml")
199
+ with open_file(pth, "r") as fin:
200
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
201
+ # Only allow args to be updated if they already exist.
202
+ invalid_args = list(set(configs.keys()) - set(dir(args)))
203
+ if invalid_args:
204
+ raise ValueError(f"Invalid args {invalid_args} in {pth}.")
205
+ args.__dict__.update(configs)
206
+
207
+ def open_file(pth, mode="r"):
208
+ if not INTERNAL:
209
+ return open(pth, mode=mode)
210
+
211
+
212
+ def file_exists(pth):
213
+ if not INTERNAL:
214
+ return path.exists(pth)
215
+
216
+
217
+ def listdir(pth):
218
+ if not INTERNAL:
219
+ return os.listdir(pth)
220
+
221
+
222
+ def isdir(pth):
223
+ if not INTERNAL:
224
+ return path.isdir(pth)
225
+
226
+
227
+ def makedirs(pth):
228
+ if not INTERNAL:
229
+ os.makedirs(pth)
230
+
231
+
232
+ def render_image(render_fn, rays, rng, normalize_disp, chunk=8192):
233
+ """Render all the pixels of an image (in test mode).
234
+
235
+ Args:
236
+ render_fn: function, jit-ed render function.
237
+ rays: a `Rays` namedtuple, the rays to be rendered.
238
+ rng: jnp.ndarray, random number generator (used in training mode only).
239
+ normalize_disp: bool, if true then normalize `disp` to [0, 1].
240
+ chunk: int, the size of chunks to render sequentially.
241
+
242
+ Returns:
243
+ rgb: jnp.ndarray, rendered color image.
244
+ disp: jnp.ndarray, rendered disparity image.
245
+ acc: jnp.ndarray, rendered accumulated weights per pixel.
246
+ """
247
+ height, width = rays[0].shape[:2]
248
+ num_rays = height * width
249
+ rays = namedtuple_map(lambda r: r.reshape((num_rays, -1)), rays)
250
+ unused_rng, key_0, key_1 = jax.random.split(rng, 3)
251
+ host_id = jax.host_id()
252
+ results = []
253
+ for i in range(0, num_rays, chunk):
254
+ # pylint: disable=cell-var-from-loop
255
+ chunk_rays = namedtuple_map(lambda r: r[i:i + chunk], rays)
256
+ chunk_size = chunk_rays[0].shape[0]
257
+ rays_remaining = chunk_size % jax.device_count()
258
+ if rays_remaining != 0:
259
+ padding = jax.device_count() - rays_remaining
260
+ chunk_rays = namedtuple_map(
261
+ lambda r: jnp.pad(r, ((0, padding), (0, 0)), mode="edge"), chunk_rays)
262
+ else:
263
+ padding = 0
264
+ # After padding the number of chunk_rays is always divisible by
265
+ # host_count.
266
+ rays_per_host = chunk_rays[0].shape[0] // jax.process_count()
267
+ start, stop = host_id * rays_per_host, (host_id + 1) * rays_per_host
268
+ chunk_rays = namedtuple_map(lambda r: shard(r[start:stop]), chunk_rays)
269
+ chunk_results = render_fn(key_0, key_1, chunk_rays)[-1]
270
+ results.append([unshard(x, padding) for x in chunk_results])
271
+ # pylint: enable=cell-var-from-loop
272
+ rgb, disp, acc = [jnp.concatenate(r, axis=0) for r in zip(*results)]
273
+ # Normalize disp for visualization for ndc_rays in llff front-facing scenes.
274
+ if normalize_disp:
275
+ disp = (disp - disp.min()) / (disp.max() - disp.min())
276
+ return (rgb.reshape((height, width, -1)), disp.reshape(
277
+ (height, width, -1)), acc.reshape((height, width, -1)))
278
+
279
+
280
+ def compute_psnr(mse):
281
+ """Compute psnr value given mse (we assume the maximum pixel value is 1).
282
+
283
+ Args:
284
+ mse: float, mean square error of pixels.
285
+
286
+ Returns:
287
+ psnr: float, the psnr value.
288
+ """
289
+ return -10. * jnp.log(mse) / jnp.log(10.)
290
+
291
+
292
+ def compute_ssim(img0,
293
+ img1,
294
+ max_val,
295
+ filter_size=11,
296
+ filter_sigma=1.5,
297
+ k1=0.01,
298
+ k2=0.03,
299
+ return_map=False):
300
+ """Computes SSIM from two images.
301
+
302
+ This function was modeled after tf.image.ssim, and should produce comparable
303
+ output.
304
+
305
+ Args:
306
+ img0: array. An image of size [..., width, height, num_channels].
307
+ img1: array. An image of size [..., width, height, num_channels].
308
+ max_val: float > 0. The maximum magnitude that `img0` or `img1` can have.
309
+ filter_size: int >= 1. Window size.
310
+ filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering.
311
+ k1: float > 0. One of the SSIM dampening parameters.
312
+ k2: float > 0. One of the SSIM dampening parameters.
313
+ return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned
314
+
315
+ Returns:
316
+ Each image's mean SSIM, or a tensor of individual values if `return_map`.
317
+ """
318
+ # Construct a 1D Gaussian blur filter.
319
+ hw = filter_size // 2
320
+ shift = (2 * hw - filter_size + 1) / 2
321
+ f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma) ** 2
322
+ filt = jnp.exp(-0.5 * f_i)
323
+ filt /= jnp.sum(filt)
324
+
325
+ # Blur in x and y (faster than the 2D convolution).
326
+ filt_fn1 = lambda z: jsp.signal.convolve2d(z, filt[:, None], mode="valid")
327
+ filt_fn2 = lambda z: jsp.signal.convolve2d(z, filt[None, :], mode="valid")
328
+
329
+ # Vmap the blurs to the tensor size, and then compose them.
330
+ num_dims = len(img0.shape)
331
+ map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1])
332
+ for d in map_axes:
333
+ filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d)
334
+ filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d)
335
+ filt_fn = lambda z: filt_fn1(filt_fn2(z))
336
+
337
+ mu0 = filt_fn(img0)
338
+ mu1 = filt_fn(img1)
339
+ mu00 = mu0 * mu0
340
+ mu11 = mu1 * mu1
341
+ mu01 = mu0 * mu1
342
+ sigma00 = filt_fn(img0 ** 2) - mu00
343
+ sigma11 = filt_fn(img1 ** 2) - mu11
344
+ sigma01 = filt_fn(img0 * img1) - mu01
345
+
346
+ # Clip the variances and covariances to valid values.
347
+ # Variance must be non-negative:
348
+ sigma00 = jnp.maximum(0., sigma00)
349
+ sigma11 = jnp.maximum(0., sigma11)
350
+ sigma01 = jnp.sign(sigma01) * jnp.minimum(
351
+ jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01))
352
+
353
+ c1 = (k1 * max_val) ** 2
354
+ c2 = (k2 * max_val) ** 2
355
+ numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
356
+ denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
357
+ ssim_map = numer / denom
358
+ ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims)))
359
+ return ssim_map if return_map else ssim
360
+
361
+
362
+ def save_img(img, pth):
363
+ """Save an image to disk.
364
+
365
+ Args:
366
+ img: jnp.ndarry, [height, width, channels], img will be clipped to [0, 1]
367
+ before saved to pth.
368
+ pth: string, path to save the image to.
369
+ """
370
+ with open_file(pth, "wb") as imgout:
371
+ Image.fromarray(np.array(
372
+ (np.clip(img, 0., 1.) * 255.).astype(jnp.uint8))).save(imgout, "PNG")
373
+
374
+
375
+ def learning_rate_decay(step,
376
+ lr_init,
377
+ lr_final,
378
+ max_steps,
379
+ lr_delay_steps=0,
380
+ lr_delay_mult=1):
381
+ """Continuous learning rate decay function.
382
+
383
+ The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
384
+ is log-linearly interpolated elsewhere (equivalent to exponential decay).
385
+ If lr_delay_steps>0 then the learning rate will be scaled by some smooth
386
+ function of lr_delay_mult, such that the initial learning rate is
387
+ lr_init*lr_delay_mult at the beginning of optimization but will be eased back
388
+ to the normal learning rate when steps>lr_delay_steps.
389
+
390
+ Args:
391
+ step: int, the current optimization step.
392
+ lr_init: float, the initial learning rate.
393
+ lr_final: float, the final learning rate.
394
+ max_steps: int, the number of steps during optimization.
395
+ lr_delay_steps: int, the number of steps to delay the full learning rate.
396
+ lr_delay_mult: float, the multiplier on the rate when delaying it.
397
+
398
+ Returns:
399
+ lr: the learning for current step 'step'.
400
+ """
401
+ if lr_delay_steps > 0:
402
+ # A kind of reverse cosine decay.
403
+ delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
404
+ 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1))
405
+ else:
406
+ delay_rate = 1.
407
+ t = np.clip(step / max_steps, 0, 1)
408
+ log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
409
+ return delay_rate * log_lerp
410
+
411
+
412
+ def shard(xs):
413
+ """Split data into shards for multiple devices along the first dimension."""
414
+ '''
415
+ if 'embedding' in xs:
416
+ xs['pixels'] = jax.tree_map(lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs['pixels'])
417
+ xs['rays'] = jax.tree_map(lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs['rays'])
418
+ xs['embedding'] = np.stack([xs['embedding']]*jax.local_device_count(),0)
419
+ xs['random_rays'] = jax.tree_map(lambda x: np.stack([x]*jax.local_device_count(),0), xs['random_rays'])
420
+ else:
421
+ xs = jax.tree_map(
422
+ lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]) if len(x.shape) != 0 else x
423
+ , xs)
424
+
425
+ return xs
426
+ '''
427
+ return jax.tree_map(
428
+ lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]) if len(x.shape) != 0 else x
429
+ , xs)
430
+
431
+
432
+ def to_device(xs):
433
+ """Transfer data to devices (GPU/TPU)."""
434
+ return jax.tree_map(jnp.array, xs)
435
+
436
+
437
+ def unshard(x, padding=0):
438
+ """Collect the sharded tensor to the shape before sharding."""
439
+ y = x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:]))
440
+ if padding > 0:
441
+ y = y[:-padding]
442
+ return y
443
+
444
+
445
+ def write_pickle(data, fn):
446
+ with open(fn, 'wb') as f:
447
+ pickle.dump(data, f)
448
+ return None
449
+
450
+
451
+ def read_pickle(fn):
452
+ with open(fn, 'rb') as f:
453
+ data = pickle.load(f)
454
+ return data
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy>=1.16.4
2
+ jax>=0.2.6
3
+ jaxlib>=0.1.57
4
+ flax>=0.2.2
5
+ opencv-python>=4.4.0
6
+ Pillow>=7.2.0
7
+ pyyaml>=5.3.1
8
+ tensorboard>=2.4.0
9
+ tensorflow>=2.3.1
10
+ tensorflow-hub>=0.11.0
11
+ transformers==4.8.2
12
+ wandb==0.10.33
13
+ tqdm==4.61.2
14
+ # pip install git+https://github.com/deepmind/jmp # mixed precision for JAX
run.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The Google Research Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ #!/bin/bash
16
+ set -e
17
+ set -x
18
+
19
+ virtualenv -p python3 .
20
+ source ./bin/activate
21
+
22
+ pip install -r jaxnerf/requirements.txt
23
+ pip uninstall jax
24
+ pip install --upgrade pip
25
+ pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
26
+ python -m jaxnerf.train \
27
+ --data_dir=/mnt/data/NeRF_Data/nerf_synthetic/lego \
28
+ --train_dir=test_output \
29
+ --max_steps=5 \
30
+ --factor=2 \
31
+ --batch_size=512 \
32
+ --config=configs/orig_nerf_tpu_vm_test \
33
+ --precompute_pkl_path /mnt/data/NeRF_Data/nerf_synthetic/lego/clip_cache_train_factor4_float32.pkl
train.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Lint as: python3
17
+ """Training script for Nerf."""
18
+ import functools
19
+ import gc
20
+ import time
21
+ from absl import app
22
+ from absl import flags
23
+ import flax
24
+ from flax.metrics import tensorboard
25
+ from flax.training import checkpoints
26
+ import jax
27
+ from jax import config
28
+ from jax import random
29
+ import jax.numpy as jnp
30
+ import numpy as np
31
+ # import wandb
32
+ from tqdm import tqdm
33
+
34
+ from nerf import datasets
35
+ from nerf import models
36
+ from nerf import utils
37
+ from nerf import clip_utils
38
+
39
+ FLAGS = flags.FLAGS
40
+
41
+ utils.define_flags()
42
+ config.parse_flags_with_absl()
43
+
44
+ # set up TPU for colab
45
+ import os
46
+ if "COLAB_TPU_ADDR" in os.environ:
47
+ import jax.tools.colab_tpu
48
+ jax.tools.colab_tpu.setup_tpu()
49
+ print(f"detected device: {jax.local_devices()}")
50
+
51
+
52
+ def train_step(model, clip_model, rng, state, batch, lr, step, K,):
53
+ # TODO make clip_grad input enable
54
+ """One optimization step.
55
+
56
+ Args:
57
+ model: The linen model.
58
+ rng: jnp.ndarray, random number generator.
59
+ state: utils.TrainState, state of the model/optimizer.
60
+ batch: dict, a mini-batch of data for training.
61
+ lr: float, real-time learning rate.
62
+
63
+ Returns:
64
+ new_state: utils.TrainState, new training state.
65
+ stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)].
66
+ rng: jnp.ndarray, updated random number generator.
67
+ """
68
+ rng, key_0, key_1 = random.split(rng, 3)
69
+
70
+ def loss_fn(variables):
71
+ rays = batch["rays"]
72
+ ret = model.apply(variables, key_0, key_1, rays, FLAGS.randomized)
73
+ if len(ret) not in (1, 2):
74
+ raise ValueError(
75
+ "ret should contain either 1 set of output (coarse only), or 2 sets"
76
+ "of output (coarse as ret[0] and fine as ret[1]).")
77
+ # The main prediction is always at the end of the ret list.
78
+ rgb, unused_disp, unused_acc = ret[-1]
79
+ loss = ((rgb - batch["pixels"][Ellipsis, :3]) ** 2).mean()
80
+ psnr = utils.compute_psnr(loss)
81
+ if len(ret) > 1:
82
+ # If there are both coarse and fine predictions, we compute the loss for
83
+ # the coarse prediction (ret[0]) as well.
84
+ rgb_c, unused_disp_c, unused_acc_c = ret[0]
85
+ loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3]) ** 2).mean()
86
+ psnr_c = utils.compute_psnr(loss_c)
87
+ else:
88
+ loss_c = 0.
89
+ psnr_c = 0.
90
+
91
+ def tree_sum_fn(fn):
92
+ return jax.tree_util.tree_reduce(lambda x, y: x + fn(y),
93
+ variables, initializer=0)
94
+
95
+ weight_l2 = (tree_sum_fn(lambda z: jnp.sum(z ** 2)) /
96
+ tree_sum_fn(lambda z: jnp.prod(jnp.array(z.shape))))
97
+
98
+ total_loss = loss + loss_c + FLAGS.weight_decay_mult * weight_l2
99
+ stats = utils.Stats(loss=loss, psnr=psnr, loss_c=loss_c,
100
+ psnr_c=psnr_c, weight_l2=weight_l2)
101
+ return total_loss, stats
102
+
103
+ (_, stats), grad = (
104
+ jax.value_and_grad(loss_fn, has_aux=True)(state.optimizer.target))
105
+ #grad = jax.lax.pmean(grad, axis_name="batch")
106
+ stats = jax.lax.pmean(stats, axis_name="batch")
107
+
108
+ # Clip the gradient by value.
109
+ if FLAGS.grad_max_val > 0:
110
+ clip_fn = lambda z: jnp.clip(z, -FLAGS.grad_max_val, FLAGS.grad_max_val)
111
+ grad = jax.tree_util.tree_map(clip_fn, grad)
112
+
113
+ # Clip the (possibly value-clipped) gradient by norm.
114
+ if FLAGS.grad_max_norm > 0:
115
+ grad_norm = jnp.sqrt(
116
+ jax.tree_util.tree_reduce(
117
+ lambda x, y: x + jnp.sum(y ** 2), grad, initializer=0))
118
+ mult = jnp.minimum(1, FLAGS.grad_max_norm / (1e-7 + grad_norm))
119
+ grad = jax.tree_util.tree_map(lambda z: mult * z, grad)
120
+
121
+ return grad, stats, rng
122
+ new_optimizer = state.optimizer.apply_gradient(grad, learning_rate =lr)
123
+ new_state = state.replace(optimizer=new_optimizer)
124
+ return new_state, stats, rng
125
+
126
+ def update_step(state, grad, lr):
127
+ grad = jax.lax.pmean(grad, axis_name="batch")
128
+ new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr)
129
+ new_state = state.replace(optimizer=new_optimizer)
130
+ return new_state
131
+
132
+ def main(unused_argv):
133
+ #wandb.init(project="hf-flax-clip-nerf", entity="wandb", sync_tensorboard=True)
134
+ rng = random.PRNGKey(20200823)
135
+ # Shift the numpy random seed by host_id() to shuffle data loaded by different
136
+ # hosts.
137
+ np.random.seed(20201473 + jax.host_id())
138
+
139
+ if FLAGS.config is not None:
140
+ utils.update_flags(FLAGS)
141
+ if FLAGS.batch_size % jax.device_count() != 0:
142
+ raise ValueError("Batch size must be divisible by the number of devices.")
143
+ if FLAGS.train_dir is None:
144
+ raise ValueError("train_dir must be set. None set now.")
145
+ if FLAGS.data_dir is None:
146
+ raise ValueError("data_dir must be set. None set now.")
147
+
148
+ # setup CLIP model
149
+ if FLAGS.use_semantic_loss:
150
+ clip_model = clip_utils.init_CLIP(FLAGS.clip_output_dtype,
151
+ FLAGS.clip_model_name)
152
+ print(f'semantic loss ACTIVATED, CLIP is set up '
153
+ f'(sc_loss_mult: {FLAGS.sc_loss_mult})')
154
+ else:
155
+ clip_model = None
156
+ print('semantic loss DEACTIVATED, CLIP is set to None')
157
+
158
+ dataset = datasets.get_dataset("train", FLAGS, clip_model)
159
+ test_dataset = datasets.get_dataset("test", FLAGS, clip_model)
160
+
161
+ # setup NeRF model
162
+ rng, key = random.split(rng)
163
+ model, variables = models.get_model(key, dataset.peek(), FLAGS)
164
+ optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables)
165
+ state = utils.TrainState(optimizer=optimizer)
166
+ del optimizer, variables
167
+ learning_rate_fn = functools.partial(
168
+ utils.learning_rate_decay,
169
+ lr_init=FLAGS.lr_init,
170
+ lr_final=FLAGS.lr_final,
171
+ max_steps=FLAGS.max_steps,
172
+ lr_delay_steps=FLAGS.lr_delay_steps,
173
+ lr_delay_mult=FLAGS.lr_delay_mult)
174
+
175
+ train_pstep = jax.pmap(
176
+ functools.partial(train_step, model, clip_model),
177
+ axis_name="batch",
178
+ in_axes=(0, 0, 0, None, None, None),
179
+ donate_argnums=(2,))
180
+
181
+ update_pstep = jax.pmap(
182
+ functools.partial(update_step,),
183
+ axis_name="batch",
184
+ in_axes=(0, 0, None),
185
+ donate_argnums=(0,))
186
+
187
+ def render_fn(variables, key_0, key_1, rays):
188
+ return model.apply(variables, key_0, key_1, rays, FLAGS.randomized)
189
+
190
+ render_pfn = jax.pmap(
191
+ render_fn,
192
+ in_axes=(None, None, None, 0), # Only distribute the data input.
193
+ donate_argnums=(3,),
194
+ axis_name="batch")
195
+
196
+ def render_fn_(variables, key_0, key_1, rays):
197
+ return model.apply(variables, key_0, key_1, rays, False, True)
198
+
199
+ render_pfn_ = jax.pmap(
200
+ render_fn_,
201
+ in_axes=(None, None, None, 0), # Only distribute the data input.
202
+ donate_argnums=(3,),
203
+ axis_name="batch")
204
+
205
+ # Compiling to the CPU because it's faster and more accurate.
206
+ ssim_fn = jax.jit(
207
+ functools.partial(utils.compute_ssim, max_val=1.), backend="cpu")
208
+
209
+ if not utils.isdir(FLAGS.train_dir):
210
+ utils.makedirs(FLAGS.train_dir)
211
+ state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
212
+ # Resume training a the step of the last checkpoint.
213
+ init_step = state.optimizer.state.step + 1
214
+
215
+ # for distributive training
216
+ state = flax.jax_utils.replicate(state)
217
+ if jax.host_id() == 0:
218
+ summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)
219
+
220
+ # Prefetch_buffer_size = 3 x batch_size
221
+ pdataset = flax.jax_utils.prefetch_to_device(dataset, 3)
222
+ n_local_devices = jax.local_device_count()
223
+ rng = rng + jax.host_id() # Make random seed separate across hosts.
224
+ keys = random.split(rng, n_local_devices) # For pmapping RNG keys.
225
+ gc.disable() # Disable automatic garbage collection for efficiency.
226
+ stats_trace = []
227
+ reset_timer = True
228
+
229
+ # for semantic loss update
230
+ sc_image = None
231
+ sc_loss = 0.
232
+
233
+ for step, batch in tqdm(zip(range(init_step, FLAGS.max_steps + 1), pdataset)):
234
+ if reset_timer:
235
+ t_loop_start = time.time()
236
+ reset_timer = False
237
+ lr = learning_rate_fn(step)
238
+
239
+ grad, stats, keys = train_pstep(keys, state, batch, lr, step, FLAGS.sc_loss_every)
240
+
241
+ if step%FLAGS.sc_loss_every == 0 and FLAGS.use_semantic_loss:
242
+ sc_batch = dataset.get_clip_data()
243
+ if jax.local_device_count() > 1:
244
+ sc_loss, sc_grad, sc_image = clip_utils.semantic_step_multi(render_pfn_, clip_model, keys[0], state, sc_batch, lr)
245
+ else:
246
+ sc_loss, sc_grad, sc_image = clip_utils.semantic_step_single(model, clip_model, keys[0], state, sc_batch, lr)
247
+
248
+ if jax.host_id() == 0 and step%FLAGS.print_every:
249
+ for mlp_k, mlp in grad['params'].items():
250
+ for layer_k, layer_g in mlp.items():
251
+ summary_writer.scalar("%s/%s/kernel_grad"%(mlp_k, layer_k), jnp.linalg.norm(jnp.mean(layer_g['kernel'],0)), step)
252
+ for mlp_k, mlp in sc_grad['params'].items():
253
+ for layer_k, layer_g in mlp.items():
254
+ summary_writer.scalar("%s/%s/kernel_sc_grad"%(mlp_k, layer_k), jnp.linalg.norm(layer_g['kernel']), step)
255
+
256
+ leaves, treedef = jax.tree_flatten(grad)
257
+ sc_leaves, _ = jax.tree_flatten(sc_grad)
258
+ grad = treedef.unflatten(g+jnp.expand_dims(sc_g,0) for g, sc_g in zip(leaves, sc_leaves))
259
+
260
+
261
+
262
+ state = update_pstep(state, grad, lr)
263
+
264
+ if jax.host_id() == 0:
265
+ stats_trace.append(stats)
266
+ if step % FLAGS.gc_every == 0:
267
+ gc.collect()
268
+
269
+ # Log training summaries. This is put behind a host_id check because in
270
+ # multi-host evaluation, all hosts need to run inference even though we
271
+ # only use host 0 to record results.
272
+ if jax.host_id() == 0:
273
+ if step % FLAGS.print_every == 0:
274
+ summary_writer.scalar("loss/train", stats.loss[0], step)
275
+ summary_writer.scalar("sc_loss", sc_loss, step)
276
+ summary_writer.scalar("psnr/train", stats.psnr[0], step)
277
+ summary_writer.scalar("train_coarse/loss", stats.loss_c[0], step)
278
+ summary_writer.scalar("train_coarse/psnr", stats.psnr_c[0], step)
279
+ summary_writer.scalar("weight_l2", stats.weight_l2[0], step)
280
+ avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
281
+ avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
282
+ stats_trace = []
283
+ summary_writer.scalar("train_avg/loss", avg_loss, step)
284
+ summary_writer.scalar("train_avg/psnr", avg_psnr, step)
285
+ summary_writer.scalar("learning_rate", lr, step)
286
+ steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
287
+ reset_timer = True
288
+ rays_per_sec = FLAGS.batch_size * steps_per_sec
289
+ summary_writer.scalar("train_steps_per_sec", steps_per_sec, step)
290
+ summary_writer.scalar("train_rays_per_sec", rays_per_sec, step)
291
+ precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
292
+ print(("{:" + "{:d}".format(precision) + "d}").format(step) +
293
+ f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " +
294
+ f"avg_loss={avg_loss:0.4f}, " +
295
+ f"weight_l2={stats.weight_l2[0]:0.2e}, " +
296
+ # f"sc_loss={sc_loss:0.4f}, " +
297
+ f"lr={lr:0.2e}, {rays_per_sec:0.0f} rays/sec")
298
+ if step % FLAGS.save_every == 0:
299
+ state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
300
+ checkpoints.save_checkpoint(
301
+ FLAGS.train_dir, state_to_save, int(step), keep=100)
302
+
303
+ # Test-set evaluation.
304
+ if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
305
+ # We reuse the same random number generator from the optimization step
306
+ # here on purpose so that the visualization matches what happened in
307
+ # training.
308
+ t_eval_start = time.time()
309
+ eval_variables = jax.device_get(jax.tree_map(lambda x: x[0],
310
+ state)).optimizer.target
311
+ test_case = next(test_dataset)
312
+ pred_color, pred_disp, pred_acc = utils.render_image(
313
+ functools.partial(render_pfn, eval_variables),
314
+ test_case["rays"],
315
+ keys[0],
316
+ FLAGS.dataset == "llff",
317
+ chunk=FLAGS.chunk)
318
+
319
+ # Log eval summaries on host 0.
320
+ if jax.host_id() == 0:
321
+ psnr = utils.compute_psnr(
322
+ ((pred_color - test_case["pixels"]) ** 2).mean())
323
+ ssim = ssim_fn(pred_color, test_case["pixels"])
324
+ eval_time = time.time() - t_eval_start
325
+ num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1]))
326
+ rays_per_sec = num_rays / eval_time
327
+ summary_writer.scalar("test_rays_per_sec", rays_per_sec, step)
328
+ print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec")
329
+ summary_writer.scalar("psnr/test", psnr, step)
330
+ summary_writer.scalar("test_psnr", psnr, step)
331
+ summary_writer.scalar("ssim/ssim", ssim, step)
332
+ summary_writer.scalar("test_ssim", ssim, step)
333
+ if sc_image is not None:
334
+ summary_writer .image("random_ray_image", sc_image, step)
335
+ summary_writer.image("test_pred_color", pred_color, step)
336
+ summary_writer.image("test_pred_disp", pred_disp, step)
337
+ summary_writer.image("test_pred_acc", pred_acc, step)
338
+ summary_writer.image("test_target", test_case["pixels"], step)
339
+
340
+ if FLAGS.max_steps % FLAGS.save_every != 0:
341
+ state = jax.device_get(jax.tree_map(lambda x: x[0], state))
342
+ checkpoints.save_checkpoint(
343
+ FLAGS.train_dir, state, int(FLAGS.max_steps), keep=100)
344
+
345
+
346
+ if __name__ == "__main__":
347
+ app.run(main)
train.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The Google Research Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ #!/bin/bash
16
+ CONFIG=$1
17
+ DATA_ROOT=$2
18
+ ROOT_DIR=/tmp/jaxnerf/"$CONFIG"
19
+ if [ $CONFIG == "llff" ]
20
+ then
21
+ SCENES="room fern leaves fortress orchids flower trex horns"
22
+ DATA_FOLDER="nerf_llff_data"
23
+ else
24
+ SCENES="lego chair drums ficus hotdog materials mic ship"
25
+ DATA_FOLDER="nerf_synthetic"
26
+ fi
27
+
28
+ # launch training jobs for all scenes.
29
+ for scene in $SCENES; do
30
+ python -m jaxnerf.train \
31
+ --data_dir="$DATA_ROOT"/"$DATA_FOLDER"/"$scene" \
32
+ --train_dir="$ROOT_DIR"/"$scene" \
33
+ --config=configs/"$CONFIG"
34
+ done