SerdarHelli commited on
Commit
b440279
1 Parent(s): b1809a9

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +252 -0
  2. requirements.txt +19 -0
app.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import re
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import click
7
+ import dnnlib
8
+ import numpy as np
9
+ import PIL.Image
10
+ import torch
11
+ from tqdm import tqdm
12
+
13
+
14
+ import legacy
15
+ from camera_utils import LookAtPoseSampler
16
+ from huggingface_hub import hf_hub_download
17
+
18
+ from matplotlib import pyplot as plt
19
+
20
+ from pathlib import Path
21
+
22
+ import json
23
+ import gradio as gr
24
+
25
+ from training.utils import color_mask, color_list
26
+ import plotly.graph_objects as go
27
+ from tqdm import tqdm
28
+
29
+ import imageio
30
+
31
+ import argparse
32
+
33
+ import trimesh
34
+ import pyrender
35
+ import mcubes
36
+
37
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
38
+
39
+
40
+ def get_sigma_field_np(nerf, styles, resolution=512, block_resolution=64):
41
+ # return numpy array of forwarded sigma value
42
+ # bound = (nerf.rendering_kwargs['ray_end'] - nerf.rendering_kwargs['ray_start']) * 0.5
43
+ bound = nerf.rendering_kwargs['box_warp'] * 0.5
44
+ X = torch.linspace(-bound, bound, resolution).split(block_resolution)
45
+
46
+ sigma_np = np.zeros([resolution, resolution, resolution], dtype=np.float32)
47
+
48
+ for xi, xs in enumerate(X):
49
+ for yi, ys in enumerate(X):
50
+ for zi, zs in enumerate(X):
51
+ xx, yy, zz = torch.meshgrid(xs, ys, zs)
52
+ pts = torch.stack([xx, yy, zz], dim=-1).unsqueeze(0).to(styles.device) # B, H, H, H, C
53
+ block_shape = [1, len(xs), len(ys), len(zs)]
54
+ out = nerf.sample_mixed(pts.reshape(1,-1,3), None, ws=styles, noise_mode='const')
55
+ feat_out, sigma_out = out['rgb'], out['sigma']
56
+ sigma_np[xi * block_resolution: xi * block_resolution + len(xs), \
57
+ yi * block_resolution: yi * block_resolution + len(ys), \
58
+ zi * block_resolution: zi * block_resolution + len(zs)] = sigma_out.reshape(block_shape[1:]).detach().cpu().numpy()
59
+ # print(feat_out.shape)
60
+
61
+ return sigma_np, bound
62
+
63
+
64
+ def extract_geometry(nerf, styles, resolution, threshold):
65
+
66
+ # print('threshold: {}'.format(threshold))
67
+ u, bound = get_sigma_field_np(nerf, styles, resolution)
68
+ vertices, faces = mcubes.marching_cubes(u, threshold)
69
+ # vertices, faces, normals, values = skimage.measure.marching_cubes(
70
+ # u, level=10
71
+ # )
72
+ b_min_np = np.array([-bound, -bound, -bound])
73
+ b_max_np = np.array([ bound, bound, bound])
74
+
75
+ vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
76
+ return vertices.astype('float32'), faces
77
+ def render_video(G, ws, intrinsics, num_frames = 120, pitch_range = 0.25, yaw_range = 0.35, neural_rendering_resolution = 128, device='cuda'):
78
+ frames, frames_label = [], []
79
+
80
+ for frame_idx in tqdm(range(num_frames)):
81
+ cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / num_frames),
82
+ 3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / num_frames),
83
+ torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=G.rendering_kwargs['avg_camera_radius'], device=device)
84
+ pose = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
85
+ with torch.no_grad():
86
+ # out = G(z, pose, {'mask': batch['mask'].unsqueeze(0).to(device), 'pose': torch.tensor(batch['pose']).unsqueeze(0).to(device)})
87
+ out = G.synthesis(ws, pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution)
88
+ frames.append(((out['image'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8).transpose(1, 2, 0))
89
+ frames_label.append(color_mask(torch.argmax(out['semantic'], dim=1).cpu().numpy()[0]).astype(np.uint8))
90
+
91
+ return frames, frames_label
92
+
93
+ def return_plot_go(mesh_trimesh):
94
+ x=np.asarray(mesh_trimesh.vertices).T[0]
95
+ y=np.asarray(mesh_trimesh.vertices).T[1]
96
+ z=np.asarray(mesh_trimesh.vertices).T[2]
97
+
98
+ i=np.asarray(mesh_trimesh.faces).T[0]
99
+ j=np.asarray(mesh_trimesh.faces).T[1]
100
+ k=np.asarray(mesh_trimesh.faces).T[2]
101
+ fig = go.Figure(go.Mesh3d(x=x, y=y, z=z,
102
+ i=i, j=j, k=k,
103
+ vertexcolor=np.asarray(mesh_trimesh.visual.vertex_colors) ,
104
+ lighting=dict(ambient=0.5,
105
+ diffuse=1,
106
+ fresnel=4,
107
+ specular=0.5,
108
+ roughness=0.05,
109
+ facenormalsepsilon=0,
110
+ vertexnormalsepsilon=0),
111
+ lightposition=dict(x=100,
112
+ y=100,
113
+ z=1000)))
114
+ return fig
115
+
116
+
117
+
118
+ network_cat=hf_hub_download("SerdarHelli/pix2pix3d_seg2cat", filename="pix2pix3d_seg2cat.pkl",revision="main")
119
+
120
+ models={"seg2cat":network_cat
121
+ }
122
+
123
+ device='cuda' if torch.cuda.is_available() else 'cpu'
124
+ outdir="/content/"
125
+
126
+ def get_all(cfg,input,truncation_psi,mesh_resolution,random_seed,fps,num_frames):
127
+
128
+ newtork=models[cfg]
129
+
130
+ with dnnlib.util.open_url(network) as f:
131
+ G = legacy.load_network_pkl(f)['G_ema'].eval().to(device)
132
+
133
+ if cfg == 'seg2cat' or cfg == 'seg2face':
134
+ neural_rendering_resolution = 128
135
+ data_type = 'seg'
136
+ # Initialize pose sampler.
137
+ forward_cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device),
138
+ radius=G.rendering_kwargs['avg_camera_radius'], device=device)
139
+ focal_length = 4.2647 # shapenet has higher FOV
140
+ intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
141
+ forward_pose = torch.cat([forward_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
142
+ elif cfg == 'edge2car':
143
+ neural_rendering_resolution = 64
144
+ data_type= 'edge'
145
+ else:
146
+ print('Invalid cfg')
147
+
148
+ save_dir = Path(outdir)
149
+
150
+ input_label = PIL.Image.open(input)
151
+ input_label = PIL.ImageOps.grayscale(input_label)
152
+ input_label = np.asarray(input_label).astype(np.uint8)
153
+ input_label = torch.from_numpy(input_label).unsqueeze(0).unsqueeze(0).to(device)
154
+ print(input_label.shape)
155
+ input_pose = forward_pose.to(device)
156
+
157
+ # Generate videos
158
+ z = torch.from_numpy(np.random.RandomState(int(0)).randn(1, G.z_dim).astype('float32')).to(device)
159
+
160
+ with torch.no_grad():
161
+ ws = G.mapping(z, input_pose, {'mask': input_label, 'pose': input_pose})
162
+ out = G.synthesis(ws, input_pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution)
163
+
164
+ image_color = ((out['image'][0].permute(1, 2, 0).cpu().numpy().clip(-1, 1) + 1) * 127.5).astype(np.uint8)
165
+ image_seg = color_mask(torch.argmax(out['semantic'][0], dim=0).cpu().numpy()).astype(np.uint8)
166
+ mesh_trimesh = trimesh.Trimesh(*extract_geometry(G, ws, resolution=mesh_resolution, threshold=50.))
167
+
168
+ verts_np = np.array(mesh_trimesh.vertices)
169
+ colors = torch.zeros((verts_np.shape[0], 3), device=device)
170
+ semantic_colors = torch.zeros((verts_np.shape[0], 6), device=device)
171
+ samples_color = torch.tensor(verts_np, device=device).unsqueeze(0).float()
172
+
173
+ head = 0
174
+ max_batch = 10000000
175
+ with tqdm(total = verts_np.shape[0]) as pbar:
176
+ with torch.no_grad():
177
+ while head < verts_np.shape[0]:
178
+ torch.manual_seed(0)
179
+ out = G.sample_mixed(samples_color[:, head:head+max_batch], None, ws, truncation_psi=truncation_psi, noise_mode='const')
180
+ # sigma = out['sigma']
181
+ colors[head:head+max_batch, :] = out['rgb'][0,:,:3]
182
+ seg = out['rgb'][0, :, 32:32+6]
183
+ semantic_colors[head:head+max_batch, :] = seg
184
+ # semantics[:, head:head+max_batch] = out['semantic']
185
+ head += max_batch
186
+ pbar.update(max_batch)
187
+
188
+ semantic_colors = torch.tensor(color_list,device=device)[torch.argmax(semantic_colors, dim=-1)]
189
+
190
+ mesh_trimesh.visual.vertex_colors = semantic_colors.cpu().numpy().astype(np.uint8)
191
+ frames, frames_label = render_video(G, ws, intrinsics, num_frames = num_frames, pitch_range = 0.25, yaw_range = 0.35, neural_rendering_resolution=neural_rendering_resolution, device=device)
192
+
193
+ # Save the video
194
+ video=save_dir / f'{cfg}_color.mp4'
195
+ video_label=save_dir / f'{cfg}_label.mp4'
196
+ imageio.mimsave(video, frames, fps=fps)
197
+ imageio.mimsave(video_label, frames_label, fps=fps),
198
+ fig_mesh=return_plot_go(mesh_trimesh)
199
+ return fig_mesh,image_color,image_seg,video,video_label
200
+
201
+ markdown=f'''
202
+ # 3D-aware Conditional Image Synthesis
203
+
204
+ [Arxiv: "3D-aware Conditional Image Synthesis".](https://arxiv.org/abs/2302.08509)
205
+ [Project Page.](https://www.cs.cmu.edu/~pix2pix3D/)
206
+ [For the official implementation.](https://github.com/dunbar12138/pix2pix3D)
207
+
208
+ ### Future Work based on interest
209
+ - Adding new models for new type objects
210
+ - New Customization
211
+
212
+
213
+ It is running on {device}
214
+ The process can take long time.Especially ,To generate videos and the time of process depends the number of frames,Mesh Resolution and current compiler device.
215
+
216
+ '''
217
+
218
+
219
+ with gr.Blocks() as demo:
220
+ gr.Markdown(markdown)
221
+ with gr.Row():
222
+ with gr.Column():
223
+ input=gr.Image(type="filepath",shape=(512, 512))
224
+ with gr.Column():
225
+ cfg=gr.Dropdown(choices=["seg2cat"],label="Choose Model",value="seg2cat")
226
+ truncation_psi = gr.Slider( minimum=0, maximum=2,label='Truncation PSI',value=1)
227
+ mesh_resolution = gr.Slider( minimum=32, maximum=512,label='Mesh Resolution',value=32)
228
+ random_seed = gr.Slider( minimum=0, maximum=2**16,label='Seed',value=128)
229
+ fps = gr.Slider( minimum=10, maximum=120,label='FPS',value=30)
230
+ num_frames = gr.Slider( minimum=10, maximum=120,label='The Number of Frames',value=30)
231
+
232
+ with gr.Row():
233
+ btn = gr.Button(value="Generate")
234
+
235
+ with gr.Row():
236
+ with gr.Column():
237
+ image_color=gr.Image(type="pil",shape=(256,256))
238
+ with gr.Column():
239
+ image_label=gr.Image(type="pil",shape=(256,256))
240
+ with gr.Row():
241
+ mesh = gr.Plot()
242
+ with gr.Row():
243
+ with gr.Column():
244
+ video_color=gr.Video()
245
+ with gr.Column():
246
+ video_label=gr.Video()
247
+
248
+
249
+
250
+ btn.click(get_all, [cfg,input,truncation_psi,mesh_resolution,random_seed,fps,num_frames],[ mesh,image_color,image_label,video_color,video_label])
251
+
252
+ demo.launch(debug=True,share=True)
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ trimesh
3
+ pyrender
4
+ PyMCubes
5
+ pycollada
6
+ einops
7
+ ninja
8
+ imageio-ffmpeg
9
+ imgui==1.3.0
10
+ glfw==2.2.0
11
+ pyopengl==3.1.5
12
+ pyspng
13
+ psutil
14
+ mrcfile
15
+ opencv-python
16
+ tqdm
17
+ scipy
18
+ pillow
19
+ numpy