Pix2Pix3D / app.py
SerdarHelli's picture
Update app.py
f80422c
raw
history blame contribute delete
No virus
13.6 kB
import sys
import os
os.system("git clone https://github.com/dunbar12138/pix2pix3D.git")
sys.path.append("pix2pix3D")
from typing import List, Optional, Tuple, Union
import dnnlib
import numpy as np
import PIL.Image
import torch
from tqdm import tqdm
import legacy
from camera_utils import LookAtPoseSampler
from huggingface_hub import hf_hub_download
from matplotlib import pyplot as plt
from pathlib import Path
import gradio as gr
from training.utils import color_mask, color_list
import plotly.graph_objects as go
from tqdm import tqdm
import imageio
import trimesh
import mcubes
import copy
import pickle
import numpy as np
import torch
import dnnlib
from torch_utils import misc
from legacy import *
import io
os.environ["PYOPENGL_PLATFORM"] = "egl"
def get_sigma_field_np(nerf, styles, resolution=512, block_resolution=64):
# return numpy array of forwarded sigma value
# bound = (nerf.rendering_kwargs['ray_end'] - nerf.rendering_kwargs['ray_start']) * 0.5
bound = nerf.rendering_kwargs['box_warp'] * 0.5
X = torch.linspace(-bound, bound, resolution).split(block_resolution)
sigma_np = np.zeros([resolution, resolution, resolution], dtype=np.float32)
for xi, xs in enumerate(X):
for yi, ys in enumerate(X):
for zi, zs in enumerate(X):
xx, yy, zz = torch.meshgrid(xs, ys, zs)
pts = torch.stack([xx, yy, zz], dim=-1).unsqueeze(0).to(styles.device) # B, H, H, H, C
block_shape = [1, len(xs), len(ys), len(zs)]
out = nerf.sample_mixed(pts.reshape(1,-1,3), None, ws=styles, noise_mode='const')
feat_out, sigma_out = out['rgb'], out['sigma']
sigma_np[xi * block_resolution: xi * block_resolution + len(xs), \
yi * block_resolution: yi * block_resolution + len(ys), \
zi * block_resolution: zi * block_resolution + len(zs)] = sigma_out.reshape(block_shape[1:]).detach().cpu().numpy()
# print(feat_out.shape)
return sigma_np, bound
def extract_geometry(nerf, styles, resolution, threshold):
# print('threshold: {}'.format(threshold))
u, bound = get_sigma_field_np(nerf, styles, resolution)
vertices, faces = mcubes.marching_cubes(u, threshold)
# vertices, faces, normals, values = skimage.measure.marching_cubes(
# u, level=10
# )
b_min_np = np.array([-bound, -bound, -bound])
b_max_np = np.array([ bound, bound, bound])
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
return vertices.astype('float32'), faces
def render_video(G, ws, intrinsics, num_frames = 120, pitch_range = 0.25, yaw_range = 0.35, neural_rendering_resolution = 128, device='cuda'):
frames, frames_label = [], []
for frame_idx in tqdm(range(num_frames)):
cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / num_frames),
3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / num_frames),
torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=G.rendering_kwargs['avg_camera_radius'], device=device)
pose = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
with torch.no_grad():
# out = G(z, pose, {'mask': batch['mask'].unsqueeze(0).to(device), 'pose': torch.tensor(batch['pose']).unsqueeze(0).to(device)})
out = G.synthesis(ws, pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution)
frames.append(((out['image'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8).transpose(1, 2, 0))
frames_label.append(color_mask(torch.argmax(out['semantic'], dim=1).cpu().numpy()[0]).astype(np.uint8))
return frames, frames_label
def return_plot_go(mesh_trimesh):
x=np.asarray(mesh_trimesh.vertices).T[0]
y=np.asarray(mesh_trimesh.vertices).T[1]
z=np.asarray(mesh_trimesh.vertices).T[2]
i=np.asarray(mesh_trimesh.faces).T[0]
j=np.asarray(mesh_trimesh.faces).T[1]
k=np.asarray(mesh_trimesh.faces).T[2]
fig = go.Figure(go.Mesh3d(x=x, y=y, z=z,
i=i, j=j, k=k,
vertexcolor=np.asarray(mesh_trimesh.visual.vertex_colors) ,
lighting=dict(ambient=0.5,
diffuse=1,
fresnel=4,
specular=0.5,
roughness=0.05,
facenormalsepsilon=0,
vertexnormalsepsilon=0),
lightposition=dict(x=100,
y=100,
z=1000)))
return fig
network_cat=hf_hub_download("SerdarHelli/pix2pix3d_seg2cat", filename="pix2pix3d_seg2cat.pkl",revision="main")
models={"seg2cat":network_cat
}
device='cuda' if torch.cuda.is_available() else 'cpu'
outdir="./"
class CPU_Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == 'torch.storage' and name == '_load_from_bytes':
return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
return super().find_class(module, name)
def load_network_pkl_cpu(f, force_fp16=False):
data = CPU_Unpickler(f).load()
# Legacy TensorFlow pickle => convert.
if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
tf_G, tf_D, tf_Gs = data
G = convert_tf_generator(tf_G)
D = convert_tf_discriminator(tf_D)
G_ema = convert_tf_generator(tf_Gs)
data = dict(G=G, D=D, G_ema=G_ema)
# Add missing fields.
if 'training_set_kwargs' not in data:
data['training_set_kwargs'] = None
if 'augment_pipe' not in data:
data['augment_pipe'] = None
# Validate contents.
assert isinstance(data['G'], torch.nn.Module)
assert isinstance(data['D'], torch.nn.Module)
assert isinstance(data['G_ema'], torch.nn.Module)
assert isinstance(data['training_set_kwargs'], (dict, type(None)))
assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
# Force FP16.
if force_fp16:
for key in ['G', 'D', 'G_ema']:
old = data[key]
kwargs = copy.deepcopy(old.init_kwargs)
fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
fp16_kwargs.num_fp16_res = 4
fp16_kwargs.conv_clamp = 256
if kwargs != old.init_kwargs:
new = type(old)(**kwargs).eval().requires_grad_(False)
misc.copy_params_and_buffers(old, new, require_all=True)
data[key] = new
return data
color_list = [[255, 255, 255], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]]
def colormap2labelmap(color_img):
im_base = np.zeros((color_img.shape[0], color_img.shape[1]))
for idx, color in enumerate(color_list):
k1=((color_img == np.asarray(color))[:,:,0])*1
k2=((color_img == np.asarray(color))[:,:,1])*1
k3=((color_img == np.asarray(color))[:,:,2])*1
k=((k1*k2*k3)==1)
im_base[k] = idx
return im_base
def checklabelmap(img):
labels=np.unique(img)
for idx,label in enumerate(labels):
img[img==label]=idx
return img
def get_all(cfg,input,truncation_psi,mesh_resolution,random_seed,fps,num_frames):
network=models[cfg]
if device=="cpu":
with dnnlib.util.open_url(network) as f:
G = load_network_pkl_cpu(f)['G_ema'].eval().to(device)
else:
with dnnlib.util.open_url(network) as f:
G = legacy.load_network_pkl(f)['G_ema'].eval().to(device)
if cfg == 'seg2cat' or cfg == 'seg2face':
neural_rendering_resolution = 128
data_type = 'seg'
# Initialize pose sampler.
forward_cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device),
radius=G.rendering_kwargs['avg_camera_radius'], device=device)
focal_length = 4.2647 # shapenet has higher FOV
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
forward_pose = torch.cat([forward_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
elif cfg == 'edge2car':
neural_rendering_resolution = 64
data_type= 'edge'
else:
print('Invalid cfg')
save_dir = Path(outdir)
if isinstance(input,str):
input_label =np.asarray( PIL.Image.open(input))
else:
input_label=np.asarray(input)
input_label=colormap2labelmap(input_label)
input_label=checklabelmap(input_label)
input_label = np.asarray(input_label).astype(np.uint8)
input_label = torch.from_numpy(input_label).unsqueeze(0).unsqueeze(0).to(device)
input_pose = forward_pose.to(device)
# Generate videos
z = torch.from_numpy(np.random.RandomState(int(random_seed)).randn(1, G.z_dim).astype('float32')).to(device)
with torch.no_grad():
ws = G.mapping(z, input_pose, {'mask': input_label, 'pose': input_pose})
out = G.synthesis(ws, input_pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution)
image_color = ((out['image'][0].permute(1, 2, 0).cpu().numpy().clip(-1, 1) + 1) * 127.5).astype(np.uint8)
image_seg = color_mask(torch.argmax(out['semantic'][0], dim=0).cpu().numpy()).astype(np.uint8)
mesh_trimesh = trimesh.Trimesh(*extract_geometry(G, ws, resolution=mesh_resolution, threshold=50.))
verts_np = np.array(mesh_trimesh.vertices)
colors = torch.zeros((verts_np.shape[0], 3), device=device)
semantic_colors = torch.zeros((verts_np.shape[0], 6), device=device)
samples_color = torch.tensor(verts_np, device=device).unsqueeze(0).float()
head = 0
max_batch = 10000000
with tqdm(total = verts_np.shape[0]) as pbar:
with torch.no_grad():
while head < verts_np.shape[0]:
torch.manual_seed(0)
out = G.sample_mixed(samples_color[:, head:head+max_batch], None, ws, truncation_psi=truncation_psi, noise_mode='const')
# sigma = out['sigma']
colors[head:head+max_batch, :] = out['rgb'][0,:,:3]
seg = out['rgb'][0, :, 32:32+6]
semantic_colors[head:head+max_batch, :] = seg
# semantics[:, head:head+max_batch] = out['semantic']
head += max_batch
pbar.update(max_batch)
semantic_colors = torch.tensor(color_list,device=device)[torch.argmax(semantic_colors, dim=-1)]
mesh_trimesh.visual.vertex_colors = semantic_colors.cpu().numpy().astype(np.uint8)
frames, frames_label = render_video(G, ws, intrinsics, num_frames = num_frames, pitch_range = 0.25, yaw_range = 0.35, neural_rendering_resolution=neural_rendering_resolution, device=device)
# Save the video
video=os.path.join(save_dir ,f'{cfg}_color.mp4')
video_label=os.path.join(save_dir,f'{cfg}_label.mp4')
imageio.mimsave(video, frames, fps=fps)
imageio.mimsave(video_label, frames_label, fps=fps),
fig_mesh=return_plot_go(mesh_trimesh)
return fig_mesh,image_color,image_seg,video,video_label
title="3D-aware Conditional Image Synthesis"
desc=f'''
[Arxiv: "3D-aware Conditional Image Synthesis".](https://arxiv.org/abs/2302.08509)
[Project Page.](https://www.cs.cmu.edu/~pix2pix3D/)
[For the official implementation.](https://github.com/dunbar12138/pix2pix3D)
### Future Work based on interest
- Adding new models for new type objects
- New Customization
It is running on {device}
The process can take long time.Especially ,To generate videos and the time of process depends the number of frames,Mesh Resolution and current compiler device.
'''
demo_inputs=[
gr.Dropdown(choices=["seg2cat"],label="Choose Model",value="seg2cat"),
gr.Image(type="filepath",shape=(512, 512),label="Mask"),
gr.Slider( minimum=0, maximum=2,label='Truncation PSI',value=1),
gr.Slider( minimum=32, maximum=512,label='Mesh Resolution',value=32),
gr.Slider( minimum=0, maximum=2**16,label='Seed',value=128),
gr.Slider( minimum=10, maximum=120,label='FPS',value=30),
gr.Slider( minimum=10, maximum=120,label='The Number of Frames',value=30),
]
demo_outputs=[
gr.Plot(label="Generated Mesh"),
gr.Image(type="pil",shape=(256,256),label="Generated Image"),
gr.Image(type="pil",shape=(256,256),label="Generated LabelMap"),
gr.Video(label="Generated Video ") ,
gr.Video(label="Generated Label Video ")
]
examples = [
["seg2cat", "img.png", 1, 32, 128, 30, 30],
["seg2cat", "img2.png", 1, 32, 128, 30, 30],
["seg2cat", "img3.png", 1, 32, 128, 30, 30],
]
demo_app = gr.Interface(
fn=get_all,
inputs=demo_inputs,
outputs=demo_outputs,
cache_examples=True,
title=title,
theme="huggingface",
description=desc,
examples=examples,
)
demo_app.launch(debug=True, enable_queue=True)