import sys import os os.system("git clone https://github.com/dunbar12138/pix2pix3D.git") sys.path.append("pix2pix3D") from typing import List, Optional, Tuple, Union import dnnlib import numpy as np import PIL.Image import torch from tqdm import tqdm import legacy from camera_utils import LookAtPoseSampler from huggingface_hub import hf_hub_download from matplotlib import pyplot as plt from pathlib import Path import gradio as gr from training.utils import color_mask, color_list import plotly.graph_objects as go from tqdm import tqdm import imageio import trimesh import mcubes import copy import pickle import numpy as np import torch import dnnlib from torch_utils import misc from legacy import * import io os.environ["PYOPENGL_PLATFORM"] = "egl" def get_sigma_field_np(nerf, styles, resolution=512, block_resolution=64): # return numpy array of forwarded sigma value # bound = (nerf.rendering_kwargs['ray_end'] - nerf.rendering_kwargs['ray_start']) * 0.5 bound = nerf.rendering_kwargs['box_warp'] * 0.5 X = torch.linspace(-bound, bound, resolution).split(block_resolution) sigma_np = np.zeros([resolution, resolution, resolution], dtype=np.float32) for xi, xs in enumerate(X): for yi, ys in enumerate(X): for zi, zs in enumerate(X): xx, yy, zz = torch.meshgrid(xs, ys, zs) pts = torch.stack([xx, yy, zz], dim=-1).unsqueeze(0).to(styles.device) # B, H, H, H, C block_shape = [1, len(xs), len(ys), len(zs)] out = nerf.sample_mixed(pts.reshape(1,-1,3), None, ws=styles, noise_mode='const') feat_out, sigma_out = out['rgb'], out['sigma'] sigma_np[xi * block_resolution: xi * block_resolution + len(xs), \ yi * block_resolution: yi * block_resolution + len(ys), \ zi * block_resolution: zi * block_resolution + len(zs)] = sigma_out.reshape(block_shape[1:]).detach().cpu().numpy() # print(feat_out.shape) return sigma_np, bound def extract_geometry(nerf, styles, resolution, threshold): # print('threshold: {}'.format(threshold)) u, bound = get_sigma_field_np(nerf, styles, resolution) vertices, faces = mcubes.marching_cubes(u, threshold) # vertices, faces, normals, values = skimage.measure.marching_cubes( # u, level=10 # ) b_min_np = np.array([-bound, -bound, -bound]) b_max_np = np.array([ bound, bound, bound]) vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] return vertices.astype('float32'), faces def render_video(G, ws, intrinsics, num_frames = 120, pitch_range = 0.25, yaw_range = 0.35, neural_rendering_resolution = 128, device='cuda'): frames, frames_label = [], [] for frame_idx in tqdm(range(num_frames)): cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / num_frames), 3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / num_frames), torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=G.rendering_kwargs['avg_camera_radius'], device=device) pose = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) with torch.no_grad(): # out = G(z, pose, {'mask': batch['mask'].unsqueeze(0).to(device), 'pose': torch.tensor(batch['pose']).unsqueeze(0).to(device)}) out = G.synthesis(ws, pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution) frames.append(((out['image'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8).transpose(1, 2, 0)) frames_label.append(color_mask(torch.argmax(out['semantic'], dim=1).cpu().numpy()[0]).astype(np.uint8)) return frames, frames_label def return_plot_go(mesh_trimesh): x=np.asarray(mesh_trimesh.vertices).T[0] y=np.asarray(mesh_trimesh.vertices).T[1] z=np.asarray(mesh_trimesh.vertices).T[2] i=np.asarray(mesh_trimesh.faces).T[0] j=np.asarray(mesh_trimesh.faces).T[1] k=np.asarray(mesh_trimesh.faces).T[2] fig = go.Figure(go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, vertexcolor=np.asarray(mesh_trimesh.visual.vertex_colors) , lighting=dict(ambient=0.5, diffuse=1, fresnel=4, specular=0.5, roughness=0.05, facenormalsepsilon=0, vertexnormalsepsilon=0), lightposition=dict(x=100, y=100, z=1000))) return fig network_cat=hf_hub_download("SerdarHelli/pix2pix3d_seg2cat", filename="pix2pix3d_seg2cat.pkl",revision="main") models={"seg2cat":network_cat } device='cuda' if torch.cuda.is_available() else 'cpu' outdir="./" class CPU_Unpickler(pickle.Unpickler): def find_class(self, module, name): if module == 'torch.storage' and name == '_load_from_bytes': return lambda b: torch.load(io.BytesIO(b), map_location='cpu') return super().find_class(module, name) def load_network_pkl_cpu(f, force_fp16=False): data = CPU_Unpickler(f).load() # Legacy TensorFlow pickle => convert. if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): tf_G, tf_D, tf_Gs = data G = convert_tf_generator(tf_G) D = convert_tf_discriminator(tf_D) G_ema = convert_tf_generator(tf_Gs) data = dict(G=G, D=D, G_ema=G_ema) # Add missing fields. if 'training_set_kwargs' not in data: data['training_set_kwargs'] = None if 'augment_pipe' not in data: data['augment_pipe'] = None # Validate contents. assert isinstance(data['G'], torch.nn.Module) assert isinstance(data['D'], torch.nn.Module) assert isinstance(data['G_ema'], torch.nn.Module) assert isinstance(data['training_set_kwargs'], (dict, type(None))) assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) # Force FP16. if force_fp16: for key in ['G', 'D', 'G_ema']: old = data[key] kwargs = copy.deepcopy(old.init_kwargs) fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs) fp16_kwargs.num_fp16_res = 4 fp16_kwargs.conv_clamp = 256 if kwargs != old.init_kwargs: new = type(old)(**kwargs).eval().requires_grad_(False) misc.copy_params_and_buffers(old, new, require_all=True) data[key] = new return data color_list = [[255, 255, 255], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]] def colormap2labelmap(color_img): im_base = np.zeros((color_img.shape[0], color_img.shape[1])) for idx, color in enumerate(color_list): k1=((color_img == np.asarray(color))[:,:,0])*1 k2=((color_img == np.asarray(color))[:,:,1])*1 k3=((color_img == np.asarray(color))[:,:,2])*1 k=((k1*k2*k3)==1) im_base[k] = idx return im_base def checklabelmap(img): labels=np.unique(img) for idx,label in enumerate(labels): img[img==label]=idx return img def get_all(cfg,input,truncation_psi,mesh_resolution,random_seed,fps,num_frames): network=models[cfg] if device=="cpu": with dnnlib.util.open_url(network) as f: G = load_network_pkl_cpu(f)['G_ema'].eval().to(device) else: with dnnlib.util.open_url(network) as f: G = legacy.load_network_pkl(f)['G_ema'].eval().to(device) if cfg == 'seg2cat' or cfg == 'seg2face': neural_rendering_resolution = 128 data_type = 'seg' # Initialize pose sampler. forward_cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=G.rendering_kwargs['avg_camera_radius'], device=device) focal_length = 4.2647 # shapenet has higher FOV intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) forward_pose = torch.cat([forward_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) elif cfg == 'edge2car': neural_rendering_resolution = 64 data_type= 'edge' else: print('Invalid cfg') save_dir = Path(outdir) if isinstance(input,str): input_label =np.asarray( PIL.Image.open(input)) else: input_label=np.asarray(input) input_label=colormap2labelmap(input_label) input_label=checklabelmap(input_label) input_label = np.asarray(input_label).astype(np.uint8) input_label = torch.from_numpy(input_label).unsqueeze(0).unsqueeze(0).to(device) input_pose = forward_pose.to(device) # Generate videos z = torch.from_numpy(np.random.RandomState(int(random_seed)).randn(1, G.z_dim).astype('float32')).to(device) with torch.no_grad(): ws = G.mapping(z, input_pose, {'mask': input_label, 'pose': input_pose}) out = G.synthesis(ws, input_pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution) image_color = ((out['image'][0].permute(1, 2, 0).cpu().numpy().clip(-1, 1) + 1) * 127.5).astype(np.uint8) image_seg = color_mask(torch.argmax(out['semantic'][0], dim=0).cpu().numpy()).astype(np.uint8) mesh_trimesh = trimesh.Trimesh(*extract_geometry(G, ws, resolution=mesh_resolution, threshold=50.)) verts_np = np.array(mesh_trimesh.vertices) colors = torch.zeros((verts_np.shape[0], 3), device=device) semantic_colors = torch.zeros((verts_np.shape[0], 6), device=device) samples_color = torch.tensor(verts_np, device=device).unsqueeze(0).float() head = 0 max_batch = 10000000 with tqdm(total = verts_np.shape[0]) as pbar: with torch.no_grad(): while head < verts_np.shape[0]: torch.manual_seed(0) out = G.sample_mixed(samples_color[:, head:head+max_batch], None, ws, truncation_psi=truncation_psi, noise_mode='const') # sigma = out['sigma'] colors[head:head+max_batch, :] = out['rgb'][0,:,:3] seg = out['rgb'][0, :, 32:32+6] semantic_colors[head:head+max_batch, :] = seg # semantics[:, head:head+max_batch] = out['semantic'] head += max_batch pbar.update(max_batch) semantic_colors = torch.tensor(color_list,device=device)[torch.argmax(semantic_colors, dim=-1)] mesh_trimesh.visual.vertex_colors = semantic_colors.cpu().numpy().astype(np.uint8) frames, frames_label = render_video(G, ws, intrinsics, num_frames = num_frames, pitch_range = 0.25, yaw_range = 0.35, neural_rendering_resolution=neural_rendering_resolution, device=device) # Save the video video=os.path.join(save_dir ,f'{cfg}_color.mp4') video_label=os.path.join(save_dir,f'{cfg}_label.mp4') imageio.mimsave(video, frames, fps=fps) imageio.mimsave(video_label, frames_label, fps=fps), fig_mesh=return_plot_go(mesh_trimesh) return fig_mesh,image_color,image_seg,video,video_label title="3D-aware Conditional Image Synthesis" desc=f''' [Arxiv: "3D-aware Conditional Image Synthesis".](https://arxiv.org/abs/2302.08509) [Project Page.](https://www.cs.cmu.edu/~pix2pix3D/) [For the official implementation.](https://github.com/dunbar12138/pix2pix3D) ### Future Work based on interest - Adding new models for new type objects - New Customization It is running on {device} The process can take long time.Especially ,To generate videos and the time of process depends the number of frames,Mesh Resolution and current compiler device. ''' demo_inputs=[ gr.Dropdown(choices=["seg2cat"],label="Choose Model",value="seg2cat"), gr.Image(type="filepath",shape=(512, 512),label="Mask"), gr.Slider( minimum=0, maximum=2,label='Truncation PSI',value=1), gr.Slider( minimum=32, maximum=512,label='Mesh Resolution',value=32), gr.Slider( minimum=0, maximum=2**16,label='Seed',value=128), gr.Slider( minimum=10, maximum=120,label='FPS',value=30), gr.Slider( minimum=10, maximum=120,label='The Number of Frames',value=30), ] demo_outputs=[ gr.Plot(label="Generated Mesh"), gr.Image(type="pil",shape=(256,256),label="Generated Image"), gr.Image(type="pil",shape=(256,256),label="Generated LabelMap"), gr.Video(label="Generated Video ") , gr.Video(label="Generated Label Video ") ] examples = [ ["seg2cat", "img.png", 1, 32, 128, 30, 30], ["seg2cat", "img2.png", 1, 32, 128, 30, 30], ["seg2cat", "img3.png", 1, 32, 128, 30, 30], ] demo_app = gr.Interface( fn=get_all, inputs=demo_inputs, outputs=demo_outputs, cache_examples=True, title=title, theme="huggingface", description=desc, examples=examples, ) demo_app.launch(debug=True, enable_queue=True)