liuyuan-pal commited on
Commit
d0f39be
1 Parent(s): df62e57

add models

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ckpt/* filter=lfs diff=lfs merge=lfs -text
blender_script.py DELETED
@@ -1,282 +0,0 @@
1
- """Blender script to render images of 3D models.
2
-
3
- This script is used to render images of 3D models. It takes in a list of paths
4
- to .glb files and renders images of each model. The images are from rotating the
5
- object around the origin. The images are saved to the output directory.
6
-
7
- Example usage:
8
- blender -b -P blender_script.py -- \
9
- --object_path my_object.glb \
10
- --output_dir ./views \
11
- --engine CYCLES \
12
- --scale 0.8 \
13
- --num_images 12 \
14
- --camera_dist 1.2
15
-
16
- Here, input_model_paths.json is a json file containing a list of paths to .glb.
17
- """
18
-
19
- import argparse
20
- import json
21
- import math
22
- import os
23
- import random
24
- import sys
25
- import time
26
- import urllib.request
27
- from pathlib import Path
28
-
29
- from mathutils import Vector, Matrix
30
- import numpy as np
31
-
32
- import bpy
33
- from mathutils import Vector
34
- import pickle
35
-
36
- def read_pickle(pkl_path):
37
- with open(pkl_path, 'rb') as f:
38
- return pickle.load(f)
39
-
40
- def save_pickle(data, pkl_path):
41
- # os.system('mkdir -p {}'.format(os.path.dirname(pkl_path)))
42
- with open(pkl_path, 'wb') as f:
43
- pickle.dump(data, f)
44
-
45
- parser = argparse.ArgumentParser()
46
- parser.add_argument("--object_path", type=str, required=True)
47
- parser.add_argument("--output_dir", type=str, required=True)
48
- parser.add_argument("--engine", type=str, default="CYCLES", choices=["CYCLES", "BLENDER_EEVEE"])
49
- parser.add_argument("--camera_type", type=str, default='even')
50
- parser.add_argument("--num_images", type=int, default=16)
51
- parser.add_argument("--elevation", type=float, default=30)
52
- parser.add_argument("--elevation_start", type=float, default=-10)
53
- parser.add_argument("--elevation_end", type=float, default=40)
54
- parser.add_argument("--device", type=str, default='CUDA')
55
-
56
- argv = sys.argv[sys.argv.index("--") + 1 :]
57
- args = parser.parse_args(argv)
58
-
59
- print('===================', args.engine, '===================')
60
-
61
- context = bpy.context
62
- scene = context.scene
63
- render = scene.render
64
-
65
- cam = scene.objects["Camera"]
66
- cam.location = (0, 1.2, 0)
67
- cam.data.lens = 35
68
- cam.data.sensor_width = 32
69
-
70
- cam_constraint = cam.constraints.new(type="TRACK_TO")
71
- cam_constraint.track_axis = "TRACK_NEGATIVE_Z"
72
- cam_constraint.up_axis = "UP_Y"
73
-
74
- render.engine = args.engine
75
- render.image_settings.file_format = "PNG"
76
- render.image_settings.color_mode = "RGBA"
77
- render.resolution_x = 256
78
- render.resolution_y = 256
79
- render.resolution_percentage = 100
80
-
81
- scene.cycles.device = "GPU"
82
- scene.cycles.samples = 128
83
- scene.cycles.diffuse_bounces = 1
84
- scene.cycles.glossy_bounces = 1
85
- scene.cycles.transparent_max_bounces = 3
86
- scene.cycles.transmission_bounces = 3
87
- scene.cycles.filter_width = 0.01
88
- scene.cycles.use_denoising = True
89
- scene.render.film_transparent = True
90
-
91
- bpy.context.preferences.addons["cycles"].preferences.get_devices()
92
- # Set the device_type
93
- bpy.context.preferences.addons["cycles"].preferences.compute_device_type = args.device # or "OPENCL"
94
- bpy.context.scene.cycles.tile_size = 8192
95
-
96
-
97
- def az_el_to_points(azimuths, elevations):
98
- x = np.cos(azimuths)*np.cos(elevations)
99
- y = np.sin(azimuths)*np.cos(elevations)
100
- z = np.sin(elevations)
101
- return np.stack([x,y,z],-1) #
102
-
103
- def set_camera_location(cam_pt):
104
- # from https://blender.stackexchange.com/questions/18530/
105
- x, y, z = cam_pt # sample_spherical(radius_min=1.5, radius_max=2.2, maxz=2.2, minz=-2.2)
106
- camera = bpy.data.objects["Camera"]
107
- camera.location = x, y, z
108
-
109
- return camera
110
-
111
- def get_calibration_matrix_K_from_blender(camera):
112
- f_in_mm = camera.data.lens
113
- scene = bpy.context.scene
114
- resolution_x_in_px = scene.render.resolution_x
115
- resolution_y_in_px = scene.render.resolution_y
116
- scale = scene.render.resolution_percentage / 100
117
- sensor_width_in_mm = camera.data.sensor_width
118
- sensor_height_in_mm = camera.data.sensor_height
119
- pixel_aspect_ratio = scene.render.pixel_aspect_x / scene.render.pixel_aspect_y
120
-
121
- if camera.data.sensor_fit == 'VERTICAL':
122
- # the sensor height is fixed (sensor fit is horizontal),
123
- # the sensor width is effectively changed with the pixel aspect ratio
124
- s_u = resolution_x_in_px * scale / sensor_width_in_mm / pixel_aspect_ratio
125
- s_v = resolution_y_in_px * scale / sensor_height_in_mm
126
- else: # 'HORIZONTAL' and 'AUTO'
127
- # the sensor width is fixed (sensor fit is horizontal),
128
- # the sensor height is effectively changed with the pixel aspect ratio
129
- s_u = resolution_x_in_px * scale / sensor_width_in_mm
130
- s_v = resolution_y_in_px * scale * pixel_aspect_ratio / sensor_height_in_mm
131
-
132
- # Parameters of intrinsic calibration matrix K
133
- alpha_u = f_in_mm * s_u
134
- alpha_v = f_in_mm * s_u
135
- u_0 = resolution_x_in_px * scale / 2
136
- v_0 = resolution_y_in_px * scale / 2
137
- skew = 0 # only use rectangular pixels
138
-
139
- K = np.asarray(((alpha_u, skew, u_0),
140
- (0, alpha_v, v_0),
141
- (0, 0, 1)),np.float32)
142
- return K
143
-
144
-
145
- def reset_scene() -> None:
146
- """Resets the scene to a clean state."""
147
- # delete everything that isn't part of a camera or a light
148
- for obj in bpy.data.objects:
149
- if obj.type not in {"CAMERA", "LIGHT"}:
150
- bpy.data.objects.remove(obj, do_unlink=True)
151
- # delete all the materials
152
- for material in bpy.data.materials:
153
- bpy.data.materials.remove(material, do_unlink=True)
154
- # delete all the textures
155
- for texture in bpy.data.textures:
156
- bpy.data.textures.remove(texture, do_unlink=True)
157
- # delete all the images
158
- for image in bpy.data.images:
159
- bpy.data.images.remove(image, do_unlink=True)
160
-
161
-
162
- # load the glb model
163
- def load_object(object_path: str) -> None:
164
- """Loads a glb model into the scene."""
165
- if object_path.endswith(".glb"):
166
- bpy.ops.import_scene.gltf(filepath=object_path, merge_vertices=True)
167
- elif object_path.endswith(".fbx"):
168
- bpy.ops.import_scene.fbx(filepath=object_path)
169
- else:
170
- raise ValueError(f"Unsupported file type: {object_path}")
171
-
172
-
173
- def scene_bbox(single_obj=None, ignore_matrix=False):
174
- bbox_min = (math.inf,) * 3
175
- bbox_max = (-math.inf,) * 3
176
- found = False
177
- for obj in scene_meshes() if single_obj is None else [single_obj]:
178
- found = True
179
- for coord in obj.bound_box:
180
- coord = Vector(coord)
181
- if not ignore_matrix:
182
- coord = obj.matrix_world @ coord
183
- bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
184
- bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
185
- if not found:
186
- raise RuntimeError("no objects in scene to compute bounding box for")
187
- return Vector(bbox_min), Vector(bbox_max)
188
-
189
-
190
- def scene_root_objects():
191
- for obj in bpy.context.scene.objects.values():
192
- if not obj.parent:
193
- yield obj
194
-
195
-
196
- def scene_meshes():
197
- for obj in bpy.context.scene.objects.values():
198
- if isinstance(obj.data, (bpy.types.Mesh)):
199
- yield obj
200
-
201
- # function from https://github.com/panmari/stanford-shapenet-renderer/blob/master/render_blender.py
202
- def get_3x4_RT_matrix_from_blender(cam):
203
- bpy.context.view_layer.update()
204
- location, rotation = cam.matrix_world.decompose()[0:2]
205
- R = np.asarray(rotation.to_matrix())
206
- t = np.asarray(location)
207
-
208
- cam_rec = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32)
209
- R = R.T
210
- t = -R @ t
211
- R_world2cv = cam_rec @ R
212
- t_world2cv = cam_rec @ t
213
-
214
- RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1)
215
- return RT
216
-
217
- def normalize_scene():
218
- bbox_min, bbox_max = scene_bbox()
219
- scale = 1 / max(bbox_max - bbox_min)
220
- for obj in scene_root_objects():
221
- obj.scale = obj.scale * scale
222
- # Apply scale to matrix_world.
223
- bpy.context.view_layer.update()
224
- bbox_min, bbox_max = scene_bbox()
225
- offset = -(bbox_min + bbox_max) / 2
226
- for obj in scene_root_objects():
227
- obj.matrix_world.translation += offset
228
- bpy.ops.object.select_all(action="DESELECT")
229
-
230
- def save_images(object_file: str) -> None:
231
- object_uid = os.path.basename(object_file).split(".")[0]
232
- os.makedirs(args.output_dir, exist_ok=True)
233
-
234
- reset_scene()
235
- # load the object
236
- load_object(object_file)
237
- # object_uid = os.path.basename(object_file).split(".")[0]
238
- normalize_scene()
239
-
240
- # create an empty object to track
241
- empty = bpy.data.objects.new("Empty", None)
242
- scene.collection.objects.link(empty)
243
- cam_constraint.target = empty
244
-
245
- world_tree = bpy.context.scene.world.node_tree
246
- back_node = world_tree.nodes['Background']
247
- env_light = 0.5
248
- back_node.inputs['Color'].default_value = Vector([env_light, env_light, env_light, 1.0])
249
- back_node.inputs['Strength'].default_value = 1.0
250
-
251
- distances = np.asarray([1.5 for _ in range(args.num_images)])
252
- if args.camera_type=='fixed':
253
- azimuths = (np.arange(args.num_images)/args.num_images*np.pi*2).astype(np.float32)
254
- elevations = np.deg2rad(np.asarray([args.elevation] * args.num_images).astype(np.float32))
255
- elif args.camera_type=='random':
256
- azimuths = (np.arange(args.num_images) / args.num_images * np.pi * 2).astype(np.float32)
257
- elevations = np.random.uniform(args.elevation_start, args.elevation_end, args.num_images)
258
- elevations = np.deg2rad(elevations)
259
- else:
260
- raise NotImplementedError
261
-
262
- cam_pts = az_el_to_points(azimuths, elevations) * distances[:,None]
263
- cam_poses = []
264
- (Path(args.output_dir) / object_uid).mkdir(exist_ok=True, parents=True)
265
- for i in range(args.num_images):
266
- # set camera
267
- camera = set_camera_location(cam_pts[i])
268
- RT = get_3x4_RT_matrix_from_blender(camera)
269
- cam_poses.append(RT)
270
-
271
- render_path = os.path.join(args.output_dir, object_uid, f"{i:03d}.png")
272
- if os.path.exists(render_path): continue
273
- scene.render.filepath = os.path.abspath(render_path)
274
- bpy.ops.render.render(write_still=True)
275
-
276
- if args.camera_type=='random':
277
- K = get_calibration_matrix_K_from_blender(camera)
278
- cam_poses = np.stack(cam_poses, 0)
279
- save_pickle([K, azimuths, elevations, distances, cam_poses], os.path.join(args.output_dir, object_uid, "meta.pkl"))
280
-
281
- if __name__ == "__main__":
282
- save_images(args.object_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ckpt/ViT-L-14.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836
3
+ size 932768134
ckpt/syncdreamer-pretrain.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ebb31334d9e4002b2590dd805e25238beaf95fa082f6e39a132344624448dcb
3
+ size 5570034171
foreground_segment.py DELETED
@@ -1,50 +0,0 @@
1
- import cv2
2
- import argparse
3
- import numpy as np
4
-
5
- import torch
6
- from PIL import Image
7
-
8
-
9
- class BackgroundRemoval:
10
- def __init__(self, device='cuda'):
11
- from carvekit.api.high import HiInterface
12
- self.interface = HiInterface(
13
- object_type="object", # Can be "object" or "hairs-like".
14
- batch_size_seg=5,
15
- batch_size_matting=1,
16
- device=device,
17
- seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
18
- matting_mask_size=2048,
19
- trimap_prob_threshold=231,
20
- trimap_dilation=30,
21
- trimap_erosion_iters=5,
22
- fp16=True,
23
- )
24
-
25
- @torch.no_grad()
26
- def __call__(self, image):
27
- # image: [H, W, 3] array in [0, 255].
28
- image = Image.fromarray(image)
29
- image = self.interface([image])[0]
30
- image = np.array(image)
31
- return image
32
-
33
- def process(image_path, mask_path):
34
- mask_predictor = BackgroundRemoval()
35
- image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
36
- if image.shape[-1] == 4:
37
- image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
38
- else:
39
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
40
- rgba = mask_predictor(image) # [H, W, 4]
41
- cv2.imwrite(mask_path, cv2.cvtColor(rgba, cv2.COLOR_RGBA2BGRA))
42
-
43
-
44
- if __name__ == '__main__':
45
- parser = argparse.ArgumentParser()
46
- parser.add_argument('--input', required=True, type=str)
47
- parser.add_argument('--output', required=True, type=str)
48
- opt = parser.parse_args()
49
-
50
- process(opt.input, opt.output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raymarching/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .raymarching import *
 
 
raymarching/backend.py DELETED
@@ -1,40 +0,0 @@
1
- import os
2
- from torch.utils.cpp_extension import load
3
-
4
- _src_path = os.path.dirname(os.path.abspath(__file__))
5
-
6
- nvcc_flags = [
7
- '-O3', '-std=c++14',
8
- '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9
- ]
10
-
11
- if os.name == "posix":
12
- c_flags = ['-O3', '-std=c++14']
13
- elif os.name == "nt":
14
- c_flags = ['/O2', '/std:c++17']
15
-
16
- # find cl.exe
17
- def find_cl_path():
18
- import glob
19
- for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
20
- paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
21
- if paths:
22
- return paths[0]
23
-
24
- # If cl.exe is not on path, try to find it.
25
- if os.system("where cl.exe >nul 2>nul") != 0:
26
- cl_path = find_cl_path()
27
- if cl_path is None:
28
- raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
29
- os.environ["PATH"] += ";" + cl_path
30
-
31
- _backend = load(name='_raymarching',
32
- extra_cflags=c_flags,
33
- extra_cuda_cflags=nvcc_flags,
34
- sources=[os.path.join(_src_path, 'src', f) for f in [
35
- 'raymarching.cu',
36
- 'bindings.cpp',
37
- ]],
38
- )
39
-
40
- __all__ = ['_backend']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raymarching/raymarching.py DELETED
@@ -1,373 +0,0 @@
1
- import numpy as np
2
- import time
3
-
4
- import torch
5
- import torch.nn as nn
6
- from torch.autograd import Function
7
- from torch.cuda.amp import custom_bwd, custom_fwd
8
-
9
- try:
10
- import _raymarching as _backend
11
- except ImportError:
12
- from .backend import _backend
13
-
14
-
15
- # ----------------------------------------
16
- # utils
17
- # ----------------------------------------
18
-
19
- class _near_far_from_aabb(Function):
20
- @staticmethod
21
- @custom_fwd(cast_inputs=torch.float32)
22
- def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
23
- ''' near_far_from_aabb, CUDA implementation
24
- Calculate rays' intersection time (near and far) with aabb
25
- Args:
26
- rays_o: float, [N, 3]
27
- rays_d: float, [N, 3]
28
- aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)
29
- min_near: float, scalar
30
- Returns:
31
- nears: float, [N]
32
- fars: float, [N]
33
- '''
34
- if not rays_o.is_cuda: rays_o = rays_o.cuda()
35
- if not rays_d.is_cuda: rays_d = rays_d.cuda()
36
-
37
- rays_o = rays_o.contiguous().view(-1, 3)
38
- rays_d = rays_d.contiguous().view(-1, 3)
39
-
40
- N = rays_o.shape[0] # num rays
41
-
42
- nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
43
- fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
44
-
45
- _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)
46
-
47
- return nears, fars
48
-
49
- near_far_from_aabb = _near_far_from_aabb.apply
50
-
51
-
52
- class _sph_from_ray(Function):
53
- @staticmethod
54
- @custom_fwd(cast_inputs=torch.float32)
55
- def forward(ctx, rays_o, rays_d, radius):
56
- ''' sph_from_ray, CUDA implementation
57
- get spherical coordinate on the background sphere from rays.
58
- Assume rays_o are inside the Sphere(radius).
59
- Args:
60
- rays_o: [N, 3]
61
- rays_d: [N, 3]
62
- radius: scalar, float
63
- Return:
64
- coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface)
65
- '''
66
- if not rays_o.is_cuda: rays_o = rays_o.cuda()
67
- if not rays_d.is_cuda: rays_d = rays_d.cuda()
68
-
69
- rays_o = rays_o.contiguous().view(-1, 3)
70
- rays_d = rays_d.contiguous().view(-1, 3)
71
-
72
- N = rays_o.shape[0] # num rays
73
-
74
- coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)
75
-
76
- _backend.sph_from_ray(rays_o, rays_d, radius, N, coords)
77
-
78
- return coords
79
-
80
- sph_from_ray = _sph_from_ray.apply
81
-
82
-
83
- class _morton3D(Function):
84
- @staticmethod
85
- def forward(ctx, coords):
86
- ''' morton3D, CUDA implementation
87
- Args:
88
- coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...)
89
- TODO: check if the coord range is valid! (current 128 is safe)
90
- Returns:
91
- indices: [N], int32, in [0, 128^3)
92
-
93
- '''
94
- if not coords.is_cuda: coords = coords.cuda()
95
-
96
- N = coords.shape[0]
97
-
98
- indices = torch.empty(N, dtype=torch.int32, device=coords.device)
99
-
100
- _backend.morton3D(coords.int(), N, indices)
101
-
102
- return indices
103
-
104
- morton3D = _morton3D.apply
105
-
106
- class _morton3D_invert(Function):
107
- @staticmethod
108
- def forward(ctx, indices):
109
- ''' morton3D_invert, CUDA implementation
110
- Args:
111
- indices: [N], int32, in [0, 128^3)
112
- Returns:
113
- coords: [N, 3], int32, in [0, 128)
114
-
115
- '''
116
- if not indices.is_cuda: indices = indices.cuda()
117
-
118
- N = indices.shape[0]
119
-
120
- coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)
121
-
122
- _backend.morton3D_invert(indices.int(), N, coords)
123
-
124
- return coords
125
-
126
- morton3D_invert = _morton3D_invert.apply
127
-
128
-
129
- class _packbits(Function):
130
- @staticmethod
131
- @custom_fwd(cast_inputs=torch.float32)
132
- def forward(ctx, grid, thresh, bitfield=None):
133
- ''' packbits, CUDA implementation
134
- Pack up the density grid into a bit field to accelerate ray marching.
135
- Args:
136
- grid: float, [C, H * H * H], assume H % 2 == 0
137
- thresh: float, threshold
138
- Returns:
139
- bitfield: uint8, [C, H * H * H / 8]
140
- '''
141
- if not grid.is_cuda: grid = grid.cuda()
142
- grid = grid.contiguous()
143
-
144
- C = grid.shape[0]
145
- H3 = grid.shape[1]
146
- N = C * H3 // 8
147
-
148
- if bitfield is None:
149
- bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)
150
-
151
- _backend.packbits(grid, N, thresh, bitfield)
152
-
153
- return bitfield
154
-
155
- packbits = _packbits.apply
156
-
157
- # ----------------------------------------
158
- # train functions
159
- # ----------------------------------------
160
-
161
- class _march_rays_train(Function):
162
- @staticmethod
163
- @custom_fwd(cast_inputs=torch.float32)
164
- def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024):
165
- ''' march rays to generate points (forward only)
166
- Args:
167
- rays_o/d: float, [N, 3]
168
- bound: float, scalar
169
- density_bitfield: uint8: [CHHH // 8]
170
- C: int
171
- H: int
172
- nears/fars: float, [N]
173
- step_counter: int32, (2), used to count the actual number of generated points.
174
- mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)
175
- perturb: bool
176
- align: int, pad output so its size is dividable by align, set to -1 to disable.
177
- force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.
178
- dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
179
- max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
180
- Returns:
181
- xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)
182
- dirs: float, [M, 3], all generated points' view dirs.
183
- deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth)
184
- rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0]
185
- '''
186
-
187
- if not rays_o.is_cuda: rays_o = rays_o.cuda()
188
- if not rays_d.is_cuda: rays_d = rays_d.cuda()
189
- if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda()
190
-
191
- rays_o = rays_o.contiguous().view(-1, 3)
192
- rays_d = rays_d.contiguous().view(-1, 3)
193
- density_bitfield = density_bitfield.contiguous()
194
-
195
- N = rays_o.shape[0] # num rays
196
- M = N * max_steps # init max points number in total
197
-
198
- # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp)
199
- # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated.
200
- if not force_all_rays and mean_count > 0:
201
- if align > 0:
202
- mean_count += align - mean_count % align
203
- M = mean_count
204
-
205
- xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
206
- dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
207
- deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)
208
- rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps
209
-
210
- if step_counter is None:
211
- step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter
212
-
213
- if perturb:
214
- noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
215
- else:
216
- noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)
217
-
218
- _backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number
219
-
220
- #print(step_counter, M)
221
-
222
- # only used at the first (few) epochs.
223
- if force_all_rays or mean_count <= 0:
224
- m = step_counter[0].item() # D2H copy
225
- if align > 0:
226
- m += align - m % align
227
- xyzs = xyzs[:m]
228
- dirs = dirs[:m]
229
- deltas = deltas[:m]
230
-
231
- torch.cuda.empty_cache()
232
-
233
- return xyzs, dirs, deltas, rays
234
-
235
- march_rays_train = _march_rays_train.apply
236
-
237
-
238
- class _composite_rays_train(Function):
239
- @staticmethod
240
- @custom_fwd(cast_inputs=torch.float32)
241
- def forward(ctx, sigmas, rgbs, deltas, rays, T_thresh=1e-4):
242
- ''' composite rays' rgbs, according to the ray marching formula.
243
- Args:
244
- rgbs: float, [M, 3]
245
- sigmas: float, [M,]
246
- deltas: float, [M, 2]
247
- rays: int32, [N, 3]
248
- Returns:
249
- weights_sum: float, [N,], the alpha channel
250
- depth: float, [N, ], the Depth
251
- image: float, [N, 3], the RGB channel (after multiplying alpha!)
252
- '''
253
-
254
- sigmas = sigmas.contiguous()
255
- rgbs = rgbs.contiguous()
256
-
257
- M = sigmas.shape[0]
258
- N = rays.shape[0]
259
-
260
- weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
261
- depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
262
- image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
263
-
264
- _backend.composite_rays_train_forward(sigmas, rgbs, deltas, rays, M, N, T_thresh, weights_sum, depth, image)
265
-
266
- ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image)
267
- ctx.dims = [M, N, T_thresh]
268
-
269
- return weights_sum, depth, image
270
-
271
- @staticmethod
272
- @custom_bwd
273
- def backward(ctx, grad_weights_sum, grad_depth, grad_image):
274
-
275
- # NOTE: grad_depth is not used now! It won't be propagated to sigmas.
276
-
277
- grad_weights_sum = grad_weights_sum.contiguous()
278
- grad_image = grad_image.contiguous()
279
-
280
- sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors
281
- M, N, T_thresh = ctx.dims
282
-
283
- grad_sigmas = torch.zeros_like(sigmas)
284
- grad_rgbs = torch.zeros_like(rgbs)
285
-
286
- _backend.composite_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs)
287
-
288
- return grad_sigmas, grad_rgbs, None, None, None
289
-
290
-
291
- composite_rays_train = _composite_rays_train.apply
292
-
293
- # ----------------------------------------
294
- # infer functions
295
- # ----------------------------------------
296
-
297
- class _march_rays(Function):
298
- @staticmethod
299
- @custom_fwd(cast_inputs=torch.float32)
300
- def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024):
301
- ''' march rays to generate points (forward only, for inference)
302
- Args:
303
- n_alive: int, number of alive rays
304
- n_step: int, how many steps we march
305
- rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)
306
- rays_t: float, [N], the alive rays' time, we only use the first n_alive.
307
- rays_o/d: float, [N, 3]
308
- bound: float, scalar
309
- density_bitfield: uint8: [CHHH // 8]
310
- C: int
311
- H: int
312
- nears/fars: float, [N]
313
- align: int, pad output so its size is dividable by align, set to -1 to disable.
314
- perturb: bool/int, int > 0 is used as the random seed.
315
- dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
316
- max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
317
- Returns:
318
- xyzs: float, [n_alive * n_step, 3], all generated points' coords
319
- dirs: float, [n_alive * n_step, 3], all generated points' view dirs.
320
- deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
321
- '''
322
-
323
- if not rays_o.is_cuda: rays_o = rays_o.cuda()
324
- if not rays_d.is_cuda: rays_d = rays_d.cuda()
325
-
326
- rays_o = rays_o.contiguous().view(-1, 3)
327
- rays_d = rays_d.contiguous().view(-1, 3)
328
-
329
- M = n_alive * n_step
330
-
331
- if align > 0:
332
- M += align - (M % align)
333
-
334
- xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
335
- dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
336
- deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth
337
-
338
- if perturb:
339
- # torch.manual_seed(perturb) # test_gui uses spp index as seed
340
- noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)
341
- else:
342
- noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)
343
-
344
- _backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises)
345
-
346
- return xyzs, dirs, deltas
347
-
348
- march_rays = _march_rays.apply
349
-
350
-
351
- class _composite_rays(Function):
352
- @staticmethod
353
- @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
354
- def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2):
355
- ''' composite rays' rgbs, according to the ray marching formula. (for inference)
356
- Args:
357
- n_alive: int, number of alive rays
358
- n_step: int, how many steps we march
359
- rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)
360
- rays_t: float, [N], the alive rays' time
361
- sigmas: float, [n_alive * n_step,]
362
- rgbs: float, [n_alive * n_step, 3]
363
- deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
364
- In-place Outputs:
365
- weights_sum: float, [N,], the alpha channel
366
- depth: float, [N,], the depth value
367
- image: float, [N, 3], the RGB channel (after multiplying alpha!)
368
- '''
369
- _backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image)
370
- return tuple()
371
-
372
-
373
- composite_rays = _composite_rays.apply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raymarching/setup.py DELETED
@@ -1,62 +0,0 @@
1
- import os
2
- from setuptools import setup
3
- from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4
-
5
- _src_path = os.path.dirname(os.path.abspath(__file__))
6
-
7
- nvcc_flags = [
8
- '-O3', '-std=c++14',
9
- '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10
- ]
11
-
12
- if os.name == "posix":
13
- c_flags = ['-O3', '-std=c++14']
14
- elif os.name == "nt":
15
- c_flags = ['/O2', '/std:c++17']
16
-
17
- # find cl.exe
18
- def find_cl_path():
19
- import glob
20
- for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21
- paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22
- if paths:
23
- return paths[0]
24
-
25
- # If cl.exe is not on path, try to find it.
26
- if os.system("where cl.exe >nul 2>nul") != 0:
27
- cl_path = find_cl_path()
28
- if cl_path is None:
29
- raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30
- os.environ["PATH"] += ";" + cl_path
31
-
32
- '''
33
- Usage:
34
-
35
- python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)
36
-
37
- python setup.py install # build extensions and install (copy) to PATH.
38
- pip install . # ditto but better (e.g., dependency & metadata handling)
39
-
40
- python setup.py develop # build extensions and install (symbolic) to PATH.
41
- pip install -e . # ditto but better (e.g., dependency & metadata handling)
42
-
43
- '''
44
- setup(
45
- name='raymarching', # package name, import this to use python API
46
- ext_modules=[
47
- CUDAExtension(
48
- name='_raymarching', # extension name, import this to use CUDA API
49
- sources=[os.path.join(_src_path, 'src', f) for f in [
50
- 'raymarching.cu',
51
- 'bindings.cpp',
52
- ]],
53
- extra_compile_args={
54
- 'cxx': c_flags,
55
- 'nvcc': nvcc_flags,
56
- }
57
- ),
58
- ],
59
- cmdclass={
60
- 'build_ext': BuildExtension,
61
- }
62
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raymarching/src/bindings.cpp DELETED
@@ -1,19 +0,0 @@
1
- #include <torch/extension.h>
2
-
3
- #include "raymarching.h"
4
-
5
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6
- // utils
7
- m.def("packbits", &packbits, "packbits (CUDA)");
8
- m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
9
- m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)");
10
- m.def("morton3D", &morton3D, "morton3D (CUDA)");
11
- m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)");
12
- // train
13
- m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
14
- m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
15
- m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
16
- // infer
17
- m.def("march_rays", &march_rays, "march rays (CUDA)");
18
- m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raymarching/src/raymarching.cu DELETED
@@ -1,914 +0,0 @@
1
- #include <cuda.h>
2
- #include <cuda_fp16.h>
3
- #include <cuda_runtime.h>
4
-
5
- #include <ATen/cuda/CUDAContext.h>
6
- #include <torch/torch.h>
7
-
8
- #include <cstdio>
9
- #include <stdint.h>
10
- #include <stdexcept>
11
- #include <limits>
12
-
13
- #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
14
- #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
15
- #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
16
- #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
17
-
18
-
19
- inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; }
20
- inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; }
21
- inline constexpr __device__ float PI() { return 3.141592653589793f; }
22
- inline constexpr __device__ float RPI() { return 0.3183098861837907f; }
23
-
24
-
25
- template <typename T>
26
- inline __host__ __device__ T div_round_up(T val, T divisor) {
27
- return (val + divisor - 1) / divisor;
28
- }
29
-
30
- inline __host__ __device__ float signf(const float x) {
31
- return copysignf(1.0, x);
32
- }
33
-
34
- inline __host__ __device__ float clamp(const float x, const float min, const float max) {
35
- return fminf(max, fmaxf(min, x));
36
- }
37
-
38
- inline __host__ __device__ void swapf(float& a, float& b) {
39
- float c = a; a = b; b = c;
40
- }
41
-
42
- inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) {
43
- const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z)));
44
- int exponent;
45
- frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ...
46
- return fminf(max_cascade - 1, fmaxf(0, exponent));
47
- }
48
-
49
- inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) {
50
- const float mx = dt * H * 0.5;
51
- int exponent;
52
- frexpf(mx, &exponent);
53
- return fminf(max_cascade - 1, fmaxf(0, exponent));
54
- }
55
-
56
- inline __host__ __device__ uint32_t __expand_bits(uint32_t v)
57
- {
58
- v = (v * 0x00010001u) & 0xFF0000FFu;
59
- v = (v * 0x00000101u) & 0x0F00F00Fu;
60
- v = (v * 0x00000011u) & 0xC30C30C3u;
61
- v = (v * 0x00000005u) & 0x49249249u;
62
- return v;
63
- }
64
-
65
- inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)
66
- {
67
- uint32_t xx = __expand_bits(x);
68
- uint32_t yy = __expand_bits(y);
69
- uint32_t zz = __expand_bits(z);
70
- return xx | (yy << 1) | (zz << 2);
71
- }
72
-
73
- inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x)
74
- {
75
- x = x & 0x49249249;
76
- x = (x | (x >> 2)) & 0xc30c30c3;
77
- x = (x | (x >> 4)) & 0x0f00f00f;
78
- x = (x | (x >> 8)) & 0xff0000ff;
79
- x = (x | (x >> 16)) & 0x0000ffff;
80
- return x;
81
- }
82
-
83
-
84
- ////////////////////////////////////////////////////
85
- ///////////// utils /////////////
86
- ////////////////////////////////////////////////////
87
-
88
- // rays_o/d: [N, 3]
89
- // nears/fars: [N]
90
- // scalar_t should always be float in use.
91
- template <typename scalar_t>
92
- __global__ void kernel_near_far_from_aabb(
93
- const scalar_t * __restrict__ rays_o,
94
- const scalar_t * __restrict__ rays_d,
95
- const scalar_t * __restrict__ aabb,
96
- const uint32_t N,
97
- const float min_near,
98
- scalar_t * nears, scalar_t * fars
99
- ) {
100
- // parallel per ray
101
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
102
- if (n >= N) return;
103
-
104
- // locate
105
- rays_o += n * 3;
106
- rays_d += n * 3;
107
-
108
- const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
109
- const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
110
- const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
111
-
112
- // get near far (assume cube scene)
113
- float near = (aabb[0] - ox) * rdx;
114
- float far = (aabb[3] - ox) * rdx;
115
- if (near > far) swapf(near, far);
116
-
117
- float near_y = (aabb[1] - oy) * rdy;
118
- float far_y = (aabb[4] - oy) * rdy;
119
- if (near_y > far_y) swapf(near_y, far_y);
120
-
121
- if (near > far_y || near_y > far) {
122
- nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
123
- return;
124
- }
125
-
126
- if (near_y > near) near = near_y;
127
- if (far_y < far) far = far_y;
128
-
129
- float near_z = (aabb[2] - oz) * rdz;
130
- float far_z = (aabb[5] - oz) * rdz;
131
- if (near_z > far_z) swapf(near_z, far_z);
132
-
133
- if (near > far_z || near_z > far) {
134
- nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
135
- return;
136
- }
137
-
138
- if (near_z > near) near = near_z;
139
- if (far_z < far) far = far_z;
140
-
141
- if (near < min_near) near = min_near;
142
-
143
- nears[n] = near;
144
- fars[n] = far;
145
- }
146
-
147
-
148
- void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) {
149
-
150
- static constexpr uint32_t N_THREAD = 128;
151
-
152
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
153
- rays_o.scalar_type(), "near_far_from_aabb", ([&] {
154
- kernel_near_far_from_aabb<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), aabb.data_ptr<scalar_t>(), N, min_near, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>());
155
- }));
156
- }
157
-
158
-
159
- // rays_o/d: [N, 3]
160
- // radius: float
161
- // coords: [N, 2]
162
- template <typename scalar_t>
163
- __global__ void kernel_sph_from_ray(
164
- const scalar_t * __restrict__ rays_o,
165
- const scalar_t * __restrict__ rays_d,
166
- const float radius,
167
- const uint32_t N,
168
- scalar_t * coords
169
- ) {
170
- // parallel per ray
171
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
172
- if (n >= N) return;
173
-
174
- // locate
175
- rays_o += n * 3;
176
- rays_d += n * 3;
177
- coords += n * 2;
178
-
179
- const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
180
- const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
181
- const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
182
-
183
- // solve t from || o + td || = radius
184
- const float A = dx * dx + dy * dy + dz * dz;
185
- const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2
186
- const float C = ox * ox + oy * oy + oz * oz - radius * radius;
187
-
188
- const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive)
189
-
190
- // solve theta, phi (assume y is the up axis)
191
- const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz;
192
- const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI)
193
- const float phi = atan2(z, x); // [-PI, PI)
194
-
195
- // normalize to [-1, 1]
196
- coords[0] = 2 * theta * RPI() - 1;
197
- coords[1] = phi * RPI();
198
- }
199
-
200
-
201
- void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) {
202
-
203
- static constexpr uint32_t N_THREAD = 128;
204
-
205
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
206
- rays_o.scalar_type(), "sph_from_ray", ([&] {
207
- kernel_sph_from_ray<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), radius, N, coords.data_ptr<scalar_t>());
208
- }));
209
- }
210
-
211
-
212
- // coords: int32, [N, 3]
213
- // indices: int32, [N]
214
- __global__ void kernel_morton3D(
215
- const int * __restrict__ coords,
216
- const uint32_t N,
217
- int * indices
218
- ) {
219
- // parallel
220
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
221
- if (n >= N) return;
222
-
223
- // locate
224
- coords += n * 3;
225
- indices[n] = __morton3D(coords[0], coords[1], coords[2]);
226
- }
227
-
228
-
229
- void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) {
230
- static constexpr uint32_t N_THREAD = 128;
231
- kernel_morton3D<<<div_round_up(N, N_THREAD), N_THREAD>>>(coords.data_ptr<int>(), N, indices.data_ptr<int>());
232
- }
233
-
234
-
235
- // indices: int32, [N]
236
- // coords: int32, [N, 3]
237
- __global__ void kernel_morton3D_invert(
238
- const int * __restrict__ indices,
239
- const uint32_t N,
240
- int * coords
241
- ) {
242
- // parallel
243
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
244
- if (n >= N) return;
245
-
246
- // locate
247
- coords += n * 3;
248
-
249
- const int ind = indices[n];
250
-
251
- coords[0] = __morton3D_invert(ind >> 0);
252
- coords[1] = __morton3D_invert(ind >> 1);
253
- coords[2] = __morton3D_invert(ind >> 2);
254
- }
255
-
256
-
257
- void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) {
258
- static constexpr uint32_t N_THREAD = 128;
259
- kernel_morton3D_invert<<<div_round_up(N, N_THREAD), N_THREAD>>>(indices.data_ptr<int>(), N, coords.data_ptr<int>());
260
- }
261
-
262
-
263
- // grid: float, [C, H, H, H]
264
- // N: int, C * H * H * H / 8
265
- // density_thresh: float
266
- // bitfield: uint8, [N]
267
- template <typename scalar_t>
268
- __global__ void kernel_packbits(
269
- const scalar_t * __restrict__ grid,
270
- const uint32_t N,
271
- const float density_thresh,
272
- uint8_t * bitfield
273
- ) {
274
- // parallel per byte
275
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
276
- if (n >= N) return;
277
-
278
- // locate
279
- grid += n * 8;
280
-
281
- uint8_t bits = 0;
282
-
283
- #pragma unroll
284
- for (uint8_t i = 0; i < 8; i++) {
285
- bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0;
286
- }
287
-
288
- bitfield[n] = bits;
289
- }
290
-
291
-
292
- void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) {
293
-
294
- static constexpr uint32_t N_THREAD = 128;
295
-
296
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
297
- grid.scalar_type(), "packbits", ([&] {
298
- kernel_packbits<<<div_round_up(N, N_THREAD), N_THREAD>>>(grid.data_ptr<scalar_t>(), N, density_thresh, bitfield.data_ptr<uint8_t>());
299
- }));
300
- }
301
-
302
- ////////////////////////////////////////////////////
303
- ///////////// training /////////////
304
- ////////////////////////////////////////////////////
305
-
306
- // rays_o/d: [N, 3]
307
- // grid: [CHHH / 8]
308
- // xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2]
309
- // dirs: [M, 3]
310
- // rays: [N, 3], idx, offset, num_steps
311
- template <typename scalar_t>
312
- __global__ void kernel_march_rays_train(
313
- const scalar_t * __restrict__ rays_o,
314
- const scalar_t * __restrict__ rays_d,
315
- const uint8_t * __restrict__ grid,
316
- const float bound,
317
- const float dt_gamma, const uint32_t max_steps,
318
- const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M,
319
- const scalar_t* __restrict__ nears,
320
- const scalar_t* __restrict__ fars,
321
- scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas,
322
- int * rays,
323
- int * counter,
324
- const scalar_t* __restrict__ noises
325
- ) {
326
- // parallel per ray
327
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
328
- if (n >= N) return;
329
-
330
- // locate
331
- rays_o += n * 3;
332
- rays_d += n * 3;
333
-
334
- // ray marching
335
- const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
336
- const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
337
- const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
338
- const float rH = 1 / (float)H;
339
- const float H3 = H * H * H;
340
-
341
- const float near = nears[n];
342
- const float far = fars[n];
343
- const float noise = noises[n];
344
-
345
- const float dt_min = 2 * SQRT3() / max_steps;
346
- const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
347
-
348
- float t0 = near;
349
-
350
- // perturb
351
- t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise;
352
-
353
- // first pass: estimation of num_steps
354
- float t = t0;
355
- uint32_t num_steps = 0;
356
-
357
- //if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far);
358
-
359
- while (t < far && num_steps < max_steps) {
360
- // current point
361
- const float x = clamp(ox + t * dx, -bound, bound);
362
- const float y = clamp(oy + t * dy, -bound, bound);
363
- const float z = clamp(oz + t * dz, -bound, bound);
364
-
365
- const float dt = clamp(t * dt_gamma, dt_min, dt_max);
366
-
367
- // get mip level
368
- const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
369
-
370
- const float mip_bound = fminf(scalbnf(1.0f, level), bound);
371
- const float mip_rbound = 1 / mip_bound;
372
-
373
- // convert to nearest grid position
374
- const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
375
- const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
376
- const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
377
-
378
- const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
379
- const bool occ = grid[index / 8] & (1 << (index % 8));
380
-
381
- // if occpuied, advance a small step, and write to output
382
- //if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps);
383
-
384
- if (occ) {
385
- num_steps++;
386
- t += dt;
387
- // else, skip a large step (basically skip a voxel grid)
388
- } else {
389
- // calc distance to next voxel
390
- const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
391
- const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
392
- const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
393
-
394
- const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
395
- // step until next voxel
396
- do {
397
- t += clamp(t * dt_gamma, dt_min, dt_max);
398
- } while (t < tt);
399
- }
400
- }
401
-
402
- //printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min);
403
-
404
- // second pass: really locate and write points & dirs
405
- uint32_t point_index = atomicAdd(counter, num_steps);
406
- uint32_t ray_index = atomicAdd(counter + 1, 1);
407
-
408
- //printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index);
409
-
410
- // write rays
411
- rays[ray_index * 3] = n;
412
- rays[ray_index * 3 + 1] = point_index;
413
- rays[ray_index * 3 + 2] = num_steps;
414
-
415
- if (num_steps == 0) return;
416
- if (point_index + num_steps > M) return;
417
-
418
- xyzs += point_index * 3;
419
- dirs += point_index * 3;
420
- deltas += point_index * 2;
421
-
422
- t = t0;
423
- uint32_t step = 0;
424
-
425
- float last_t = t;
426
-
427
- while (t < far && step < num_steps) {
428
- // current point
429
- const float x = clamp(ox + t * dx, -bound, bound);
430
- const float y = clamp(oy + t * dy, -bound, bound);
431
- const float z = clamp(oz + t * dz, -bound, bound);
432
-
433
- const float dt = clamp(t * dt_gamma, dt_min, dt_max);
434
-
435
- // get mip level
436
- const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
437
-
438
- const float mip_bound = fminf(scalbnf(1.0f, level), bound);
439
- const float mip_rbound = 1 / mip_bound;
440
-
441
- // convert to nearest grid position
442
- const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
443
- const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
444
- const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
445
-
446
- // query grid
447
- const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
448
- const bool occ = grid[index / 8] & (1 << (index % 8));
449
-
450
- // if occpuied, advance a small step, and write to output
451
- if (occ) {
452
- // write step
453
- xyzs[0] = x;
454
- xyzs[1] = y;
455
- xyzs[2] = z;
456
- dirs[0] = dx;
457
- dirs[1] = dy;
458
- dirs[2] = dz;
459
- t += dt;
460
- deltas[0] = dt;
461
- deltas[1] = t - last_t; // used to calc depth
462
- last_t = t;
463
- xyzs += 3;
464
- dirs += 3;
465
- deltas += 2;
466
- step++;
467
- // else, skip a large step (basically skip a voxel grid)
468
- } else {
469
- // calc distance to next voxel
470
- const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
471
- const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
472
- const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
473
- const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
474
- // step until next voxel
475
- do {
476
- t += clamp(t * dt_gamma, dt_min, dt_max);
477
- } while (t < tt);
478
- }
479
- }
480
- }
481
-
482
- void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) {
483
-
484
- static constexpr uint32_t N_THREAD = 128;
485
-
486
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
487
- rays_o.scalar_type(), "march_rays_train", ([&] {
488
- kernel_march_rays_train<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), grid.data_ptr<uint8_t>(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), counter.data_ptr<int>(), noises.data_ptr<scalar_t>());
489
- }));
490
- }
491
-
492
-
493
- // sigmas: [M]
494
- // rgbs: [M, 3]
495
- // deltas: [M, 2]
496
- // rays: [N, 3], idx, offset, num_steps
497
- // weights_sum: [N], final pixel alpha
498
- // depth: [N,]
499
- // image: [N, 3]
500
- template <typename scalar_t>
501
- __global__ void kernel_composite_rays_train_forward(
502
- const scalar_t * __restrict__ sigmas,
503
- const scalar_t * __restrict__ rgbs,
504
- const scalar_t * __restrict__ deltas,
505
- const int * __restrict__ rays,
506
- const uint32_t M, const uint32_t N, const float T_thresh,
507
- scalar_t * weights_sum,
508
- scalar_t * depth,
509
- scalar_t * image
510
- ) {
511
- // parallel per ray
512
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
513
- if (n >= N) return;
514
-
515
- // locate
516
- uint32_t index = rays[n * 3];
517
- uint32_t offset = rays[n * 3 + 1];
518
- uint32_t num_steps = rays[n * 3 + 2];
519
-
520
- // empty ray, or ray that exceed max step count.
521
- if (num_steps == 0 || offset + num_steps > M) {
522
- weights_sum[index] = 0;
523
- depth[index] = 0;
524
- image[index * 3] = 0;
525
- image[index * 3 + 1] = 0;
526
- image[index * 3 + 2] = 0;
527
- return;
528
- }
529
-
530
- sigmas += offset;
531
- rgbs += offset * 3;
532
- deltas += offset * 2;
533
-
534
- // accumulate
535
- uint32_t step = 0;
536
-
537
- scalar_t T = 1.0f;
538
- scalar_t r = 0, g = 0, b = 0, ws = 0, t = 0, d = 0;
539
-
540
- while (step < num_steps) {
541
-
542
- const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
543
- const scalar_t weight = alpha * T;
544
-
545
- r += weight * rgbs[0];
546
- g += weight * rgbs[1];
547
- b += weight * rgbs[2];
548
-
549
- t += deltas[1]; // real delta
550
- d += weight * t;
551
-
552
- ws += weight;
553
-
554
- T *= 1.0f - alpha;
555
-
556
- // minimal remained transmittence
557
- if (T < T_thresh) break;
558
-
559
- //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
560
-
561
- // locate
562
- sigmas++;
563
- rgbs += 3;
564
- deltas += 2;
565
-
566
- step++;
567
- }
568
-
569
- //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
570
-
571
- // write
572
- weights_sum[index] = ws; // weights_sum
573
- depth[index] = d;
574
- image[index * 3] = r;
575
- image[index * 3 + 1] = g;
576
- image[index * 3 + 2] = b;
577
- }
578
-
579
-
580
- void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) {
581
-
582
- static constexpr uint32_t N_THREAD = 128;
583
-
584
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
585
- sigmas.scalar_type(), "composite_rays_train_forward", ([&] {
586
- kernel_composite_rays_train_forward<<<div_round_up(N, N_THREAD), N_THREAD>>>(sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), M, N, T_thresh, weights_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
587
- }));
588
- }
589
-
590
-
591
- // grad_weights_sum: [N,]
592
- // grad: [N, 3]
593
- // sigmas: [M]
594
- // rgbs: [M, 3]
595
- // deltas: [M, 2]
596
- // rays: [N, 3], idx, offset, num_steps
597
- // weights_sum: [N,], weights_sum here
598
- // image: [N, 3]
599
- // grad_sigmas: [M]
600
- // grad_rgbs: [M, 3]
601
- template <typename scalar_t>
602
- __global__ void kernel_composite_rays_train_backward(
603
- const scalar_t * __restrict__ grad_weights_sum,
604
- const scalar_t * __restrict__ grad_image,
605
- const scalar_t * __restrict__ sigmas,
606
- const scalar_t * __restrict__ rgbs,
607
- const scalar_t * __restrict__ deltas,
608
- const int * __restrict__ rays,
609
- const scalar_t * __restrict__ weights_sum,
610
- const scalar_t * __restrict__ image,
611
- const uint32_t M, const uint32_t N, const float T_thresh,
612
- scalar_t * grad_sigmas,
613
- scalar_t * grad_rgbs
614
- ) {
615
- // parallel per ray
616
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
617
- if (n >= N) return;
618
-
619
- // locate
620
- uint32_t index = rays[n * 3];
621
- uint32_t offset = rays[n * 3 + 1];
622
- uint32_t num_steps = rays[n * 3 + 2];
623
-
624
- if (num_steps == 0 || offset + num_steps > M) return;
625
-
626
- grad_weights_sum += index;
627
- grad_image += index * 3;
628
- weights_sum += index;
629
- image += index * 3;
630
- sigmas += offset;
631
- rgbs += offset * 3;
632
- deltas += offset * 2;
633
- grad_sigmas += offset;
634
- grad_rgbs += offset * 3;
635
-
636
- // accumulate
637
- uint32_t step = 0;
638
-
639
- scalar_t T = 1.0f;
640
- const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0];
641
- scalar_t r = 0, g = 0, b = 0, ws = 0;
642
-
643
- while (step < num_steps) {
644
-
645
- const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
646
- const scalar_t weight = alpha * T;
647
-
648
- r += weight * rgbs[0];
649
- g += weight * rgbs[1];
650
- b += weight * rgbs[2];
651
- ws += weight;
652
-
653
- T *= 1.0f - alpha;
654
-
655
- // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation.
656
- // write grad_rgbs
657
- grad_rgbs[0] = grad_image[0] * weight;
658
- grad_rgbs[1] = grad_image[1] * weight;
659
- grad_rgbs[2] = grad_image[2] * weight;
660
-
661
- // write grad_sigmas
662
- grad_sigmas[0] = deltas[0] * (
663
- grad_image[0] * (T * rgbs[0] - (r_final - r)) +
664
- grad_image[1] * (T * rgbs[1] - (g_final - g)) +
665
- grad_image[2] * (T * rgbs[2] - (b_final - b)) +
666
- grad_weights_sum[0] * (1 - ws_final)
667
- );
668
-
669
- //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r);
670
- // minimal remained transmittence
671
- if (T < T_thresh) break;
672
-
673
- // locate
674
- sigmas++;
675
- rgbs += 3;
676
- deltas += 2;
677
- grad_sigmas++;
678
- grad_rgbs += 3;
679
-
680
- step++;
681
- }
682
- }
683
-
684
-
685
- void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs) {
686
-
687
- static constexpr uint32_t N_THREAD = 128;
688
-
689
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
690
- grad_image.scalar_type(), "composite_rays_train_backward", ([&] {
691
- kernel_composite_rays_train_backward<<<div_round_up(N, N_THREAD), N_THREAD>>>(grad_weights_sum.data_ptr<scalar_t>(), grad_image.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), weights_sum.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), M, N, T_thresh, grad_sigmas.data_ptr<scalar_t>(), grad_rgbs.data_ptr<scalar_t>());
692
- }));
693
- }
694
-
695
-
696
- ////////////////////////////////////////////////////
697
- ///////////// infernce /////////////
698
- ////////////////////////////////////////////////////
699
-
700
- template <typename scalar_t>
701
- __global__ void kernel_march_rays(
702
- const uint32_t n_alive,
703
- const uint32_t n_step,
704
- const int* __restrict__ rays_alive,
705
- const scalar_t* __restrict__ rays_t,
706
- const scalar_t* __restrict__ rays_o,
707
- const scalar_t* __restrict__ rays_d,
708
- const float bound,
709
- const float dt_gamma, const uint32_t max_steps,
710
- const uint32_t C, const uint32_t H,
711
- const uint8_t * __restrict__ grid,
712
- const scalar_t* __restrict__ nears,
713
- const scalar_t* __restrict__ fars,
714
- scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas,
715
- const scalar_t* __restrict__ noises
716
- ) {
717
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
718
- if (n >= n_alive) return;
719
-
720
- const int index = rays_alive[n]; // ray id
721
- const float noise = noises[n];
722
-
723
- // locate
724
- rays_o += index * 3;
725
- rays_d += index * 3;
726
- xyzs += n * n_step * 3;
727
- dirs += n * n_step * 3;
728
- deltas += n * n_step * 2;
729
-
730
- const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
731
- const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
732
- const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
733
- const float rH = 1 / (float)H;
734
- const float H3 = H * H * H;
735
-
736
- float t = rays_t[index]; // current ray's t
737
- const float near = nears[index], far = fars[index];
738
-
739
- const float dt_min = 2 * SQRT3() / max_steps;
740
- const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
741
-
742
- // march for n_step steps, record points
743
- uint32_t step = 0;
744
-
745
- // introduce some randomness
746
- t += clamp(t * dt_gamma, dt_min, dt_max) * noise;
747
-
748
- float last_t = t;
749
-
750
- while (t < far && step < n_step) {
751
- // current point
752
- const float x = clamp(ox + t * dx, -bound, bound);
753
- const float y = clamp(oy + t * dy, -bound, bound);
754
- const float z = clamp(oz + t * dz, -bound, bound);
755
-
756
- const float dt = clamp(t * dt_gamma, dt_min, dt_max);
757
-
758
- // get mip level
759
- const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
760
-
761
- const float mip_bound = fminf(scalbnf(1, level), bound);
762
- const float mip_rbound = 1 / mip_bound;
763
-
764
- // convert to nearest grid position
765
- const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
766
- const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
767
- const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
768
-
769
- const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
770
- const bool occ = grid[index / 8] & (1 << (index % 8));
771
-
772
- // if occpuied, advance a small step, and write to output
773
- if (occ) {
774
- // write step
775
- xyzs[0] = x;
776
- xyzs[1] = y;
777
- xyzs[2] = z;
778
- dirs[0] = dx;
779
- dirs[1] = dy;
780
- dirs[2] = dz;
781
- // calc dt
782
- t += dt;
783
- deltas[0] = dt;
784
- deltas[1] = t - last_t; // used to calc depth
785
- last_t = t;
786
- // step
787
- xyzs += 3;
788
- dirs += 3;
789
- deltas += 2;
790
- step++;
791
-
792
- // else, skip a large step (basically skip a voxel grid)
793
- } else {
794
- // calc distance to next voxel
795
- const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
796
- const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
797
- const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
798
- const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
799
- // step until next voxel
800
- do {
801
- t += clamp(t * dt_gamma, dt_min, dt_max);
802
- } while (t < tt);
803
- }
804
- }
805
- }
806
-
807
-
808
- void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) {
809
- static constexpr uint32_t N_THREAD = 128;
810
-
811
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
812
- rays_o.scalar_type(), "march_rays", ([&] {
813
- kernel_march_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), bound, dt_gamma, max_steps, C, H, grid.data_ptr<uint8_t>(), near.data_ptr<scalar_t>(), far.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), noises.data_ptr<scalar_t>());
814
- }));
815
- }
816
-
817
-
818
- template <typename scalar_t>
819
- __global__ void kernel_composite_rays(
820
- const uint32_t n_alive,
821
- const uint32_t n_step,
822
- const float T_thresh,
823
- int* rays_alive,
824
- scalar_t* rays_t,
825
- const scalar_t* __restrict__ sigmas,
826
- const scalar_t* __restrict__ rgbs,
827
- const scalar_t* __restrict__ deltas,
828
- scalar_t* weights_sum, scalar_t* depth, scalar_t* image
829
- ) {
830
- const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
831
- if (n >= n_alive) return;
832
-
833
- const int index = rays_alive[n]; // ray id
834
-
835
- // locate
836
- sigmas += n * n_step;
837
- rgbs += n * n_step * 3;
838
- deltas += n * n_step * 2;
839
-
840
- rays_t += index;
841
- weights_sum += index;
842
- depth += index;
843
- image += index * 3;
844
-
845
- scalar_t t = rays_t[0]; // current ray's t
846
-
847
- scalar_t weight_sum = weights_sum[0];
848
- scalar_t d = depth[0];
849
- scalar_t r = image[0];
850
- scalar_t g = image[1];
851
- scalar_t b = image[2];
852
-
853
- // accumulate
854
- uint32_t step = 0;
855
- while (step < n_step) {
856
-
857
- // ray is terminated if delta == 0
858
- if (deltas[0] == 0) break;
859
-
860
- const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
861
-
862
- /*
863
- T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j)
864
- w_i = alpha_i * T_i
865
- -->
866
- T_i = 1 - \sum_{j=0}^{i-1} w_j
867
- */
868
- const scalar_t T = 1 - weight_sum;
869
- const scalar_t weight = alpha * T;
870
- weight_sum += weight;
871
-
872
- t += deltas[1]; // real delta
873
- d += weight * t;
874
- r += weight * rgbs[0];
875
- g += weight * rgbs[1];
876
- b += weight * rgbs[2];
877
-
878
- //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
879
-
880
- // ray is terminated if T is too small
881
- // use a larger bound to further accelerate inference
882
- if (T < T_thresh) break;
883
-
884
- // locate
885
- sigmas++;
886
- rgbs += 3;
887
- deltas += 2;
888
- step++;
889
- }
890
-
891
- //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
892
-
893
- // rays_alive = -1 means ray is terminated early.
894
- if (step < n_step) {
895
- rays_alive[n] = -1;
896
- } else {
897
- rays_t[0] = t;
898
- }
899
-
900
- weights_sum[0] = weight_sum; // this is the thing I needed!
901
- depth[0] = d;
902
- image[0] = r;
903
- image[1] = g;
904
- image[2] = b;
905
- }
906
-
907
-
908
- void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) {
909
- static constexpr uint32_t N_THREAD = 128;
910
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
911
- image.scalar_type(), "composite_rays", ([&] {
912
- kernel_composite_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, T_thresh, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
913
- }));
914
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
raymarching/src/raymarching.h DELETED
@@ -1,18 +0,0 @@
1
- #pragma once
2
-
3
- #include <stdint.h>
4
- #include <torch/torch.h>
5
-
6
-
7
- void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);
8
- void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords);
9
- void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices);
10
- void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords);
11
- void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield);
12
-
13
- void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises);
14
- void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
15
- void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs);
16
-
17
- void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises);
18
- void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
render_batch.py DELETED
@@ -1,20 +0,0 @@
1
- import subprocess
2
-
3
- from ldm.base_utils import save_pickle
4
-
5
- uids=['6f99fb8c2f1a4252b986ed5a765e1db9','8bba4678f9a349d6a29314ccf337975c','063b1b7d877a402ead76cedb06341681',
6
- '199b7a080622422fac8140b61cc7544a','83784b6f7a064212ab50aaaaeb1d7fa7','5501434a052c49d6a8a8d9a1120fee10',
7
- 'cca62f95635f4b20aea4f35014632a55','d2e8612a21044111a7176da2bd78de05','f9e172dd733644a2b47a824e202c89d5']
8
-
9
- # for uid in uids:
10
- # cmds = ['blender','--background','--python','blender_script.py','--',
11
- # '--object_path',f'objaverse_examples/{uid}/{uid}.glb',
12
- # '--output_dir','./training_examples/input','--camera_type','random']
13
- # subprocess.run(cmds)
14
- #
15
- # cmds = ['blender','--background','--python','blender_script.py','--',
16
- # '--object_path',f'objaverse_examples/{uid}/{uid}.glb',
17
- # '--output_dir','./training_examples/target','--camera_type','fixed']
18
- # subprocess.run(cmds)
19
-
20
- save_pickle(uids, f'training_examples/uid_set.pkl')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
renderer/agg_net.py DELETED
@@ -1,83 +0,0 @@
1
- import torch.nn.functional as F
2
- import torch.nn as nn
3
- import torch
4
-
5
- def weights_init(m):
6
- if isinstance(m, nn.Linear):
7
- nn.init.kaiming_normal_(m.weight.data)
8
- if m.bias is not None:
9
- nn.init.zeros_(m.bias.data)
10
-
11
- class NeRF(nn.Module):
12
- def __init__(self, vol_n=8+8, feat_ch=8+16+32+3, hid_n=64):
13
- super(NeRF, self).__init__()
14
- self.hid_n = hid_n
15
- self.agg = Agg(feat_ch)
16
- self.lr0 = nn.Sequential(nn.Linear(vol_n+16, hid_n), nn.ReLU())
17
- self.sigma = nn.Sequential(nn.Linear(hid_n, 1), nn.Softplus())
18
- self.color = nn.Sequential(
19
- nn.Linear(16+vol_n+feat_ch+hid_n+4, hid_n), # agg_feats+vox_feat+img_feat+lr0_feats+dir
20
- nn.ReLU(),
21
- nn.Linear(hid_n, 1)
22
- )
23
- self.lr0.apply(weights_init)
24
- self.sigma.apply(weights_init)
25
- self.color.apply(weights_init)
26
-
27
- def forward(self, vox_feat, img_feat_rgb_dir, source_img_mask):
28
- # assert torch.sum(torch.sum(source_img_mask,1)<2)==0
29
- b, d, n, _ = img_feat_rgb_dir.shape # b,d,n,f=8+16+32+3+4
30
- agg_feat = self.agg(img_feat_rgb_dir, source_img_mask) # b,d,f=16
31
- x = self.lr0(torch.cat((vox_feat, agg_feat), dim=-1)) # b,d,f=64
32
- sigma = self.sigma(x) # b,d,1
33
-
34
- x = torch.cat((x, vox_feat, agg_feat), dim=-1) # b,d,f=16+16+64
35
- x = x.view(b, d, 1, x.shape[-1]).repeat(1, 1, n, 1)
36
- x = torch.cat((x, img_feat_rgb_dir), dim=-1)
37
- logits = self.color(x)
38
- source_img_mask_ = source_img_mask.reshape(b, 1, n, 1).repeat(1, logits.shape[1], 1, 1) == 0
39
- logits[source_img_mask_] = -1e7
40
- color_weight = F.softmax(logits, dim=-2)
41
- color = torch.sum((img_feat_rgb_dir[..., -7:-4] * color_weight), dim=-2)
42
- return color, sigma
43
-
44
- class Agg(nn.Module):
45
- def __init__(self, feat_ch):
46
- super(Agg, self).__init__()
47
- self.feat_ch = feat_ch
48
- self.view_fc = nn.Sequential(nn.Linear(4, feat_ch), nn.ReLU())
49
- self.view_fc.apply(weights_init)
50
- self.global_fc = nn.Sequential(nn.Linear(feat_ch*3, 32), nn.ReLU())
51
-
52
- self.agg_w_fc = nn.Linear(32, 1)
53
- self.fc = nn.Linear(32, 16)
54
- self.global_fc.apply(weights_init)
55
- self.agg_w_fc.apply(weights_init)
56
- self.fc.apply(weights_init)
57
-
58
- def masked_mean_var(self, img_feat_rgb, source_img_mask):
59
- # img_feat_rgb: b,d,n,f source_img_mask: b,n
60
- b, n = source_img_mask.shape
61
- source_img_mask = source_img_mask.view(b, 1, n, 1)
62
- mean = torch.sum(source_img_mask * img_feat_rgb, dim=-2)/ (torch.sum(source_img_mask, dim=-2) + 1e-5)
63
- var = torch.sum((img_feat_rgb - mean.unsqueeze(-2)) ** 2 * source_img_mask, dim=-2) / (torch.sum(source_img_mask, dim=-2) + 1e-5)
64
- return mean, var
65
-
66
- def forward(self, img_feat_rgb_dir, source_img_mask):
67
- # img_feat_rgb_dir b,d,n,f
68
- b, d, n, _ = img_feat_rgb_dir.shape
69
- view_feat = self.view_fc(img_feat_rgb_dir[..., -4:]) # b,d,n,f-4
70
- img_feat_rgb = img_feat_rgb_dir[..., :-4] + view_feat
71
-
72
- mean_feat, var_feat = self.masked_mean_var(img_feat_rgb, source_img_mask)
73
- var_feat = var_feat.view(b, -1, 1, self.feat_ch).repeat(1, 1, n, 1)
74
- avg_feat = mean_feat.view(b, -1, 1, self.feat_ch).repeat(1, 1, n, 1)
75
-
76
- feat = torch.cat([img_feat_rgb, var_feat, avg_feat], dim=-1) # b,d,n,f
77
- global_feat = self.global_fc(feat) # b,d,n,f
78
- logits = self.agg_w_fc(global_feat) # b,d,n,1
79
- source_img_mask_ = source_img_mask.reshape(b, 1, n, 1).repeat(1, logits.shape[1], 1, 1) == 0
80
- logits[source_img_mask_] = -1e7
81
- agg_w = F.softmax(logits, dim=-2)
82
- im_feat = (global_feat * agg_w).sum(dim=-2)
83
- return self.fc(im_feat)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
renderer/cost_reg_net.py DELETED
@@ -1,95 +0,0 @@
1
- import torch.nn as nn
2
-
3
- class ConvBnReLU3D(nn.Module):
4
- def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, norm_act=nn.BatchNorm3d):
5
- super(ConvBnReLU3D, self).__init__()
6
- self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
7
- self.bn = norm_act(out_channels)
8
- self.relu = nn.ReLU(inplace=True)
9
-
10
- def forward(self, x):
11
- return self.relu(self.bn(self.conv(x)))
12
-
13
- class CostRegNet(nn.Module):
14
- def __init__(self, in_channels, norm_act=nn.BatchNorm3d):
15
- super(CostRegNet, self).__init__()
16
- self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act)
17
-
18
- self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act)
19
- self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act)
20
-
21
- self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act)
22
- self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act)
23
-
24
- self.conv5 = ConvBnReLU3D(32, 64, stride=2, norm_act=norm_act)
25
- self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act)
26
-
27
- self.conv7 = nn.Sequential(
28
- nn.ConvTranspose3d(64, 32, 3, padding=1, output_padding=1, stride=2, bias=False),
29
- norm_act(32)
30
- )
31
-
32
- self.conv9 = nn.Sequential(
33
- nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1, stride=2, bias=False),
34
- norm_act(16)
35
- )
36
-
37
- self.conv11 = nn.Sequential(
38
- nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1,stride=2, bias=False),
39
- norm_act(8)
40
- )
41
- self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False))
42
- self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False))
43
-
44
- def forward(self, x):
45
- conv0 = self.conv0(x)
46
- conv2 = self.conv2(self.conv1(conv0))
47
- conv4 = self.conv4(self.conv3(conv2))
48
- x = self.conv6(self.conv5(conv4))
49
- x = conv4 + self.conv7(x)
50
- del conv4
51
- x = conv2 + self.conv9(x)
52
- del conv2
53
- x = conv0 + self.conv11(x)
54
- del conv0
55
- feat = self.feat_conv(x)
56
- depth = self.depth_conv(x)
57
- return feat, depth
58
-
59
-
60
- class MinCostRegNet(nn.Module):
61
- def __init__(self, in_channels, norm_act=nn.BatchNorm3d):
62
- super(MinCostRegNet, self).__init__()
63
- self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act)
64
-
65
- self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act)
66
- self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act)
67
-
68
- self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act)
69
- self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act)
70
-
71
- self.conv9 = nn.Sequential(
72
- nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1,
73
- stride=2, bias=False),
74
- norm_act(16))
75
-
76
- self.conv11 = nn.Sequential(
77
- nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1,
78
- stride=2, bias=False),
79
- norm_act(8))
80
-
81
- self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False))
82
- self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False))
83
-
84
- def forward(self, x):
85
- conv0 = self.conv0(x)
86
- conv2 = self.conv2(self.conv1(conv0))
87
- conv4 = self.conv4(self.conv3(conv2))
88
- x = conv4
89
- x = conv2 + self.conv9(x)
90
- del conv2
91
- x = conv0 + self.conv11(x)
92
- del conv0
93
- feat = self.feat_conv(x)
94
- depth = self.depth_conv(x)
95
- return feat, depth
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
renderer/dummy_dataset.py DELETED
@@ -1,40 +0,0 @@
1
- import pytorch_lightning as pl
2
- from torch.utils.data import Dataset
3
- import webdataset as wds
4
- from torch.utils.data.distributed import DistributedSampler
5
- class DummyDataset(pl.LightningDataModule):
6
- def __init__(self,seed):
7
- super().__init__()
8
-
9
- def setup(self, stage):
10
- if stage in ['fit']:
11
- self.train_dataset = DummyData(True)
12
- self.val_dataset = DummyData(False)
13
- else:
14
- raise NotImplementedError
15
-
16
- def train_dataloader(self):
17
- return wds.WebLoader(self.train_dataset, batch_size=1, num_workers=0, shuffle=False)
18
-
19
- def val_dataloader(self):
20
- return wds.WebLoader(self.val_dataset, batch_size=1, num_workers=0, shuffle=False)
21
-
22
- def test_dataloader(self):
23
- return wds.WebLoader(DummyData(False))
24
-
25
- class DummyData(Dataset):
26
- def __init__(self,is_train):
27
- self.is_train=is_train
28
-
29
- def __len__(self):
30
- if self.is_train:
31
- return 99999999
32
- else:
33
- return 1
34
-
35
- def __getitem__(self, index):
36
- return {}
37
-
38
-
39
-
40
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
renderer/feature_net.py DELETED
@@ -1,42 +0,0 @@
1
- import torch.nn as nn
2
- import torch.nn.functional as F
3
-
4
- class ConvBnReLU(nn.Module):
5
- def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, norm_act=nn.BatchNorm2d):
6
- super(ConvBnReLU, self).__init__()
7
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
8
- self.bn = norm_act(out_channels)
9
- self.relu = nn.ReLU(inplace=True)
10
-
11
- def forward(self, x):
12
- return self.relu(self.bn(self.conv(x)))
13
-
14
- class FeatureNet(nn.Module):
15
- def __init__(self, norm_act=nn.BatchNorm2d):
16
- super(FeatureNet, self).__init__()
17
- self.conv0 = nn.Sequential(ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act), ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act))
18
- self.conv1 = nn.Sequential(ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act), ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act))
19
- self.conv2 = nn.Sequential(ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act), ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act))
20
-
21
- self.toplayer = nn.Conv2d(32, 32, 1)
22
- self.lat1 = nn.Conv2d(16, 32, 1)
23
- self.lat0 = nn.Conv2d(8, 32, 1)
24
-
25
- self.smooth1 = nn.Conv2d(32, 16, 3, padding=1)
26
- self.smooth0 = nn.Conv2d(32, 8, 3, padding=1)
27
-
28
- def _upsample_add(self, x, y):
29
- return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + y
30
-
31
- def forward(self, x):
32
- conv0 = self.conv0(x)
33
- conv1 = self.conv1(conv0)
34
- conv2 = self.conv2(conv1)
35
- feat2 = self.toplayer(conv2)
36
- feat1 = self._upsample_add(feat2, self.lat1(conv1))
37
- feat0 = self._upsample_add(feat1, self.lat0(conv0))
38
- feat1 = self.smooth1(feat1)
39
- feat0 = self.smooth0(feat0)
40
- return feat2, feat1, feat0
41
-
42
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
renderer/neus_networks.py DELETED
@@ -1,503 +0,0 @@
1
- import math
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- import tinycudann as tcnn
8
-
9
- # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf.
10
- class Embedder:
11
- def __init__(self, **kwargs):
12
- self.kwargs = kwargs
13
- self.create_embedding_fn()
14
-
15
- def create_embedding_fn(self):
16
- embed_fns = []
17
- d = self.kwargs['input_dims']
18
- out_dim = 0
19
- if self.kwargs['include_input']:
20
- embed_fns.append(lambda x: x)
21
- out_dim += d
22
-
23
- max_freq = self.kwargs['max_freq_log2']
24
- N_freqs = self.kwargs['num_freqs']
25
-
26
- if self.kwargs['log_sampling']:
27
- freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
28
- else:
29
- freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, N_freqs)
30
-
31
- for freq in freq_bands:
32
- for p_fn in self.kwargs['periodic_fns']:
33
- embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
34
- out_dim += d
35
-
36
- self.embed_fns = embed_fns
37
- self.out_dim = out_dim
38
-
39
- def embed(self, inputs):
40
- return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
41
-
42
-
43
- def get_embedder(multires, input_dims=3):
44
- embed_kwargs = {
45
- 'include_input': True,
46
- 'input_dims': input_dims,
47
- 'max_freq_log2': multires - 1,
48
- 'num_freqs': multires,
49
- 'log_sampling': True,
50
- 'periodic_fns': [torch.sin, torch.cos],
51
- }
52
-
53
- embedder_obj = Embedder(**embed_kwargs)
54
-
55
- def embed(x, eo=embedder_obj): return eo.embed(x)
56
-
57
- return embed, embedder_obj.out_dim
58
-
59
-
60
- class SDFNetwork(nn.Module):
61
- def __init__(self, d_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5,
62
- scale=1, geometric_init=True, weight_norm=True, inside_outside=False):
63
- super(SDFNetwork, self).__init__()
64
-
65
- dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
66
-
67
- self.embed_fn_fine = None
68
-
69
- if multires > 0:
70
- embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
71
- self.embed_fn_fine = embed_fn
72
- dims[0] = input_ch
73
-
74
- self.num_layers = len(dims)
75
- self.skip_in = skip_in
76
- self.scale = scale
77
-
78
- for l in range(0, self.num_layers - 1):
79
- if l + 1 in self.skip_in:
80
- out_dim = dims[l + 1] - dims[0]
81
- else:
82
- out_dim = dims[l + 1]
83
-
84
- lin = nn.Linear(dims[l], out_dim)
85
-
86
- if geometric_init:
87
- if l == self.num_layers - 2:
88
- if not inside_outside:
89
- torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
90
- torch.nn.init.constant_(lin.bias, -bias)
91
- else:
92
- torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
93
- torch.nn.init.constant_(lin.bias, bias)
94
- elif multires > 0 and l == 0:
95
- torch.nn.init.constant_(lin.bias, 0.0)
96
- torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
97
- torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
98
- elif multires > 0 and l in self.skip_in:
99
- torch.nn.init.constant_(lin.bias, 0.0)
100
- torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
101
- torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
102
- else:
103
- torch.nn.init.constant_(lin.bias, 0.0)
104
- torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
105
-
106
- if weight_norm:
107
- lin = nn.utils.weight_norm(lin)
108
-
109
- setattr(self, "lin" + str(l), lin)
110
-
111
- self.activation = nn.Softplus(beta=100)
112
-
113
- def forward(self, inputs):
114
- inputs = inputs * self.scale
115
- if self.embed_fn_fine is not None:
116
- inputs = self.embed_fn_fine(inputs)
117
-
118
- x = inputs
119
- for l in range(0, self.num_layers - 1):
120
- lin = getattr(self, "lin" + str(l))
121
-
122
- if l in self.skip_in:
123
- x = torch.cat([x, inputs], -1) / np.sqrt(2)
124
-
125
- x = lin(x)
126
-
127
- if l < self.num_layers - 2:
128
- x = self.activation(x)
129
-
130
- return x
131
-
132
- def sdf(self, x):
133
- return self.forward(x)[..., :1]
134
-
135
- def sdf_hidden_appearance(self, x):
136
- return self.forward(x)
137
-
138
- def gradient(self, x):
139
- x.requires_grad_(True)
140
- with torch.enable_grad():
141
- y = self.sdf(x)
142
- d_output = torch.ones_like(y, requires_grad=False, device=y.device)
143
- gradients = torch.autograd.grad(
144
- outputs=y,
145
- inputs=x,
146
- grad_outputs=d_output,
147
- create_graph=True,
148
- retain_graph=True,
149
- only_inputs=True)[0]
150
- return gradients
151
-
152
- def sdf_normal(self, x):
153
- x.requires_grad_(True)
154
- with torch.enable_grad():
155
- y = self.sdf(x)
156
- d_output = torch.ones_like(y, requires_grad=False, device=y.device)
157
- gradients = torch.autograd.grad(
158
- outputs=y,
159
- inputs=x,
160
- grad_outputs=d_output,
161
- create_graph=True,
162
- retain_graph=True,
163
- only_inputs=True)[0]
164
- return y[..., :1].detach(), gradients.detach()
165
-
166
- class SDFNetworkWithFeature(nn.Module):
167
- def __init__(self, cube, dp_in, df_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5,
168
- scale=1, geometric_init=True, weight_norm=True, inside_outside=False, cube_length=0.5):
169
- super().__init__()
170
-
171
- self.register_buffer("cube", cube)
172
- self.cube_length = cube_length
173
- dims = [dp_in+df_in] + [d_hidden for _ in range(n_layers)] + [d_out]
174
-
175
- self.embed_fn_fine = None
176
-
177
- if multires > 0:
178
- embed_fn, input_ch = get_embedder(multires, input_dims=dp_in)
179
- self.embed_fn_fine = embed_fn
180
- dims[0] = input_ch + df_in
181
-
182
- self.num_layers = len(dims)
183
- self.skip_in = skip_in
184
- self.scale = scale
185
-
186
- for l in range(0, self.num_layers - 1):
187
- if l + 1 in self.skip_in:
188
- out_dim = dims[l + 1] - dims[0]
189
- else:
190
- out_dim = dims[l + 1]
191
-
192
- lin = nn.Linear(dims[l], out_dim)
193
-
194
- if geometric_init:
195
- if l == self.num_layers - 2:
196
- if not inside_outside:
197
- torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
198
- torch.nn.init.constant_(lin.bias, -bias)
199
- else:
200
- torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
201
- torch.nn.init.constant_(lin.bias, bias)
202
- elif multires > 0 and l == 0:
203
- torch.nn.init.constant_(lin.bias, 0.0)
204
- torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
205
- torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
206
- elif multires > 0 and l in self.skip_in:
207
- torch.nn.init.constant_(lin.bias, 0.0)
208
- torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
209
- torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
210
- else:
211
- torch.nn.init.constant_(lin.bias, 0.0)
212
- torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
213
-
214
- if weight_norm:
215
- lin = nn.utils.weight_norm(lin)
216
-
217
- setattr(self, "lin" + str(l), lin)
218
-
219
- self.activation = nn.Softplus(beta=100)
220
-
221
- def forward(self, points):
222
- points = points * self.scale
223
-
224
- # note: point*2 because the cube is [-0.5,0.5]
225
- with torch.no_grad():
226
- feats = F.grid_sample(self.cube, points.view(1,-1,1,1,3)/self.cube_length, mode='bilinear', align_corners=True, padding_mode='zeros').detach()
227
- feats = feats.view(self.cube.shape[1], -1).permute(1,0).view(*points.shape[:-1], -1)
228
- if self.embed_fn_fine is not None:
229
- points = self.embed_fn_fine(points)
230
-
231
- x = torch.cat([points, feats], -1)
232
- for l in range(0, self.num_layers - 1):
233
- lin = getattr(self, "lin" + str(l))
234
-
235
- if l in self.skip_in:
236
- x = torch.cat([x, points, feats], -1) / np.sqrt(2)
237
-
238
- x = lin(x)
239
-
240
- if l < self.num_layers - 2:
241
- x = self.activation(x)
242
-
243
- # concat feats
244
- x = torch.cat([x, feats], -1)
245
- return x
246
-
247
- def sdf(self, x):
248
- return self.forward(x)[..., :1]
249
-
250
- def sdf_hidden_appearance(self, x):
251
- return self.forward(x)
252
-
253
- def gradient(self, x):
254
- x.requires_grad_(True)
255
- with torch.enable_grad():
256
- y = self.sdf(x)
257
- d_output = torch.ones_like(y, requires_grad=False, device=y.device)
258
- gradients = torch.autograd.grad(
259
- outputs=y,
260
- inputs=x,
261
- grad_outputs=d_output,
262
- create_graph=True,
263
- retain_graph=True,
264
- only_inputs=True)[0]
265
- return gradients
266
-
267
- def sdf_normal(self, x):
268
- x.requires_grad_(True)
269
- with torch.enable_grad():
270
- y = self.sdf(x)
271
- d_output = torch.ones_like(y, requires_grad=False, device=y.device)
272
- gradients = torch.autograd.grad(
273
- outputs=y,
274
- inputs=x,
275
- grad_outputs=d_output,
276
- create_graph=True,
277
- retain_graph=True,
278
- only_inputs=True)[0]
279
- return y[..., :1].detach(), gradients.detach()
280
-
281
-
282
- class VanillaMLP(nn.Module):
283
- def __init__(self, dim_in, dim_out, n_neurons, n_hidden_layers):
284
- super().__init__()
285
- self.n_neurons, self.n_hidden_layers = n_neurons, n_hidden_layers
286
- self.sphere_init, self.weight_norm = True, True
287
- self.sphere_init_radius = 0.5
288
- self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()]
289
- for i in range(self.n_hidden_layers - 1):
290
- self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()]
291
- self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)]
292
- self.layers = nn.Sequential(*self.layers)
293
-
294
- @torch.cuda.amp.autocast(False)
295
- def forward(self, x):
296
- x = self.layers(x.float())
297
- return x
298
-
299
- def make_linear(self, dim_in, dim_out, is_first, is_last):
300
- layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality
301
- if self.sphere_init:
302
- if is_last:
303
- torch.nn.init.constant_(layer.bias, -self.sphere_init_radius)
304
- torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001)
305
- elif is_first:
306
- torch.nn.init.constant_(layer.bias, 0.0)
307
- torch.nn.init.constant_(layer.weight[:, 3:], 0.0)
308
- torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out))
309
- else:
310
- torch.nn.init.constant_(layer.bias, 0.0)
311
- torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out))
312
- else:
313
- torch.nn.init.constant_(layer.bias, 0.0)
314
- torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
315
-
316
- if self.weight_norm:
317
- layer = nn.utils.weight_norm(layer)
318
- return layer
319
-
320
- def make_activation(self):
321
- if self.sphere_init:
322
- return nn.Softplus(beta=100)
323
- else:
324
- return nn.ReLU(inplace=True)
325
-
326
-
327
- class SDFHashGridNetwork(nn.Module):
328
- def __init__(self, bound=0.5, feats_dim=13):
329
- super().__init__()
330
- self.bound = bound
331
- # max_resolution = 32
332
- # base_resolution = 16
333
- # n_levels = 4
334
- # log2_hashmap_size = 16
335
- # n_features_per_level = 8
336
- max_resolution = 2048
337
- base_resolution = 16
338
- n_levels = 16
339
- log2_hashmap_size = 19
340
- n_features_per_level = 2
341
-
342
- # max_res = base_res * t^(k-1)
343
- per_level_scale = (max_resolution / base_resolution)** (1 / (n_levels - 1))
344
-
345
- self.encoder = tcnn.Encoding(
346
- n_input_dims=3,
347
- encoding_config={
348
- "otype": "HashGrid",
349
- "n_levels": n_levels,
350
- "n_features_per_level": n_features_per_level,
351
- "log2_hashmap_size": log2_hashmap_size,
352
- "base_resolution": base_resolution,
353
- "per_level_scale": per_level_scale,
354
- },
355
- )
356
- self.sdf_mlp = VanillaMLP(n_levels*n_features_per_level+3,feats_dim,64,1)
357
-
358
- def forward(self, x):
359
- shape = x.shape[:-1]
360
- x = x.reshape(-1, 3)
361
- x_ = (x + self.bound) / (2 * self.bound)
362
- feats = self.encoder(x_)
363
- feats = torch.cat([x, feats], 1)
364
-
365
- feats = self.sdf_mlp(feats)
366
- feats = feats.reshape(*shape,-1)
367
- return feats
368
-
369
- def sdf(self, x):
370
- return self(x)[...,:1]
371
-
372
- def gradient(self, x):
373
- x.requires_grad_(True)
374
- with torch.enable_grad():
375
- y = self.sdf(x)
376
- d_output = torch.ones_like(y, requires_grad=False, device=y.device)
377
- gradients = torch.autograd.grad(
378
- outputs=y,
379
- inputs=x,
380
- grad_outputs=d_output,
381
- create_graph=True,
382
- retain_graph=True,
383
- only_inputs=True)[0]
384
- return gradients
385
-
386
- def sdf_normal(self, x):
387
- x.requires_grad_(True)
388
- with torch.enable_grad():
389
- y = self.sdf(x)
390
- d_output = torch.ones_like(y, requires_grad=False, device=y.device)
391
- gradients = torch.autograd.grad(
392
- outputs=y,
393
- inputs=x,
394
- grad_outputs=d_output,
395
- create_graph=True,
396
- retain_graph=True,
397
- only_inputs=True)[0]
398
- return y[..., :1].detach(), gradients.detach()
399
-
400
- class RenderingFFNetwork(nn.Module):
401
- def __init__(self, in_feats_dim=12):
402
- super().__init__()
403
- self.dir_encoder = tcnn.Encoding(
404
- n_input_dims=3,
405
- encoding_config={
406
- "otype": "SphericalHarmonics",
407
- "degree": 4,
408
- },
409
- )
410
- self.color_mlp = tcnn.Network(
411
- n_input_dims = in_feats_dim + 3 + self.dir_encoder.n_output_dims,
412
- n_output_dims = 3,
413
- network_config={
414
- "otype": "FullyFusedMLP",
415
- "activation": "ReLU",
416
- "output_activation": "none",
417
- "n_neurons": 64,
418
- "n_hidden_layers": 2,
419
- },
420
- )
421
-
422
- def forward(self, points, normals, view_dirs, feature_vectors):
423
- normals = F.normalize(normals, dim=-1)
424
- view_dirs = F.normalize(view_dirs, dim=-1)
425
- reflective = torch.sum(view_dirs * normals, -1, keepdim=True) * normals * 2 - view_dirs
426
-
427
- x = torch.cat([feature_vectors, normals, self.dir_encoder(reflective)], -1)
428
- colors = self.color_mlp(x).float()
429
- colors = F.sigmoid(colors)
430
- return colors
431
-
432
- # This implementation is borrowed from IDR: https://github.com/lioryariv/idr
433
- class RenderingNetwork(nn.Module):
434
- def __init__(self, d_feature, d_in, d_out, d_hidden,
435
- n_layers, weight_norm=True, multires_view=0, squeeze_out=True, use_view_dir=True):
436
- super().__init__()
437
-
438
- self.squeeze_out = squeeze_out
439
- self.rgb_act=F.sigmoid
440
- self.use_view_dir=use_view_dir
441
-
442
- dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out]
443
-
444
- self.embedview_fn = None
445
- if multires_view > 0:
446
- embedview_fn, input_ch = get_embedder(multires_view)
447
- self.embedview_fn = embedview_fn
448
- dims[0] += (input_ch - 3)
449
-
450
- self.num_layers = len(dims)
451
-
452
- for l in range(0, self.num_layers - 1):
453
- out_dim = dims[l + 1]
454
- lin = nn.Linear(dims[l], out_dim)
455
-
456
- if weight_norm:
457
- lin = nn.utils.weight_norm(lin)
458
-
459
- setattr(self, "lin" + str(l), lin)
460
-
461
- self.relu = nn.ReLU()
462
-
463
- def forward(self, points, normals, view_dirs, feature_vectors):
464
- if self.use_view_dir:
465
- view_dirs = F.normalize(view_dirs, dim=-1)
466
- normals = F.normalize(normals, dim=-1)
467
- reflective = torch.sum(view_dirs*normals, -1, keepdim=True) * normals * 2 - view_dirs
468
- if self.embedview_fn is not None: reflective = self.embedview_fn(reflective)
469
- rendering_input = torch.cat([points, reflective, normals, feature_vectors], dim=-1)
470
- else:
471
- rendering_input = torch.cat([points, normals, feature_vectors], dim=-1)
472
-
473
- x = rendering_input
474
-
475
- for l in range(0, self.num_layers - 1):
476
- lin = getattr(self, "lin" + str(l))
477
-
478
- x = lin(x)
479
-
480
- if l < self.num_layers - 2:
481
- x = self.relu(x)
482
-
483
- if self.squeeze_out:
484
- x = self.rgb_act(x)
485
- return x
486
-
487
-
488
- class SingleVarianceNetwork(nn.Module):
489
- def __init__(self, init_val, activation='exp'):
490
- super(SingleVarianceNetwork, self).__init__()
491
- self.act = activation
492
- self.register_parameter('variance', nn.Parameter(torch.tensor(init_val)))
493
-
494
- def forward(self, x):
495
- device = x.device
496
- if self.act=='exp':
497
- return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * torch.exp(self.variance * 10.0)
498
- else:
499
- raise NotImplementedError
500
-
501
- def warp(self, x, inv_s):
502
- device = x.device
503
- return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * inv_s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
renderer/ngp_renderer.py DELETED
@@ -1,721 +0,0 @@
1
- import math
2
- import trimesh
3
- import numpy as np
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from packaging import version as pver
9
-
10
- import tinycudann as tcnn
11
- from torch.autograd import Function
12
-
13
- from torch.cuda.amp import custom_bwd, custom_fwd
14
-
15
- import raymarching
16
-
17
- def custom_meshgrid(*args):
18
- # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
19
- if pver.parse(torch.__version__) < pver.parse('1.10'):
20
- return torch.meshgrid(*args)
21
- else:
22
- return torch.meshgrid(*args, indexing='ij')
23
-
24
- def sample_pdf(bins, weights, n_samples, det=False):
25
- # This implementation is from NeRF
26
- # bins: [B, T], old_z_vals
27
- # weights: [B, T - 1], bin weights.
28
- # return: [B, n_samples], new_z_vals
29
-
30
- # Get pdf
31
- weights = weights + 1e-5 # prevent nans
32
- pdf = weights / torch.sum(weights, -1, keepdim=True)
33
- cdf = torch.cumsum(pdf, -1)
34
- cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
35
- # Take uniform samples
36
- if det:
37
- u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
38
- u = u.expand(list(cdf.shape[:-1]) + [n_samples])
39
- else:
40
- u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
41
-
42
- # Invert CDF
43
- u = u.contiguous()
44
- inds = torch.searchsorted(cdf, u, right=True)
45
- below = torch.max(torch.zeros_like(inds - 1), inds - 1)
46
- above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
47
- inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
48
-
49
- matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
50
- cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
51
- bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
52
-
53
- denom = (cdf_g[..., 1] - cdf_g[..., 0])
54
- denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
55
- t = (u - cdf_g[..., 0]) / denom
56
- samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
57
-
58
- return samples
59
-
60
-
61
- def plot_pointcloud(pc, color=None):
62
- # pc: [N, 3]
63
- # color: [N, 3/4]
64
- print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))
65
- pc = trimesh.PointCloud(pc, color)
66
- # axis
67
- axes = trimesh.creation.axis(axis_length=4)
68
- # sphere
69
- sphere = trimesh.creation.icosphere(radius=1)
70
- trimesh.Scene([pc, axes, sphere]).show()
71
-
72
-
73
- class NGPRenderer(nn.Module):
74
- def __init__(self,
75
- bound=1,
76
- cuda_ray=True,
77
- density_scale=1, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance.
78
- min_near=0.2,
79
- density_thresh=0.01,
80
- bg_radius=-1,
81
- ):
82
- super().__init__()
83
-
84
- self.bound = bound
85
- self.cascade = 1
86
- self.grid_size = 128
87
- self.density_scale = density_scale
88
- self.min_near = min_near
89
- self.density_thresh = density_thresh
90
- self.bg_radius = bg_radius # radius of the background sphere.
91
-
92
- # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
93
- # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
94
- aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound])
95
- aabb_infer = aabb_train.clone()
96
- self.register_buffer('aabb_train', aabb_train)
97
- self.register_buffer('aabb_infer', aabb_infer)
98
-
99
- # extra state for cuda raymarching
100
- self.cuda_ray = cuda_ray
101
- if cuda_ray:
102
- # density grid
103
- density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
104
- density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
105
- self.register_buffer('density_grid', density_grid)
106
- self.register_buffer('density_bitfield', density_bitfield)
107
- self.mean_density = 0
108
- self.iter_density = 0
109
- # step counter
110
- step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging...
111
- self.register_buffer('step_counter', step_counter)
112
- self.mean_count = 0
113
- self.local_step = 0
114
-
115
- def forward(self, x, d):
116
- raise NotImplementedError()
117
-
118
- # separated density and color query (can accelerate non-cuda-ray mode.)
119
- def density(self, x):
120
- raise NotImplementedError()
121
-
122
- def color(self, x, d, mask=None, **kwargs):
123
- raise NotImplementedError()
124
-
125
- def reset_extra_state(self):
126
- if not self.cuda_ray:
127
- return
128
- # density grid
129
- self.density_grid.zero_()
130
- self.mean_density = 0
131
- self.iter_density = 0
132
- # step counter
133
- self.step_counter.zero_()
134
- self.mean_count = 0
135
- self.local_step = 0
136
-
137
- def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, bg_color=None, perturb=False, **kwargs):
138
- # rays_o, rays_d: [B, N, 3], assumes B == 1
139
- # bg_color: [3] in range [0, 1]
140
- # return: image: [B, N, 3], depth: [B, N]
141
-
142
- prefix = rays_o.shape[:-1]
143
- rays_o = rays_o.contiguous().view(-1, 3)
144
- rays_d = rays_d.contiguous().view(-1, 3)
145
-
146
- N = rays_o.shape[0] # N = B * N, in fact
147
- device = rays_o.device
148
-
149
- # choose aabb
150
- aabb = self.aabb_train if self.training else self.aabb_infer
151
-
152
- # sample steps
153
- nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near)
154
- nears.unsqueeze_(-1)
155
- fars.unsqueeze_(-1)
156
-
157
- #print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')
158
-
159
- z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T]
160
- z_vals = z_vals.expand((N, num_steps)) # [N, T]
161
- z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars]
162
-
163
- # perturb z_vals
164
- sample_dist = (fars - nears) / num_steps
165
- if perturb:
166
- z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist
167
- #z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.
168
-
169
- # generate xyzs
170
- xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3]
171
- xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip.
172
-
173
- #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
174
-
175
- # query SDF and RGB
176
- density_outputs = self.density(xyzs.reshape(-1, 3))
177
-
178
- #sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T]
179
- for k, v in density_outputs.items():
180
- density_outputs[k] = v.view(N, num_steps, -1)
181
-
182
- # upsample z_vals (nerf-like)
183
- if upsample_steps > 0:
184
- with torch.no_grad():
185
-
186
- deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1]
187
- deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
188
-
189
- alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T]
190
- alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1]
191
- weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T]
192
-
193
- # sample new z_vals
194
- z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1]
195
- new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t]
196
-
197
- new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3]
198
- new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip.
199
-
200
- # only forward new points to save computation
201
- new_density_outputs = self.density(new_xyzs.reshape(-1, 3))
202
- #new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t]
203
- for k, v in new_density_outputs.items():
204
- new_density_outputs[k] = v.view(N, upsample_steps, -1)
205
-
206
- # re-order
207
- z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t]
208
- z_vals, z_index = torch.sort(z_vals, dim=1)
209
-
210
- xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3]
211
- xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs))
212
-
213
- for k in density_outputs:
214
- tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1)
215
- density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output))
216
-
217
- deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1]
218
- deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
219
- alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T+t]
220
- alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1]
221
- weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t]
222
-
223
- dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)
224
- for k, v in density_outputs.items():
225
- density_outputs[k] = v.view(-1, v.shape[-1])
226
-
227
- mask = weights > 1e-4 # hard coded
228
- rgbs = self.color(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), mask=mask.reshape(-1), **density_outputs)
229
- rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3]
230
-
231
- #print(xyzs.shape, 'valid_rgb:', mask.sum().item())
232
-
233
- # calculate weight_sum (mask)
234
- weights_sum = weights.sum(dim=-1) # [N]
235
-
236
- # calculate depth
237
- ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1)
238
- depth = torch.sum(weights * ori_z_vals, dim=-1)
239
-
240
- # calculate color
241
- image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1]
242
-
243
- # mix background color
244
- if self.bg_radius > 0:
245
- # use the bg model to calculate bg_color
246
- sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
247
- bg_color = self.background(sph, rays_d.reshape(-1, 3)) # [N, 3]
248
- elif bg_color is None:
249
- bg_color = 1
250
-
251
- image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
252
-
253
- image = image.view(*prefix, 3)
254
- depth = depth.view(*prefix)
255
-
256
- # tmp: reg loss in mip-nerf 360
257
- # z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1)
258
- # mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T]
259
- # loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum()
260
-
261
- return {
262
- 'depth': depth,
263
- 'image': image,
264
- 'weights_sum': weights_sum,
265
- }
266
-
267
-
268
- def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs):
269
- # rays_o, rays_d: [B, N, 3], assumes B == 1
270
- # return: image: [B, N, 3], depth: [B, N]
271
-
272
- prefix = rays_o.shape[:-1]
273
- rays_o = rays_o.contiguous().view(-1, 3)
274
- rays_d = rays_d.contiguous().view(-1, 3)
275
-
276
- N = rays_o.shape[0] # N = B * N, in fact
277
- device = rays_o.device
278
-
279
- # pre-calculate near far
280
- nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near)
281
-
282
- # mix background color
283
- if self.bg_radius > 0:
284
- # use the bg model to calculate bg_color
285
- sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
286
- bg_color = self.background(sph, rays_d) # [N, 3]
287
- elif bg_color is None:
288
- bg_color = 1
289
-
290
- results = {}
291
-
292
- if self.training:
293
- # setup counter
294
- counter = self.step_counter[self.local_step % 16]
295
- counter.zero_() # set to 0
296
- self.local_step += 1
297
-
298
- xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
299
-
300
- #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
301
-
302
- sigmas, rgbs = self(xyzs, dirs)
303
- sigmas = self.density_scale * sigmas
304
-
305
- weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh)
306
- image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
307
- depth = torch.clamp(depth - nears, min=0) / (fars - nears)
308
- image = image.view(*prefix, 3)
309
- depth = depth.view(*prefix)
310
-
311
- else:
312
-
313
- # allocate outputs
314
- # if use autocast, must init as half so it won't be autocasted and lose reference.
315
- #dtype = torch.half if torch.is_autocast_enabled() else torch.float32
316
- # output should always be float32! only network inference uses half.
317
- dtype = torch.float32
318
-
319
- weights_sum = torch.zeros(N, dtype=dtype, device=device)
320
- depth = torch.zeros(N, dtype=dtype, device=device)
321
- image = torch.zeros(N, 3, dtype=dtype, device=device)
322
-
323
- n_alive = N
324
- rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
325
- rays_t = nears.clone() # [N]
326
-
327
- step = 0
328
-
329
- while step < max_steps:
330
-
331
- # count alive rays
332
- n_alive = rays_alive.shape[0]
333
-
334
- # exit loop
335
- if n_alive <= 0:
336
- break
337
-
338
- # decide compact_steps
339
- n_step = max(min(N // n_alive, 8), 1)
340
-
341
- xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps)
342
-
343
- sigmas, rgbs = self(xyzs, dirs)
344
- # density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb.
345
- # sigmas = density_outputs['sigma']
346
- # rgbs = self.color(xyzs, dirs, **density_outputs)
347
- sigmas = self.density_scale * sigmas
348
-
349
- raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh)
350
-
351
- rays_alive = rays_alive[rays_alive >= 0]
352
-
353
- #print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')
354
-
355
- step += n_step
356
-
357
- image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
358
- depth = torch.clamp(depth - nears, min=0) / (fars - nears)
359
- image = image.view(*prefix, 3)
360
- depth = depth.view(*prefix)
361
-
362
- results['weights_sum'] = weights_sum
363
- results['depth'] = depth
364
- results['image'] = image
365
-
366
- return results
367
-
368
- @torch.no_grad()
369
- def mark_untrained_grid(self, poses, intrinsic, S=64):
370
- # poses: [B, 4, 4]
371
- # intrinsic: [3, 3]
372
-
373
- if not self.cuda_ray:
374
- return
375
-
376
- if isinstance(poses, np.ndarray):
377
- poses = torch.from_numpy(poses)
378
-
379
- B = poses.shape[0]
380
-
381
- fx, fy, cx, cy = intrinsic
382
-
383
- X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
384
- Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
385
- Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
386
-
387
- count = torch.zeros_like(self.density_grid)
388
- poses = poses.to(count.device)
389
-
390
- # 5-level loop, forgive me...
391
-
392
- for xs in X:
393
- for ys in Y:
394
- for zs in Z:
395
-
396
- # construct points
397
- xx, yy, zz = custom_meshgrid(xs, ys, zs)
398
- coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
399
- indices = raymarching.morton3D(coords).long() # [N]
400
- world_xyzs = (2 * coords.float() / (self.grid_size - 1) - 1).unsqueeze(0) # [1, N, 3] in [-1, 1]
401
-
402
- # cascading
403
- for cas in range(self.cascade):
404
- bound = min(2 ** cas, self.bound)
405
- half_grid_size = bound / self.grid_size
406
- # scale to current cascade's resolution
407
- cas_world_xyzs = world_xyzs * (bound - half_grid_size)
408
-
409
- # split batch to avoid OOM
410
- head = 0
411
- while head < B:
412
- tail = min(head + S, B)
413
-
414
- # world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.)
415
- cam_xyzs = cas_world_xyzs - poses[head:tail, :3, 3].unsqueeze(1)
416
- cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3]
417
-
418
- # query if point is covered by any camera
419
- mask_z = cam_xyzs[:, :, 2] > 0 # [S, N]
420
- mask_x = torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2
421
- mask_y = torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2
422
- mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N]
423
-
424
- # update count
425
- count[cas, indices] += mask
426
- head += S
427
-
428
- # mark untrained grid as -1
429
- self.density_grid[count == 0] = -1
430
-
431
- print(f'[mark untrained grid] {(count == 0).sum()} from {self.grid_size ** 3 * self.cascade}')
432
-
433
- @torch.no_grad()
434
- def update_extra_state(self, decay=0.95, S=128):
435
- # call before each epoch to update extra states.
436
-
437
- if not self.cuda_ray:
438
- return
439
-
440
- ### update density grid
441
- tmp_grid = - torch.ones_like(self.density_grid)
442
-
443
- # full update.
444
- if self.iter_density < 16:
445
- #if True:
446
- X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
447
- Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
448
- Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
449
-
450
- for xs in X:
451
- for ys in Y:
452
- for zs in Z:
453
-
454
- # construct points
455
- xx, yy, zz = custom_meshgrid(xs, ys, zs)
456
- coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
457
- indices = raymarching.morton3D(coords).long() # [N]
458
- xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
459
-
460
- # cascading
461
- for cas in range(self.cascade):
462
- bound = min(2 ** cas, self.bound)
463
- half_grid_size = bound / self.grid_size
464
- # scale to current cascade's resolution
465
- cas_xyzs = xyzs * (bound - half_grid_size)
466
- # add noise in [-hgs, hgs]
467
- cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
468
- # query density
469
- sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()
470
- sigmas *= self.density_scale
471
- # assign
472
- tmp_grid[cas, indices] = sigmas
473
-
474
- # partial update (half the computation)
475
- # TODO: why no need of maxpool ?
476
- else:
477
- N = self.grid_size ** 3 // 4 # H * H * H / 4
478
- for cas in range(self.cascade):
479
- # random sample some positions
480
- coords = torch.randint(0, self.grid_size, (N, 3), device=self.density_bitfield.device) # [N, 3], in [0, 128)
481
- indices = raymarching.morton3D(coords).long() # [N]
482
- # random sample occupied positions
483
- occ_indices = torch.nonzero(self.density_grid[cas] > 0).squeeze(-1) # [Nz]
484
- rand_mask = torch.randint(0, occ_indices.shape[0], [N], dtype=torch.long, device=self.density_bitfield.device)
485
- occ_indices = occ_indices[rand_mask] # [Nz] --> [N], allow for duplication
486
- occ_coords = raymarching.morton3D_invert(occ_indices) # [N, 3]
487
- # concat
488
- indices = torch.cat([indices, occ_indices], dim=0)
489
- coords = torch.cat([coords, occ_coords], dim=0)
490
- # same below
491
- xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
492
- bound = min(2 ** cas, self.bound)
493
- half_grid_size = bound / self.grid_size
494
- # scale to current cascade's resolution
495
- cas_xyzs = xyzs * (bound - half_grid_size)
496
- # add noise in [-hgs, hgs]
497
- cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
498
- # query density
499
- sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()
500
- sigmas *= self.density_scale
501
- # assign
502
- tmp_grid[cas, indices] = sigmas
503
-
504
- ## max-pool on tmp_grid for less aggressive culling [No significant improvement...]
505
- # invalid_mask = tmp_grid < 0
506
- # tmp_grid = F.max_pool3d(tmp_grid.view(self.cascade, 1, self.grid_size, self.grid_size, self.grid_size), kernel_size=3, stride=1, padding=1).view(self.cascade, -1)
507
- # tmp_grid[invalid_mask] = -1
508
-
509
- # ema update
510
- valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
511
- self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
512
- self.mean_density = torch.mean(self.density_grid.clamp(min=0)).item() # -1 regions are viewed as 0 density.
513
- #self.mean_density = torch.mean(self.density_grid[self.density_grid > 0]).item() # do not count -1 regions
514
- self.iter_density += 1
515
-
516
- # convert to bitfield
517
- density_thresh = min(self.mean_density, self.density_thresh)
518
- self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)
519
-
520
- ### update step counter
521
- total_step = min(16, self.local_step)
522
- if total_step > 0:
523
- self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step)
524
- self.local_step = 0
525
-
526
- #print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')
527
-
528
-
529
- def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs):
530
- # rays_o, rays_d: [B, N, 3], assumes B == 1
531
- # return: pred_rgb: [B, N, 3]
532
-
533
- if self.cuda_ray:
534
- _run = self.run_cuda
535
- else:
536
- _run = self.run
537
-
538
- results = _run(rays_o, rays_d, **kwargs)
539
- return results
540
-
541
-
542
-
543
- class _trunc_exp(Function):
544
- @staticmethod
545
- @custom_fwd(cast_inputs=torch.float32) # cast to float32
546
- def forward(ctx, x):
547
- ctx.save_for_backward(x)
548
- return torch.exp(x)
549
-
550
- @staticmethod
551
- @custom_bwd
552
- def backward(ctx, g):
553
- x = ctx.saved_tensors[0]
554
- return g * torch.exp(x.clamp(-15, 15))
555
-
556
- trunc_exp = _trunc_exp.apply
557
-
558
- class NGPNetwork(NGPRenderer):
559
- def __init__(self,
560
- num_layers=2,
561
- hidden_dim=64,
562
- geo_feat_dim=15,
563
- num_layers_color=3,
564
- hidden_dim_color=64,
565
- bound=0.5,
566
- max_resolution=128,
567
- base_resolution=16,
568
- n_levels=16,
569
- **kwargs
570
- ):
571
- super().__init__(bound, **kwargs)
572
-
573
- # sigma network
574
- self.num_layers = num_layers
575
- self.hidden_dim = hidden_dim
576
- self.geo_feat_dim = geo_feat_dim
577
- self.bound = bound
578
-
579
- log2_hashmap_size = 19
580
- n_features_per_level = 2
581
-
582
-
583
- per_level_scale = np.exp2(np.log2(max_resolution / base_resolution) / (n_levels - 1))
584
-
585
- self.encoder = tcnn.Encoding(
586
- n_input_dims=3,
587
- encoding_config={
588
- "otype": "HashGrid",
589
- "n_levels": n_levels,
590
- "n_features_per_level": n_features_per_level,
591
- "log2_hashmap_size": log2_hashmap_size,
592
- "base_resolution": base_resolution,
593
- "per_level_scale": per_level_scale,
594
- },
595
- )
596
-
597
- self.sigma_net = tcnn.Network(
598
- n_input_dims = n_levels * 2,
599
- n_output_dims=1 + self.geo_feat_dim,
600
- network_config={
601
- "otype": "FullyFusedMLP",
602
- "activation": "ReLU",
603
- "output_activation": "None",
604
- "n_neurons": hidden_dim,
605
- "n_hidden_layers": num_layers - 1,
606
- },
607
- )
608
-
609
- # color network
610
- self.num_layers_color = num_layers_color
611
- self.hidden_dim_color = hidden_dim_color
612
-
613
- self.encoder_dir = tcnn.Encoding(
614
- n_input_dims=3,
615
- encoding_config={
616
- "otype": "SphericalHarmonics",
617
- "degree": 4,
618
- },
619
- )
620
-
621
- self.in_dim_color = self.encoder_dir.n_output_dims + self.geo_feat_dim
622
-
623
- self.color_net = tcnn.Network(
624
- n_input_dims = self.in_dim_color,
625
- n_output_dims=3,
626
- network_config={
627
- "otype": "FullyFusedMLP",
628
- "activation": "ReLU",
629
- "output_activation": "None",
630
- "n_neurons": hidden_dim_color,
631
- "n_hidden_layers": num_layers_color - 1,
632
- },
633
- )
634
- self.density_scale, self.density_std = 10.0, 0.25
635
-
636
- def forward(self, x, d):
637
- # x: [N, 3], in [-bound, bound]
638
- # d: [N, 3], nomalized in [-1, 1]
639
-
640
-
641
- # sigma
642
- x_raw = x
643
- x = (x + self.bound) / (2 * self.bound) # to [0, 1]
644
- x = self.encoder(x)
645
- h = self.sigma_net(x)
646
-
647
- # sigma = F.relu(h[..., 0])
648
- density = h[..., 0]
649
- # add density bias
650
- dist = torch.norm(x_raw, dim=-1)
651
- density_bias = (1 - dist / self.density_std) * self.density_scale
652
- density = density_bias + density
653
- sigma = F.softplus(density)
654
- geo_feat = h[..., 1:]
655
-
656
- # color
657
- d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1]
658
- d = self.encoder_dir(d)
659
-
660
- # p = torch.zeros_like(geo_feat[..., :1]) # manual input padding
661
- h = torch.cat([d, geo_feat], dim=-1)
662
- h = self.color_net(h)
663
-
664
- # sigmoid activation for rgb
665
- color = torch.sigmoid(h)
666
-
667
- return sigma, color
668
-
669
- def density(self, x):
670
- # x: [N, 3], in [-bound, bound]
671
- x_raw = x
672
- x = (x + self.bound) / (2 * self.bound) # to [0, 1]
673
- x = self.encoder(x)
674
- h = self.sigma_net(x)
675
-
676
- # sigma = F.relu(h[..., 0])
677
- density = h[..., 0]
678
- # add density bias
679
- dist = torch.norm(x_raw, dim=-1)
680
- density_bias = (1 - dist / self.density_std) * self.density_scale
681
- density = density_bias + density
682
- sigma = F.softplus(density)
683
- geo_feat = h[..., 1:]
684
-
685
- return {
686
- 'sigma': sigma,
687
- 'geo_feat': geo_feat,
688
- }
689
-
690
- # allow masked inference
691
- def color(self, x, d, mask=None, geo_feat=None, **kwargs):
692
- # x: [N, 3] in [-bound, bound]
693
- # mask: [N,], bool, indicates where we actually needs to compute rgb.
694
-
695
- x = (x + self.bound) / (2 * self.bound) # to [0, 1]
696
-
697
- if mask is not None:
698
- rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3]
699
- # in case of empty mask
700
- if not mask.any():
701
- return rgbs
702
- x = x[mask]
703
- d = d[mask]
704
- geo_feat = geo_feat[mask]
705
-
706
- # color
707
- d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1]
708
- d = self.encoder_dir(d)
709
-
710
- h = torch.cat([d, geo_feat], dim=-1)
711
- h = self.color_net(h)
712
-
713
- # sigmoid activation for rgb
714
- h = torch.sigmoid(h)
715
-
716
- if mask is not None:
717
- rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32
718
- else:
719
- rgbs = h
720
-
721
- return rgbs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
renderer/renderer.py DELETED
@@ -1,604 +0,0 @@
1
- import abc
2
- import os
3
- from pathlib import Path
4
-
5
- import cv2
6
- import numpy as np
7
- import pytorch_lightning as pl
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- from omegaconf import OmegaConf
12
-
13
- from skimage.io import imread, imsave
14
- from PIL import Image
15
- from torch.optim.lr_scheduler import LambdaLR
16
-
17
- from ldm.base_utils import read_pickle, concat_images_list
18
- from renderer.neus_networks import SDFNetwork, RenderingNetwork, SingleVarianceNetwork, SDFHashGridNetwork, RenderingFFNetwork
19
- from renderer.ngp_renderer import NGPNetwork
20
- from ldm.util import instantiate_from_config
21
-
22
- DEFAULT_RADIUS = np.sqrt(3)/2
23
- DEFAULT_SIDE_LENGTH = 0.6
24
-
25
- def sample_pdf(bins, weights, n_samples, det=True):
26
- device = bins.device
27
- dtype = bins.dtype
28
- # This implementation is from NeRF
29
- # Get pdf
30
- weights = weights + 1e-5 # prevent nans
31
- pdf = weights / torch.sum(weights, -1, keepdim=True)
32
- cdf = torch.cumsum(pdf, -1)
33
- cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
34
- # Take uniform samples
35
- if det:
36
- u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples, dtype=dtype, device=device)
37
- u = u.expand(list(cdf.shape[:-1]) + [n_samples])
38
- else:
39
- u = torch.rand(list(cdf.shape[:-1]) + [n_samples], dtype=dtype, device=device)
40
-
41
- # Invert CDF
42
- u = u.contiguous()
43
- inds = torch.searchsorted(cdf, u, right=True)
44
- below = torch.max(torch.zeros_like(inds - 1), inds - 1)
45
- above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
46
- inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
47
-
48
- matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
49
- cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
50
- bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
51
-
52
- denom = (cdf_g[..., 1] - cdf_g[..., 0])
53
- denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
54
- t = (u - cdf_g[..., 0]) / denom
55
- samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
56
-
57
- return samples
58
-
59
- def near_far_from_sphere(rays_o, rays_d, radius=DEFAULT_RADIUS):
60
- a = torch.sum(rays_d ** 2, dim=-1, keepdim=True)
61
- b = torch.sum(rays_o * rays_d, dim=-1, keepdim=True)
62
- mid = -b / a
63
- near = mid - radius
64
- far = mid + radius
65
- return near, far
66
-
67
- class BackgroundRemoval:
68
- def __init__(self, device='cuda'):
69
- from carvekit.api.high import HiInterface
70
- self.interface = HiInterface(
71
- object_type="object", # Can be "object" or "hairs-like".
72
- batch_size_seg=5,
73
- batch_size_matting=1,
74
- device=device,
75
- seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
76
- matting_mask_size=2048,
77
- trimap_prob_threshold=231,
78
- trimap_dilation=30,
79
- trimap_erosion_iters=5,
80
- fp16=True,
81
- )
82
-
83
- @torch.no_grad()
84
- def __call__(self, image):
85
- # image: [H, W, 3] array in [0, 255].
86
- image = Image.fromarray(image)
87
- image = self.interface([image])[0]
88
- image = np.array(image)
89
- return image
90
-
91
-
92
- class BaseRenderer(nn.Module):
93
- def __init__(self, train_batch_num, test_batch_num):
94
- super().__init__()
95
- self.train_batch_num = train_batch_num
96
- self.test_batch_num = test_batch_num
97
-
98
- @abc.abstractmethod
99
- def render_impl(self, ray_batch, is_train, step):
100
- pass
101
-
102
- @abc.abstractmethod
103
- def render_with_loss(self, ray_batch, is_train, step):
104
- pass
105
-
106
- def render(self, ray_batch, is_train, step):
107
- batch_num = self.train_batch_num if is_train else self.test_batch_num
108
- ray_num = ray_batch['rays_o'].shape[0]
109
- outputs = {}
110
- for ri in range(0, ray_num, batch_num):
111
- cur_ray_batch = {}
112
- for k, v in ray_batch.items():
113
- cur_ray_batch[k] = v[ri:ri + batch_num]
114
- cur_outputs = self.render_impl(cur_ray_batch, is_train, step)
115
- for k, v in cur_outputs.items():
116
- if k not in outputs: outputs[k] = []
117
- outputs[k].append(v)
118
-
119
- for k, v in outputs.items():
120
- outputs[k] = torch.cat(v, 0)
121
- return outputs
122
-
123
-
124
- class NeuSRenderer(BaseRenderer):
125
- def __init__(self, train_batch_num, test_batch_num, lambda_eikonal_loss=0.1, use_mask=True,
126
- lambda_rgb_loss=1.0, lambda_mask_loss=0.0, rgb_loss='soft_l1', coarse_sn=64, fine_sn=64):
127
- super().__init__(train_batch_num, test_batch_num)
128
- self.n_samples = coarse_sn
129
- self.n_importance = fine_sn
130
- self.up_sample_steps = 4
131
- self.anneal_end = 200
132
- self.use_mask = use_mask
133
- self.lambda_eikonal_loss = lambda_eikonal_loss
134
- self.lambda_rgb_loss = lambda_rgb_loss
135
- self.lambda_mask_loss = lambda_mask_loss
136
- self.rgb_loss = rgb_loss
137
-
138
- self.sdf_network = SDFNetwork(d_out=257, d_in=3, d_hidden=256, n_layers=8, skip_in=[4], multires=6, bias=0.5, scale=1.0, geometric_init=True, weight_norm=True)
139
- self.color_network = RenderingNetwork(d_feature=256, d_in=9, d_out=3, d_hidden=256, n_layers=4, weight_norm=True, multires_view=4, squeeze_out=True)
140
- self.default_dtype = torch.float32
141
- self.deviation_network = SingleVarianceNetwork(0.3)
142
-
143
- @torch.no_grad()
144
- def get_vertex_colors(self, vertices):
145
- """
146
- @param vertices: n,3
147
- @return:
148
- """
149
- V = vertices.shape[0]
150
- bn = 20480
151
- verts_colors = []
152
- with torch.no_grad():
153
- for vi in range(0, V, bn):
154
- verts = torch.from_numpy(vertices[vi:vi+bn].astype(np.float32)).cuda()
155
- feats = self.sdf_network(verts)[..., 1:]
156
- gradients = self.sdf_network.gradient(verts) # ...,3
157
- gradients = F.normalize(gradients, dim=-1)
158
- colors = self.color_network(verts, gradients, gradients, feats)
159
- colors = torch.clamp(colors,min=0,max=1).cpu().numpy()
160
- verts_colors.append(colors)
161
-
162
- verts_colors = (np.concatenate(verts_colors, 0)*255).astype(np.uint8)
163
- return verts_colors
164
-
165
- def upsample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s):
166
- """
167
- Up sampling give a fixed inv_s
168
- """
169
- device = rays_o.device
170
- batch_size, n_samples = z_vals.shape
171
- pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
172
- inner_mask = self.get_inner_mask(pts)
173
- # radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
174
- inside_sphere = inner_mask[:, :-1] | inner_mask[:, 1:]
175
- sdf = sdf.reshape(batch_size, n_samples)
176
- prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
177
- prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
178
- mid_sdf = (prev_sdf + next_sdf) * 0.5
179
- cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
180
-
181
- prev_cos_val = torch.cat([torch.zeros([batch_size, 1], dtype=self.default_dtype, device=device), cos_val[:, :-1]], dim=-1)
182
- cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
183
- cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
184
- cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
185
-
186
- dist = (next_z_vals - prev_z_vals)
187
- prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
188
- next_esti_sdf = mid_sdf + cos_val * dist * 0.5
189
- prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
190
- next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
191
- alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
192
- weights = alpha * torch.cumprod(
193
- torch.cat([torch.ones([batch_size, 1], dtype=self.default_dtype, device=device), 1. - alpha + 1e-7], -1), -1)[:, :-1]
194
-
195
- z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
196
- return z_samples
197
-
198
- def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False):
199
- batch_size, n_samples = z_vals.shape
200
- _, n_importance = new_z_vals.shape
201
- pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
202
- z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
203
- z_vals, index = torch.sort(z_vals, dim=-1)
204
-
205
- if not last:
206
- device = pts.device
207
- new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
208
- sdf = torch.cat([sdf, new_sdf], dim=-1)
209
- xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1).to(device)
210
- index = index.reshape(-1)
211
- sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
212
-
213
- return z_vals, sdf
214
-
215
- def sample_depth(self, rays_o, rays_d, near, far, perturb):
216
- n_samples = self.n_samples
217
- n_importance = self.n_importance
218
- up_sample_steps = self.up_sample_steps
219
- device = rays_o.device
220
-
221
- # sample points
222
- batch_size = len(rays_o)
223
- z_vals = torch.linspace(0.0, 1.0, n_samples, dtype=self.default_dtype, device=device) # sn
224
- z_vals = near + (far - near) * z_vals[None, :] # rn,sn
225
-
226
- if perturb > 0:
227
- t_rand = (torch.rand([batch_size, 1]).to(device) - 0.5)
228
- z_vals = z_vals + t_rand * 2.0 / n_samples
229
-
230
- # Up sample
231
- with torch.no_grad():
232
- pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
233
- sdf = self.sdf_network.sdf(pts).reshape(batch_size, n_samples)
234
-
235
- for i in range(up_sample_steps):
236
- rn, sn = z_vals.shape
237
- inv_s = torch.ones(rn, sn - 1, dtype=self.default_dtype, device=device) * 64 * 2 ** i
238
- new_z_vals = self.upsample(rays_o, rays_d, z_vals, sdf, n_importance // up_sample_steps, inv_s)
239
- z_vals, sdf = self.cat_z_vals(rays_o, rays_d, z_vals, new_z_vals, sdf, last=(i + 1 == up_sample_steps))
240
-
241
- return z_vals
242
-
243
- def compute_sdf_alpha(self, points, dists, dirs, cos_anneal_ratio, step):
244
- # points [...,3] dists [...] dirs[...,3]
245
- sdf_nn_output = self.sdf_network(points)
246
- sdf = sdf_nn_output[..., 0]
247
- feature_vector = sdf_nn_output[..., 1:]
248
-
249
- gradients = self.sdf_network.gradient(points) # ...,3
250
- inv_s = self.deviation_network(points).clip(1e-6, 1e6) # ...,1
251
- inv_s = inv_s[..., 0]
252
-
253
- true_cos = (dirs * gradients).sum(-1) # [...]
254
- iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
255
- F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
256
-
257
- # Estimate signed distances at section points
258
- estimated_next_sdf = sdf + iter_cos * dists * 0.5
259
- estimated_prev_sdf = sdf - iter_cos * dists * 0.5
260
-
261
- prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
262
- next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
263
-
264
- p = prev_cdf - next_cdf
265
- c = prev_cdf
266
-
267
- alpha = ((p + 1e-5) / (c + 1e-5)).clip(0.0, 1.0) # [...]
268
- return alpha, gradients, feature_vector, inv_s, sdf
269
-
270
- def get_anneal_val(self, step):
271
- if self.anneal_end < 0:
272
- return 1.0
273
- else:
274
- return np.min([1.0, step / self.anneal_end])
275
-
276
- def get_inner_mask(self, points):
277
- return torch.sum(torch.abs(points)<=DEFAULT_SIDE_LENGTH,-1)==3
278
-
279
- def render_impl(self, ray_batch, is_train, step):
280
- near, far = near_far_from_sphere(ray_batch['rays_o'], ray_batch['rays_d'])
281
- rays_o, rays_d = ray_batch['rays_o'], ray_batch['rays_d']
282
- z_vals = self.sample_depth(rays_o, rays_d, near, far, is_train)
283
-
284
- batch_size, n_samples = z_vals.shape
285
-
286
- # section length in original space
287
- dists = z_vals[..., 1:] - z_vals[..., :-1] # rn,sn-1
288
- dists = torch.cat([dists, dists[..., -1:]], -1) # rn,sn
289
- mid_z_vals = z_vals + dists * 0.5
290
-
291
- points = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * mid_z_vals.unsqueeze(-1) # rn, sn, 3
292
- inner_mask = self.get_inner_mask(points)
293
-
294
- dirs = rays_d.unsqueeze(-2).expand(batch_size, n_samples, 3)
295
- dirs = F.normalize(dirs, dim=-1)
296
- device = rays_o.device
297
- alpha, sampled_color, gradient_error, normal = torch.zeros(batch_size, n_samples, dtype=self.default_dtype, device=device), \
298
- torch.zeros(batch_size, n_samples, 3, dtype=self.default_dtype, device=device), \
299
- torch.zeros([batch_size, n_samples], dtype=self.default_dtype, device=device), \
300
- torch.zeros([batch_size, n_samples, 3], dtype=self.default_dtype, device=device)
301
- if torch.sum(inner_mask) > 0:
302
- cos_anneal_ratio = self.get_anneal_val(step) if is_train else 1.0
303
- alpha[inner_mask], gradients, feature_vector, inv_s, sdf = self.compute_sdf_alpha(points[inner_mask], dists[inner_mask], dirs[inner_mask], cos_anneal_ratio, step)
304
- sampled_color[inner_mask] = self.color_network(points[inner_mask], gradients, -dirs[inner_mask], feature_vector)
305
- # Eikonal loss
306
- gradient_error[inner_mask] = (torch.linalg.norm(gradients, ord=2, dim=-1) - 1.0) ** 2 # rn,sn
307
- normal[inner_mask] = F.normalize(gradients, dim=-1)
308
-
309
- weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1], dtype=self.default_dtype, device=device), 1. - alpha + 1e-7], -1), -1)[..., :-1] # rn,sn
310
- mask = torch.sum(weights,dim=1).unsqueeze(-1) # rn,1
311
- color = (sampled_color * weights[..., None]).sum(dim=1) + (1 - mask) # add white background
312
- normal = (normal * weights[..., None]).sum(dim=1)
313
-
314
- outputs = {
315
- 'rgb': color, # rn,3
316
- 'gradient_error': gradient_error, # rn,sn
317
- 'inner_mask': inner_mask, # rn,sn
318
- 'normal': normal, # rn,3
319
- 'mask': mask, # rn,1
320
- }
321
- return outputs
322
-
323
- def render_with_loss(self, ray_batch, is_train, step):
324
- render_outputs = self.render(ray_batch, is_train, step)
325
-
326
- rgb_gt = ray_batch['rgb']
327
- rgb_pr = render_outputs['rgb']
328
- if self.rgb_loss == 'soft_l1':
329
- epsilon = 0.001
330
- rgb_loss = torch.sqrt(torch.sum((rgb_gt - rgb_pr) ** 2, dim=-1) + epsilon)
331
- elif self.rgb_loss =='mse':
332
- rgb_loss = F.mse_loss(rgb_pr, rgb_gt, reduction='none')
333
- else:
334
- raise NotImplementedError
335
- rgb_loss = torch.mean(rgb_loss)
336
-
337
- eikonal_loss = torch.sum(render_outputs['gradient_error'] * render_outputs['inner_mask']) / torch.sum(render_outputs['inner_mask'] + 1e-5)
338
- loss = rgb_loss * self.lambda_rgb_loss + eikonal_loss * self.lambda_eikonal_loss
339
- loss_batch = {
340
- 'eikonal': eikonal_loss,
341
- 'rendering': rgb_loss,
342
- # 'mask': mask_loss,
343
- }
344
- if self.lambda_mask_loss>0 and self.use_mask:
345
- mask_loss = F.mse_loss(render_outputs['mask'], ray_batch['mask'], reduction='none').mean()
346
- loss += mask_loss * self.lambda_mask_loss
347
- loss_batch['mask'] = mask_loss
348
- return loss, loss_batch
349
-
350
-
351
- class NeRFRenderer(BaseRenderer):
352
- def __init__(self, train_batch_num, test_batch_num, bound=0.5, use_mask=False, lambda_rgb_loss=1.0, lambda_mask_loss=0.0):
353
- super().__init__(train_batch_num, test_batch_num)
354
- self.train_batch_num = train_batch_num
355
- self.test_batch_num = test_batch_num
356
- self.use_mask = use_mask
357
- self.field = NGPNetwork(bound=bound)
358
-
359
- self.update_interval = 16
360
- self.fp16 = True
361
- self.lambda_rgb_loss = lambda_rgb_loss
362
- self.lambda_mask_loss = lambda_mask_loss
363
-
364
- def render_impl(self, ray_batch, is_train, step):
365
- rays_o, rays_d = ray_batch['rays_o'], ray_batch['rays_d']
366
- with torch.cuda.amp.autocast(enabled=self.fp16):
367
- if step % self.update_interval==0:
368
- self.field.update_extra_state()
369
-
370
- outputs = self.field.render(rays_o, rays_d,)
371
-
372
- renderings={
373
- 'rgb': outputs['image'],
374
- 'depth': outputs['depth'],
375
- 'mask': outputs['weights_sum'].unsqueeze(-1),
376
- }
377
- return renderings
378
-
379
- def render_with_loss(self, ray_batch, is_train, step):
380
- render_outputs = self.render(ray_batch, is_train, step)
381
-
382
- rgb_gt = ray_batch['rgb']
383
- rgb_pr = render_outputs['rgb']
384
- epsilon = 0.001
385
- rgb_loss = torch.sqrt(torch.sum((rgb_gt - rgb_pr) ** 2, dim=-1) + epsilon)
386
- rgb_loss = torch.mean(rgb_loss)
387
- loss = rgb_loss * self.lambda_rgb_loss
388
- loss_batch = {'rendering': rgb_loss}
389
-
390
- if self.use_mask:
391
- mask_loss = F.mse_loss(render_outputs['mask'], ray_batch['mask'], reduction='none')
392
- mask_loss = torch.mean(mask_loss)
393
- loss = loss + mask_loss * self.lambda_mask_loss
394
- loss_batch['mask'] = mask_loss
395
- return loss, loss_batch
396
-
397
-
398
- class RendererTrainer(pl.LightningModule):
399
- def __init__(self, image_path, total_steps, warm_up_steps, log_dir, train_batch_fg_num=0,
400
- use_cube_feats=False, cube_ckpt=None, cube_cfg=None, cube_bound=0.5,
401
- train_batch_num=4096, test_batch_num=8192, use_warm_up=True, use_mask=True,
402
- lambda_rgb_loss=1.0, lambda_mask_loss=0.0, renderer='neus',
403
- # used in neus
404
- lambda_eikonal_loss=0.1,
405
- coarse_sn=64, fine_sn=64):
406
- super().__init__()
407
- self.num_images = 16
408
- self.image_size = 256
409
- self.log_dir = log_dir
410
- (Path(log_dir)/'images').mkdir(exist_ok=True, parents=True)
411
- self.train_batch_num = train_batch_num
412
- self.train_batch_fg_num = train_batch_fg_num
413
- self.test_batch_num = test_batch_num
414
- self.image_path = image_path
415
- self.total_steps = total_steps
416
- self.warm_up_steps = warm_up_steps
417
- self.use_mask = use_mask
418
- self.lambda_eikonal_loss = lambda_eikonal_loss
419
- self.lambda_rgb_loss = lambda_rgb_loss
420
- self.lambda_mask_loss = lambda_mask_loss
421
- self.use_warm_up = use_warm_up
422
-
423
- self.use_cube_feats, self.cube_cfg, self.cube_ckpt = use_cube_feats, cube_cfg, cube_ckpt
424
-
425
- self._init_dataset()
426
- if renderer=='neus':
427
- self.renderer = NeuSRenderer(train_batch_num, test_batch_num,
428
- lambda_rgb_loss=lambda_rgb_loss,
429
- lambda_eikonal_loss=lambda_eikonal_loss,
430
- lambda_mask_loss=lambda_mask_loss,
431
- coarse_sn=coarse_sn, fine_sn=fine_sn)
432
- elif renderer=='ngp':
433
- self.renderer = NeRFRenderer(train_batch_num, test_batch_num, bound=cube_bound, use_mask=use_mask, lambda_mask_loss=lambda_mask_loss, lambda_rgb_loss=lambda_rgb_loss,)
434
- else:
435
- raise NotImplementedError
436
- self.validation_index = 0
437
-
438
- def _construct_ray_batch(self, images_info):
439
- image_num = images_info['images'].shape[0]
440
- _, h, w, _ = images_info['images'].shape
441
- coords = torch.stack(torch.meshgrid(torch.arange(h), torch.arange(w)), -1)[:, :, (1, 0)] # h,w,2
442
- coords = coords.float()[None, :, :, :].repeat(image_num, 1, 1, 1) # imn,h,w,2
443
- coords = coords.reshape(image_num, h * w, 2)
444
- coords = torch.cat([coords, torch.ones(image_num, h * w, 1, dtype=torch.float32)], 2) # imn,h*w,3
445
-
446
- # imn,h*w,3 @ imn,3,3 => imn,h*w,3
447
- rays_d = coords @ torch.inverse(images_info['Ks']).permute(0, 2, 1)
448
- poses = images_info['poses'] # imn,3,4
449
- R, t = poses[:, :, :3], poses[:, :, 3:]
450
- rays_d = rays_d @ R
451
- rays_d = F.normalize(rays_d, dim=-1)
452
- rays_o = -R.permute(0,2,1) @ t # imn,3,3 @ imn,3,1
453
- rays_o = rays_o.permute(0, 2, 1).repeat(1, h*w, 1) # imn,h*w,3
454
-
455
- ray_batch = {
456
- 'rgb': images_info['images'].reshape(image_num*h*w,3),
457
- 'mask': images_info['masks'].reshape(image_num*h*w,1),
458
- 'rays_o': rays_o.reshape(image_num*h*w,3).float(),
459
- 'rays_d': rays_d.reshape(image_num*h*w,3).float(),
460
- }
461
- return ray_batch
462
-
463
- @staticmethod
464
- def load_model(cfg, ckpt):
465
- config = OmegaConf.load(cfg)
466
- model = instantiate_from_config(config.model)
467
- print(f'loading model from {ckpt} ...')
468
- ckpt = torch.load(ckpt)
469
- model.load_state_dict(ckpt['state_dict'])
470
- model = model.cuda().eval()
471
- return model
472
-
473
- def _init_dataset(self):
474
- mask_predictor = BackgroundRemoval()
475
- self.K, self.azs, self.els, self.dists, self.poses = read_pickle(f'meta_info/camera-{self.num_images}.pkl')
476
-
477
- self.images_info = {'images': [] ,'masks': [], 'Ks': [], 'poses':[]}
478
-
479
- img = imread(self.image_path)
480
-
481
- for index in range(self.num_images):
482
- rgb = np.copy(img[:,index*self.image_size:(index+1)*self.image_size,:])
483
- # predict mask
484
- if self.use_mask:
485
- imsave(f'{self.log_dir}/input-{index}.png', rgb)
486
- masked_image = mask_predictor(rgb)
487
- imsave(f'{self.log_dir}/masked-{index}.png', masked_image)
488
- mask = masked_image[:,:,3].astype(np.float32)/255
489
- else:
490
- h, w, _ = rgb.shape
491
- mask = np.zeros([h,w], np.float32)
492
-
493
- rgb = rgb.astype(np.float32)/255
494
- K, pose = np.copy(self.K), self.poses[index]
495
- self.images_info['images'].append(torch.from_numpy(rgb.astype(np.float32))) # h,w,3
496
- self.images_info['masks'].append(torch.from_numpy(mask.astype(np.float32))) # h,w
497
- self.images_info['Ks'].append(torch.from_numpy(K.astype(np.float32)))
498
- self.images_info['poses'].append(torch.from_numpy(pose.astype(np.float32)))
499
-
500
- for k, v in self.images_info.items(): self.images_info[k] = torch.stack(v, 0) # stack all values
501
-
502
- self.train_batch = self._construct_ray_batch(self.images_info)
503
- self.train_batch_pseudo_fg = {}
504
- pseudo_fg_mask = torch.sum(self.train_batch['rgb']>0.99,1)!=3
505
- for k, v in self.train_batch.items():
506
- self.train_batch_pseudo_fg[k] = v[pseudo_fg_mask]
507
- self.train_ray_fg_num = int(torch.sum(pseudo_fg_mask).cpu().numpy())
508
- self.train_ray_num = self.num_images * self.image_size ** 2
509
- self._shuffle_train_batch()
510
- self._shuffle_train_fg_batch()
511
-
512
- def _shuffle_train_batch(self):
513
- self.train_batch_i = 0
514
- shuffle_idxs = torch.randperm(self.train_ray_num, device='cpu') # shuffle
515
- for k, v in self.train_batch.items():
516
- self.train_batch[k] = v[shuffle_idxs]
517
-
518
- def _shuffle_train_fg_batch(self):
519
- self.train_batch_fg_i = 0
520
- shuffle_idxs = torch.randperm(self.train_ray_fg_num, device='cpu') # shuffle
521
- for k, v in self.train_batch_pseudo_fg.items():
522
- self.train_batch_pseudo_fg[k] = v[shuffle_idxs]
523
-
524
-
525
- def training_step(self, batch, batch_idx):
526
- train_ray_batch = {k: v[self.train_batch_i:self.train_batch_i + self.train_batch_num].cuda() for k, v in self.train_batch.items()}
527
- self.train_batch_i += self.train_batch_num
528
- if self.train_batch_i + self.train_batch_num >= self.train_ray_num: self._shuffle_train_batch()
529
-
530
- if self.train_batch_fg_num>0:
531
- train_ray_batch_fg = {k: v[self.train_batch_fg_i:self.train_batch_fg_i+self.train_batch_fg_num].cuda() for k, v in self.train_batch_pseudo_fg.items()}
532
- self.train_batch_fg_i += self.train_batch_fg_num
533
- if self.train_batch_fg_i + self.train_batch_fg_num >= self.train_ray_fg_num: self._shuffle_train_fg_batch()
534
- for k, v in train_ray_batch_fg.items():
535
- train_ray_batch[k] = torch.cat([train_ray_batch[k], v], 0)
536
-
537
- loss, loss_batch = self.renderer.render_with_loss(train_ray_batch, is_train=True, step=self.global_step)
538
- self.log_dict(loss_batch, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True)
539
-
540
- self.log('step', self.global_step, prog_bar=True, on_step=True, on_epoch=False, logger=False, rank_zero_only=True)
541
- lr = self.optimizers().param_groups[0]['lr']
542
- self.log('lr', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True)
543
- return loss
544
-
545
- def _slice_images_info(self, index):
546
- return {k:v[index:index+1] for k, v in self.images_info.items()}
547
-
548
- @torch.no_grad()
549
- def validation_step(self, batch, batch_idx):
550
- with torch.no_grad():
551
- if self.global_rank==0:
552
- # we output an rendering image
553
- images_info = self._slice_images_info(self.validation_index)
554
- self.validation_index += 1
555
- self.validation_index %= self.num_images
556
-
557
- test_ray_batch = self._construct_ray_batch(images_info)
558
- test_ray_batch = {k: v.cuda() for k,v in test_ray_batch.items()}
559
- test_ray_batch['near'], test_ray_batch['far'] = near_far_from_sphere(test_ray_batch['rays_o'], test_ray_batch['rays_d'])
560
- render_outputs = self.renderer.render(test_ray_batch, False, self.global_step)
561
-
562
- process = lambda x: (x.cpu().numpy() * 255).astype(np.uint8)
563
- h, w = self.image_size, self.image_size
564
- rgb = torch.clamp(render_outputs['rgb'].reshape(h, w, 3), max=1.0, min=0.0)
565
- mask = torch.clamp(render_outputs['mask'].reshape(h, w, 1), max=1.0, min=0.0)
566
- mask_ = torch.repeat_interleave(mask, 3, dim=-1)
567
- output_image = concat_images_list(process(rgb), process(mask_))
568
- if 'normal' in render_outputs:
569
- normal = torch.clamp((render_outputs['normal'].reshape(h, w, 3) + 1) / 2, max=1.0, min=0.0)
570
- normal = normal * mask # we only show foregound normal
571
- output_image = concat_images_list(output_image, process(normal))
572
-
573
- # save images
574
- imsave(f'{self.log_dir}/images/{self.global_step}.jpg', output_image)
575
-
576
- def configure_optimizers(self):
577
- lr = self.learning_rate
578
- opt = torch.optim.AdamW([{"params": self.renderer.parameters(), "lr": lr},], lr=lr)
579
-
580
- def schedule_fn(step):
581
- total_step = self.total_steps
582
- warm_up_step = self.warm_up_steps
583
- warm_up_init = 0.02
584
- warm_up_end = 1.0
585
- final_lr = 0.02
586
- interval = 1000
587
- times = total_step // interval
588
- ratio = np.power(final_lr, 1/times)
589
- if step<warm_up_step:
590
- learning_rate = (step / warm_up_step) * (warm_up_end - warm_up_init) + warm_up_init
591
- else:
592
- learning_rate = ratio ** (step // interval) * warm_up_end
593
- return learning_rate
594
-
595
- if self.use_warm_up:
596
- scheduler = [{
597
- 'scheduler': LambdaLR(opt, lr_lambda=schedule_fn),
598
- 'interval': 'step',
599
- 'frequency': 1
600
- }]
601
- else:
602
- scheduler = []
603
- return [opt], scheduler
604
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -19,5 +19,4 @@ trimesh
19
  easydict
20
  nerfacc
21
  imageio-ffmpeg==0.4.7
22
- git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
23
  git+https://github.com/openai/CLIP.git
 
19
  easydict
20
  nerfacc
21
  imageio-ffmpeg==0.4.7
 
22
  git+https://github.com/openai/CLIP.git
train_renderer.py DELETED
@@ -1,187 +0,0 @@
1
- import argparse
2
-
3
- import imageio
4
- import numpy as np
5
- import torch
6
- import torch.nn.functional as F
7
- from pathlib import Path
8
-
9
- import trimesh
10
- from omegaconf import OmegaConf
11
- from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, Callback
12
- from pytorch_lightning.loggers import TensorBoardLogger
13
- from pytorch_lightning import Trainer
14
- from skimage.io import imsave
15
- from tqdm import tqdm
16
-
17
- import mcubes
18
-
19
- from ldm.base_utils import read_pickle, output_points
20
- from renderer.renderer import NeuSRenderer, DEFAULT_SIDE_LENGTH
21
- from ldm.util import instantiate_from_config
22
-
23
- class ResumeCallBacks(Callback):
24
- def __init__(self):
25
- pass
26
-
27
- def on_train_start(self, trainer, pl_module):
28
- pl_module.optimizers().param_groups = pl_module.optimizers()._optimizer.param_groups
29
-
30
- def render_images(model, output,):
31
- # render from model
32
- n = 180
33
- azimuths = (np.arange(n) / n * np.pi * 2).astype(np.float32)
34
- elevations = np.deg2rad(np.asarray([30] * n).astype(np.float32))
35
- K, _, _, _, poses = read_pickle(f'meta_info/camera-16.pkl')
36
- output_points
37
- h, w = 256, 256
38
- default_size = 256
39
- K = np.diag([w/default_size,h/default_size,1.0]) @ K
40
- imgs = []
41
- for ni in tqdm(range(n)):
42
- # R = euler2mat(azimuths[ni], elevations[ni], 0, 'szyx')
43
- # R = np.asarray([[0,-1,0],[0,0,-1],[1,0,0]]) @ R
44
- e, a = elevations[ni], azimuths[ni]
45
- row1 = np.asarray([np.sin(e)*np.cos(a),np.sin(e)*np.sin(a),-np.cos(e)])
46
- row0 = np.asarray([-np.sin(a),np.cos(a), 0])
47
- row2 = np.cross(row0, row1)
48
- R = np.stack([row0,row1,row2],0)
49
- t = np.asarray([0,0,1.5])
50
- pose = np.concatenate([R,t[:,None]],1)
51
- pose_ = torch.from_numpy(pose.astype(np.float32)).unsqueeze(0)
52
- K_ = torch.from_numpy(K.astype(np.float32)).unsqueeze(0) # [1,3,3]
53
-
54
- coords = torch.stack(torch.meshgrid(torch.arange(h), torch.arange(w)), -1)[:, :, (1, 0)] # h,w,2
55
- coords = coords.float()[None, :, :, :].repeat(1, 1, 1, 1) # imn,h,w,2
56
- coords = coords.reshape(1, h * w, 2)
57
- coords = torch.cat([coords, torch.ones(1, h * w, 1, dtype=torch.float32)], 2) # imn,h*w,3
58
-
59
- # imn,h*w,3 @ imn,3,3 => imn,h*w,3
60
- rays_d = coords @ torch.inverse(K_).permute(0, 2, 1)
61
- R, t = pose_[:, :, :3], pose_[:, :, 3:]
62
- rays_d = rays_d @ R
63
- rays_d = F.normalize(rays_d, dim=-1)
64
- rays_o = -R.permute(0, 2, 1) @ t # imn,3,3 @ imn,3,1
65
- rays_o = rays_o.permute(0, 2, 1).repeat(1, h * w, 1) # imn,h*w,3
66
-
67
- ray_batch = {
68
- 'rays_o': rays_o.reshape(-1,3).cuda(),
69
- 'rays_d': rays_d.reshape(-1,3).cuda(),
70
- }
71
- with torch.no_grad():
72
- image = model.renderer.render(ray_batch,False,5000)['rgb'].reshape(h,w,3)
73
- image = (image.cpu().numpy() * 255).astype(np.uint8)
74
- imgs.append(image)
75
-
76
- imageio.mimsave(f'{output}/rendering.mp4', imgs, fps=30)
77
-
78
- def extract_fields(bound_min, bound_max, resolution, query_func, batch_size=64, outside_val=1.0):
79
- N = batch_size
80
- X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
81
- Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
82
- Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
83
-
84
- u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
85
- with torch.no_grad():
86
- for xi, xs in enumerate(X):
87
- for yi, ys in enumerate(Y):
88
- for zi, zs in enumerate(Z):
89
- xx, yy, zz = torch.meshgrid(xs, ys, zs)
90
- pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).cuda()
91
- val = query_func(pts).detach()
92
- outside_mask = torch.norm(pts,dim=-1)>=1.0
93
- val[outside_mask]=outside_val
94
- val = val.reshape(len(xs), len(ys), len(zs)).cpu().numpy()
95
- u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
96
- return u
97
-
98
- def extract_geometry(bound_min, bound_max, resolution, threshold, query_func, color_func, outside_val=1.0):
99
- u = extract_fields(bound_min, bound_max, resolution, query_func, outside_val=outside_val)
100
- vertices, triangles = mcubes.marching_cubes(u, threshold)
101
- b_max_np = bound_max.detach().cpu().numpy()
102
- b_min_np = bound_min.detach().cpu().numpy()
103
-
104
- vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
105
- vertex_colors = color_func(vertices)
106
- return vertices, triangles, vertex_colors
107
-
108
- def extract_mesh(model, output, resolution=512):
109
- if not isinstance(model.renderer, NeuSRenderer): return
110
- bbox_min = -torch.ones(3)*DEFAULT_SIDE_LENGTH
111
- bbox_max = torch.ones(3)*DEFAULT_SIDE_LENGTH
112
- with torch.no_grad():
113
- vertices, triangles, vertex_colors = extract_geometry(bbox_min, bbox_max, resolution, 0, lambda x: model.renderer.sdf_network.sdf(x), lambda x: model.renderer.get_vertex_colors(x))
114
-
115
- # output geometry
116
- mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertex_colors)
117
- mesh.export(str(f'{output}/mesh.ply'))
118
-
119
- def main():
120
- parser = argparse.ArgumentParser()
121
- parser.add_argument('-i', '--image_path', type=str, required=True)
122
- parser.add_argument('-n', '--name', type=str, required=True)
123
- parser.add_argument('-b', '--base', type=str, default='configs/neus.yaml')
124
- parser.add_argument('-l', '--log', type=str, default='output/renderer')
125
- parser.add_argument('-s', '--seed', type=int, default=6033)
126
- parser.add_argument('-g', '--gpus', type=str, default='0,')
127
- parser.add_argument('-r', '--resume', action='store_true', default=False, dest='resume')
128
- parser.add_argument('--fp16', action='store_true', default=False, dest='fp16')
129
- opt = parser.parse_args()
130
- # seed_everything(opt.seed)
131
-
132
- # configs
133
- cfg = OmegaConf.load(opt.base)
134
- name = opt.name
135
- log_dir, ckpt_dir = Path(opt.log) / name, Path(opt.log) / name / 'ckpt'
136
- cfg.model.params['image_path'] = opt.image_path
137
- cfg.model.params['log_dir'] = log_dir
138
-
139
- # setup
140
- log_dir.mkdir(exist_ok=True, parents=True)
141
- ckpt_dir.mkdir(exist_ok=True, parents=True)
142
- trainer_config = cfg.trainer
143
- callback_config = cfg.callbacks
144
- model_config = cfg.model
145
- data_config = cfg.data
146
-
147
- data_config.params.seed = opt.seed
148
- data = instantiate_from_config(data_config)
149
- data.prepare_data()
150
- data.setup('fit')
151
-
152
- model = instantiate_from_config(model_config,)
153
- model.cpu()
154
- model.learning_rate = model_config.base_lr
155
-
156
- # logger
157
- logger = TensorBoardLogger(save_dir=log_dir, name='tensorboard_logs')
158
- callbacks=[]
159
- callbacks.append(LearningRateMonitor(logging_interval='step'))
160
- callbacks.append(ModelCheckpoint(dirpath=ckpt_dir, filename="{epoch:06}", verbose=True, save_last=True, every_n_train_steps=callback_config.save_interval))
161
-
162
- # trainer
163
- trainer_config.update({
164
- "accelerator": "cuda", "check_val_every_n_epoch": None,
165
- "benchmark": True, "num_sanity_val_steps": 0,
166
- "devices": 1, "gpus": opt.gpus,
167
- })
168
- if opt.fp16:
169
- trainer_config['precision']=16
170
-
171
- if opt.resume:
172
- callbacks.append(ResumeCallBacks())
173
- trainer_config['resume_from_checkpoint'] = str(ckpt_dir / 'last.ckpt')
174
- else:
175
- if (ckpt_dir / 'last.ckpt').exists():
176
- raise RuntimeError(f"checkpoint {ckpt_dir / 'last.ckpt'} existing ...")
177
- trainer = Trainer.from_argparse_args(args=argparse.Namespace(), **trainer_config, logger=logger, callbacks=callbacks)
178
-
179
- trainer.fit(model, data)
180
-
181
- model = model.cuda().eval()
182
-
183
- render_images(model, log_dir)
184
- extract_mesh(model, log_dir)
185
-
186
- if __name__=="__main__":
187
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_syncdreamer.py DELETED
@@ -1,307 +0,0 @@
1
- import argparse, os, sys
2
- import numpy as np
3
- import time
4
- import torch
5
- import torch.nn as nn
6
- import torchvision
7
- import pytorch_lightning as pl
8
-
9
- from omegaconf import OmegaConf
10
- from PIL import Image
11
-
12
- from pytorch_lightning import seed_everything
13
- from pytorch_lightning.strategies import DDPStrategy
14
- from pytorch_lightning.trainer import Trainer
15
- from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
16
- from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
17
-
18
- from ldm.util import instantiate_from_config
19
-
20
-
21
- @rank_zero_only
22
- def rank_zero_print(*args):
23
- print(*args)
24
-
25
- def get_parser(**parser_kwargs):
26
- def str2bool(v):
27
- if isinstance(v, bool):
28
- return v
29
- if v.lower() in ("yes", "true", "t", "y", "1"):
30
- return True
31
- elif v.lower() in ("no", "false", "f", "n", "0"):
32
- return False
33
- else:
34
- raise argparse.ArgumentTypeError("Boolean value expected.")
35
-
36
- parser = argparse.ArgumentParser(**parser_kwargs)
37
- parser.add_argument("-r", "--resume", dest='resume', action='store_true', default=False)
38
- parser.add_argument("-b", "--base", type=str, default='configs/syncdreamer-training.yaml',)
39
- parser.add_argument("-l", "--logdir", type=str, default="ckpt/logs", help="directory for logging data", )
40
- parser.add_argument("-c", "--ckptdir", type=str, default="ckpt/models", help="directory for checkpoint data", )
41
- parser.add_argument("-s", "--seed", type=int, default=6033, help="seed for seed_everything", )
42
- parser.add_argument("--finetune_from", type=str, default="/cfs-cq-dcc/rondyliu/models/sd-image-conditioned-v2.ckpt", help="path to checkpoint to load model state from" )
43
- parser.add_argument("--gpus", type=str, default='0,')
44
- return parser
45
-
46
- def trainer_args(opt):
47
- parser = argparse.ArgumentParser()
48
- parser = Trainer.add_argparse_args(parser)
49
- args = parser.parse_args([])
50
- return sorted(k for k in vars(args) if hasattr(opt, k))
51
-
52
- class SetupCallback(Callback):
53
- def __init__(self, resume, logdir, ckptdir, cfgdir, config):
54
- super().__init__()
55
- self.resume = resume
56
- self.logdir = logdir
57
- self.ckptdir = ckptdir
58
- self.cfgdir = cfgdir
59
- self.config = config
60
-
61
- def on_fit_start(self, trainer, pl_module):
62
- if trainer.global_rank == 0:
63
- # Create logdirs and save configs
64
- os.makedirs(self.logdir, exist_ok=True)
65
- os.makedirs(self.ckptdir, exist_ok=True)
66
- os.makedirs(self.cfgdir, exist_ok=True)
67
-
68
- rank_zero_print(OmegaConf.to_yaml(self.config))
69
- OmegaConf.save(self.config, os.path.join(self.cfgdir, "configs.yaml"))
70
-
71
- if not self.resume and os.path.exists(os.path.join(self.logdir,'checkpoints','last.ckpt')):
72
- raise RuntimeError(f"checkpoint {os.path.join(self.logdir,'checkpoints','last.ckpt')} existing")
73
-
74
- class ImageLogger(Callback):
75
- def __init__(self, batch_frequency, max_images, log_images_kwargs=None):
76
- super().__init__()
77
- self.batch_freq = batch_frequency
78
- self.max_images = max_images
79
- self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
80
-
81
- @rank_zero_only
82
- def log_to_logger(self, pl_module, images, split):
83
- for k in images:
84
- grid = torchvision.utils.make_grid(images[k])
85
- grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
86
-
87
- tag = f"{split}/{k}"
88
- pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step)
89
-
90
- @rank_zero_only
91
- def log_to_file(self, save_dir, split, images, global_step, current_epoch):
92
- root = os.path.join(save_dir, "images", split)
93
- for k in images:
94
- grid = torchvision.utils.make_grid(images[k], nrow=4)
95
- grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
96
- grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
97
- grid = grid.numpy()
98
- grid = (grid * 255).astype(np.uint8)
99
- filename = "{:06}-{:06}-{}.jpg".format(global_step, current_epoch, k)
100
- path = os.path.join(root, filename)
101
- os.makedirs(os.path.split(path)[0], exist_ok=True)
102
- Image.fromarray(grid).save(path)
103
-
104
- @rank_zero_only
105
- def log_img(self, pl_module, batch, split="train"):
106
- if split == "val": should_log = True
107
- else: should_log = self.check_frequency(pl_module.global_step)
108
-
109
- if should_log:
110
- is_train = pl_module.training
111
- if is_train: pl_module.eval()
112
-
113
- with torch.no_grad():
114
- images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
115
-
116
- for k in images:
117
- N = min(images[k].shape[0], self.max_images)
118
- images[k] = images[k][:N]
119
- if isinstance(images[k], torch.Tensor):
120
- images[k] = images[k].detach().cpu()
121
- images[k] = torch.clamp(images[k], -1., 1.)
122
-
123
- self.log_to_file(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch)
124
- # self.log_to_logger(pl_module, images, split)
125
-
126
- if is_train: pl_module.train()
127
-
128
- def check_frequency(self, check_idx):
129
- if (check_idx % self.batch_freq) == 0 and check_idx > 0:
130
- return True
131
- else:
132
- return False
133
-
134
- def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
135
- self.log_img(pl_module, batch, split="train")
136
-
137
- @rank_zero_only
138
- def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
139
- # print('validation ....')
140
- # print(dataloader_idx)
141
- # print(batch_idx)
142
- if batch_idx==0: self.log_img(pl_module, batch, split="val")
143
-
144
- class CUDACallback(Callback):
145
- # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
146
- def on_train_epoch_start(self, trainer, pl_module):
147
- # Reset the memory use counter
148
- torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
149
- torch.cuda.synchronize(trainer.strategy.root_device.index)
150
- self.start_time = time.time()
151
-
152
- def on_train_epoch_end(self, trainer, pl_module):
153
- torch.cuda.synchronize(trainer.strategy.root_device.index)
154
- max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2 ** 20
155
- epoch_time = time.time() - self.start_time
156
-
157
- try:
158
- max_memory = trainer.strategy.reduce(max_memory)
159
- epoch_time = trainer.strategy.reduce(epoch_time)
160
-
161
- rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
162
- rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
163
- except AttributeError:
164
- pass
165
-
166
- def get_node_name(name, parent_name):
167
- if len(name) <= len(parent_name):
168
- return False, ''
169
- p = name[:len(parent_name)]
170
- if p != parent_name:
171
- return False, ''
172
- return True, name[len(parent_name):]
173
-
174
- class ResumeCallBacks(Callback):
175
- def on_train_start(self, trainer, pl_module):
176
- pl_module.optimizers().param_groups = pl_module.optimizers()._optimizer.param_groups
177
-
178
- def load_pretrain_stable_diffusion(new_model, finetune_from):
179
- rank_zero_print(f"Attempting to load state from {finetune_from}")
180
- old_state = torch.load(finetune_from, map_location="cpu")
181
- if "state_dict" in old_state: old_state = old_state["state_dict"]
182
-
183
- in_filters_load = old_state["model.diffusion_model.input_blocks.0.0.weight"]
184
- new_state = new_model.state_dict()
185
- if "model.diffusion_model.input_blocks.0.0.weight" in new_state:
186
- in_filters_current = new_state["model.diffusion_model.input_blocks.0.0.weight"]
187
- in_shape = in_filters_current.shape
188
- ## because the model adopts additional inputs as conditions.
189
- if in_shape != in_filters_load.shape:
190
- input_keys = ["model.diffusion_model.input_blocks.0.0.weight", "model_ema.diffusion_modelinput_blocks00weight",]
191
- for input_key in input_keys:
192
- if input_key not in old_state or input_key not in new_state:
193
- continue
194
- input_weight = new_state[input_key]
195
- if input_weight.size() != old_state[input_key].size():
196
- print(f"Manual init: {input_key}")
197
- input_weight.zero_()
198
- input_weight[:, :4, :, :].copy_(old_state[input_key])
199
- old_state[input_key] = torch.nn.parameter.Parameter(input_weight)
200
-
201
- new_model.load_state_dict(old_state, strict=False)
202
-
203
- def get_optional_dict(name, config):
204
- if name in config:
205
- cfg = config[name]
206
- else:
207
- cfg = OmegaConf.create()
208
- return cfg
209
-
210
- if __name__ == "__main__":
211
- # now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
212
- sys.path.append(os.getcwd())
213
- opt = get_parser().parse_args()
214
-
215
- assert opt.base != ''
216
- name = os.path.split(opt.base)[-1]
217
- name = os.path.splitext(name)[0]
218
- logdir = os.path.join(opt.logdir, name)
219
-
220
- # logdir: checkpoints+configs
221
- ckptdir = os.path.join(opt.ckptdir, name)
222
- cfgdir = os.path.join(logdir, "configs")
223
-
224
- if opt.resume:
225
- ckpt = os.path.join(ckptdir, "last.ckpt")
226
- opt.resume_from_checkpoint = ckpt
227
- opt.finetune_from = "" # disable finetune checkpoint
228
-
229
- seed_everything(opt.seed)
230
-
231
- ###################config#####################
232
- config = OmegaConf.load(opt.base) # loade default configs
233
- lightning_config = config.lightning
234
- trainer_config = config.lightning.trainer
235
- for k in trainer_args(opt): # overwrite trainer configs
236
- trainer_config[k] = getattr(opt, k)
237
-
238
- ###################trainer#####################
239
- # training framework
240
- gpuinfo = trainer_config["gpus"]
241
- rank_zero_print(f"Running on GPUs {gpuinfo}")
242
- ngpu = len(trainer_config.gpus.strip(",").split(','))
243
- trainer_config['devices'] = ngpu
244
-
245
- ###################model#####################
246
- model = instantiate_from_config(config.model)
247
- model.cpu()
248
- # load stable diffusion parameters
249
- if opt.finetune_from != "":
250
- load_pretrain_stable_diffusion(model, opt.finetune_from)
251
-
252
- ###################logger#####################
253
- # default logger configs
254
- default_logger_cfg = {"target": "pytorch_lightning.loggers.TensorBoardLogger",
255
- "params": {"save_dir": logdir, "name": "tensorboard_logs", }}
256
- logger_cfg = OmegaConf.create(default_logger_cfg)
257
- logger = instantiate_from_config(logger_cfg)
258
-
259
- ###################callbacks#####################
260
- # default ckpt callbacks
261
- default_modelckpt_cfg = {"target": "pytorch_lightning.callbacks.ModelCheckpoint",
262
- "params": {"dirpath": ckptdir, "filename": "{epoch:06}", "verbose": True, "save_last": True, "every_n_train_steps": 5000}}
263
- modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, get_optional_dict("modelcheckpoint", lightning_config)) # overwrite checkpoint configs
264
- default_modelckpt_cfg_repeat = {"target": "pytorch_lightning.callbacks.ModelCheckpoint",
265
- "params": {"dirpath": ckptdir, "filename": "{step:08}", "verbose": True, "save_last": False, "every_n_train_steps": 5000, "save_top_k": -1}}
266
- modelckpt_cfg_repeat = OmegaConf.merge(default_modelckpt_cfg_repeat)
267
-
268
- # add callback which sets up log directory
269
- default_callbacks_cfg = {
270
- "setup_callback": {
271
- "target": "train_syncdreamer.SetupCallback",
272
- "params": {"resume": opt.resume, "logdir": logdir, "ckptdir": ckptdir, "cfgdir": cfgdir, "config": config}
273
- },
274
- "learning_rate_logger": {
275
- "target": "train_syncdreamer.LearningRateMonitor",
276
- "params": {"logging_interval": "step"}
277
- },
278
- "cuda_callback": {"target": "train_syncdreamer.CUDACallback"},
279
- }
280
- callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, get_optional_dict("callbacks", lightning_config))
281
- callbacks_cfg['model_ckpt'] = modelckpt_cfg # add checkpoint
282
- callbacks_cfg['model_ckpt_repeat'] = modelckpt_cfg_repeat # add checkpoint
283
- callbacks = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] # construct all callbacks
284
- if opt.resume:
285
- callbacks.append(ResumeCallBacks())
286
-
287
- trainer = Trainer.from_argparse_args(args=argparse.Namespace(), **trainer_config,
288
- accelerator='cuda', strategy=DDPStrategy(find_unused_parameters=False), logger=logger, callbacks=callbacks)
289
- trainer.logdir = logdir
290
-
291
- ###################data#####################
292
- config.data.params.seed = opt.seed
293
- data = instantiate_from_config(config.data)
294
- data.prepare_data()
295
- data.setup('fit')
296
-
297
- ####################lr#####################
298
- bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
299
- accumulate_grad_batches = trainer_config.accumulate_grad_batches if hasattr(trainer_config, "trainer_config") else 1
300
- rank_zero_print(f"accumulate_grad_batches = {accumulate_grad_batches}")
301
- model.learning_rate = base_lr
302
- rank_zero_print("++++ NOT USING LR SCALING ++++")
303
- rank_zero_print(f"Setting learning rate to {model.learning_rate:.2e}")
304
- model.image_dir = logdir # used in output images during training
305
-
306
- # run
307
- trainer.fit(model, data)