import sys import os os.system("git clone") sys.path.append("pix2pix3D") from typing import List, Optional, Tuple, Union import dnnlib import numpy as np import PIL.Image import torch from tqdm import tqdm import legacy from camera_utils import LookAtPoseSampler from huggingface_hub import hf_hub_download from matplotlib import pyplot as plt from pathlib import Path import gradio as gr from training.utils import color_mask, color_list import plotly.graph_objects as go from tqdm import tqdm import imageio import trimesh import mcubes import copy import pickle import numpy as np import torch import dnnlib from torch_utils import misc from legacy import * import io os.environ["PYOPENGL_PLATFORM"] = "egl" def get_sigma_field_np(nerf, styles, resolution=512, block_resolution=64): # return numpy array of forwarded sigma value # bound = (nerf.rendering_kwargs['ray_end'] - nerf.rendering_kwargs['ray_start']) * 0.5 bound = nerf.rendering_kwargs['box_warp'] * 0.5 X = torch.linspace(-bound, bound, resolution).split(block_resolution) sigma_np = np.zeros([resolution, resolution, resolution], dtype=np.float32) for xi, xs in enumerate(X): for yi, ys in enumerate(X): for zi, zs in enumerate(X): xx, yy, zz = torch.meshgrid(xs, ys, zs) pts = torch.stack([xx, yy, zz], dim=-1).unsqueeze(0).to(styles.device) # B, H, H, H, C block_shape = [1, len(xs), len(ys), len(zs)] out = nerf.sample_mixed(pts.reshape(1,-1,3), None, ws=styles, noise_mode='const') feat_out, sigma_out = out['rgb'], out['sigma'] sigma_np[xi * block_resolution: xi * block_resolution + len(xs), \ yi * block_resolution: yi * block_resolution + len(ys), \ zi * block_resolution: zi * block_resolution + len(zs)] = sigma_out.reshape(block_shape[1:]).detach().cpu().numpy() # print(feat_out.shape) return sigma_np, bound def extract_geometry(nerf, styles, resolution, threshold): # print('threshold: {}'.format(threshold)) u, bound = get_sigma_field_np(nerf, styles, resolution) vertices, faces = mcubes.marching_cubes(u, threshold) # vertices, faces, normals, values = skimage.measure.marching_cubes( # u, level=10 # ) b_min_np = np.array([-bound, -bound, -bound]) b_max_np = np.array([ bound, bound, bound]) vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] return vertices.astype('float32'), faces def render_video(G, ws, intrinsics, num_frames = 120, pitch_range = 0.25, yaw_range = 0.35, neural_rendering_resolution = 128, device='cuda'): frames, frames_label = [], [] for frame_idx in tqdm(range(num_frames)): cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / num_frames), 3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / num_frames), torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=G.rendering_kwargs['avg_camera_radius'], device=device) pose =[cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) with torch.no_grad(): # out = G(z, pose, {'mask': batch['mask'].unsqueeze(0).to(device), 'pose': torch.tensor(batch['pose']).unsqueeze(0).to(device)}) out = G.synthesis(ws, pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution) frames.append(((out['image'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8).transpose(1, 2, 0)) frames_label.append(color_mask(torch.argmax(out['semantic'], dim=1).cpu().numpy()[0]).astype(np.uint8)) return frames, frames_label def return_plot_go(mesh_trimesh): x=np.asarray(mesh_trimesh.vertices).T[0] y=np.asarray(mesh_trimesh.vertices).T[1] z=np.asarray(mesh_trimesh.vertices).T[2] i=np.asarray(mesh_trimesh.faces).T[0] j=np.asarray(mesh_trimesh.faces).T[1] k=np.asarray(mesh_trimesh.faces).T[2] fig = go.Figure(go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, vertexcolor=np.asarray(mesh_trimesh.visual.vertex_colors) , lighting=dict(ambient=0.5, diffuse=1, fresnel=4, specular=0.5, roughness=0.05, facenormalsepsilon=0, vertexnormalsepsilon=0), lightposition=dict(x=100, y=100, z=1000))) return fig network_cat=hf_hub_download("SerdarHelli/pix2pix3d_seg2cat", filename="pix2pix3d_seg2cat.pkl",revision="main") models={"seg2cat":network_cat } device='cuda' if torch.cuda.is_available() else 'cpu' outdir="./" class CPU_Unpickler(pickle.Unpickler): def find_class(self, module, name): if module == '' and name == '_load_from_bytes': return lambda b: torch.load(io.BytesIO(b), map_location='cpu') return super().find_class(module, name) def load_network_pkl_cpu(f, force_fp16=False): data = CPU_Unpickler(f).load() # Legacy TensorFlow pickle => convert. if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): tf_G, tf_D, tf_Gs = data G = convert_tf_generator(tf_G) D = convert_tf_discriminator(tf_D) G_ema = convert_tf_generator(tf_Gs) data = dict(G=G, D=D, G_ema=G_ema) # Add missing fields. if 'training_set_kwargs' not in data: data['training_set_kwargs'] = None if 'augment_pipe' not in data: data['augment_pipe'] = None # Validate contents. assert isinstance(data['G'], torch.nn.Module) assert isinstance(data['D'], torch.nn.Module) assert isinstance(data['G_ema'], torch.nn.Module) assert isinstance(data['training_set_kwargs'], (dict, type(None))) assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) # Force FP16. if force_fp16: for key in ['G', 'D', 'G_ema']: old = data[key] kwargs = copy.deepcopy(old.init_kwargs) fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs) fp16_kwargs.num_fp16_res = 4 fp16_kwargs.conv_clamp = 256 if kwargs != old.init_kwargs: new = type(old)(**kwargs).eval().requires_grad_(False) misc.copy_params_and_buffers(old, new, require_all=True) data[key] = new return data color_list = [[255, 255, 255], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]] def colormap2labelmap(color_img): im_base = np.zeros((color_img.shape[0], color_img.shape[1])) for idx, color in enumerate(color_list): k1=((color_img == np.asarray(color))[:,:,0])*1 k2=((color_img == np.asarray(color))[:,:,1])*1 k3=((color_img == np.asarray(color))[:,:,2])*1 k=((k1*k2*k3)==1) im_base[k] = idx return im_base def checklabelmap(img): labels=np.unique(img) for idx,label in enumerate(labels): img[img==label]=idx return img def get_all(cfg,input,truncation_psi,mesh_resolution,random_seed,fps,num_frames): network=models[cfg] if device=="cpu": with dnnlib.util.open_url(network) as f: G = load_network_pkl_cpu(f)['G_ema'].eval().to(device) else: with dnnlib.util.open_url(network) as f: G = legacy.load_network_pkl(f)['G_ema'].eval().to(device) if cfg == 'seg2cat' or cfg == 'seg2face': neural_rendering_resolution = 128 data_type = 'seg' # Initialize pose sampler. forward_cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=G.rendering_kwargs['avg_camera_radius'], device=device) focal_length = 4.2647 # shapenet has higher FOV intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) forward_pose =[forward_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) elif cfg == 'edge2car': neural_rendering_resolution = 64 data_type= 'edge' else: print('Invalid cfg') save_dir = Path(outdir) if isinstance(input,str): input_label =np.asarray( else: input_label=np.asarray(input) input_label=colormap2labelmap(input_label) input_label=checklabelmap(input_label) input_label = np.asarray(input_label).astype(np.uint8) input_label = torch.from_numpy(input_label).unsqueeze(0).unsqueeze(0).to(device) input_pose = # Generate videos z = torch.from_numpy(np.random.RandomState(int(random_seed)).randn(1, G.z_dim).astype('float32')).to(device) with torch.no_grad(): ws = G.mapping(z, input_pose, {'mask': input_label, 'pose': input_pose}) out = G.synthesis(ws, input_pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution) image_color = ((out['image'][0].permute(1, 2, 0).cpu().numpy().clip(-1, 1) + 1) * 127.5).astype(np.uint8) image_seg = color_mask(torch.argmax(out['semantic'][0], dim=0).cpu().numpy()).astype(np.uint8) mesh_trimesh = trimesh.Trimesh(*extract_geometry(G, ws, resolution=mesh_resolution, threshold=50.)) verts_np = np.array(mesh_trimesh.vertices) colors = torch.zeros((verts_np.shape[0], 3), device=device) semantic_colors = torch.zeros((verts_np.shape[0], 6), device=device) samples_color = torch.tensor(verts_np, device=device).unsqueeze(0).float() head = 0 max_batch = 10000000 with tqdm(total = verts_np.shape[0]) as pbar: with torch.no_grad(): while head < verts_np.shape[0]: torch.manual_seed(0) out = G.sample_mixed(samples_color[:, head:head+max_batch], None, ws, truncation_psi=truncation_psi, noise_mode='const') # sigma = out['sigma'] colors[head:head+max_batch, :] = out['rgb'][0,:,:3] seg = out['rgb'][0, :, 32:32+6] semantic_colors[head:head+max_batch, :] = seg # semantics[:, head:head+max_batch] = out['semantic'] head += max_batch pbar.update(max_batch) semantic_colors = torch.tensor(color_list,device=device)[torch.argmax(semantic_colors, dim=-1)] mesh_trimesh.visual.vertex_colors = semantic_colors.cpu().numpy().astype(np.uint8) frames, frames_label = render_video(G, ws, intrinsics, num_frames = num_frames, pitch_range = 0.25, yaw_range = 0.35, neural_rendering_resolution=neural_rendering_resolution, device=device) # Save the video video=os.path.join(save_dir ,f'{cfg}_color.mp4') video_label=os.path.join(save_dir,f'{cfg}_label.mp4') imageio.mimsave(video, frames, fps=fps) imageio.mimsave(video_label, frames_label, fps=fps), fig_mesh=return_plot_go(mesh_trimesh) return fig_mesh,image_color,image_seg,video,video_label title="3D-aware Conditional Image Synthesis" desc=f''' [Arxiv: "3D-aware Conditional Image Synthesis".]( [Project Page.]( [For the official implementation.]( ### Future Work based on interest - Adding new models for new type objects - New Customization It is running on {device} The process can take long time.Especially ,To generate videos and the time of process depends the number of frames,Mesh Resolution and current compiler device. ''' demo_inputs=[ gr.Dropdown(choices=["seg2cat"],label="Choose Model",value="seg2cat"), gr.Image(type="filepath",shape=(512, 512),label="Mask"), gr.Slider( minimum=0, maximum=2,label='Truncation PSI',value=1), gr.Slider( minimum=32, maximum=512,label='Mesh Resolution',value=32), gr.Slider( minimum=0, maximum=2**16,label='Seed',value=128), gr.Slider( minimum=10, maximum=120,label='FPS',value=30), gr.Slider( minimum=10, maximum=120,label='The Number of Frames',value=30), ] demo_outputs=[ gr.Plot(label="Generated Mesh"), gr.Image(type="pil",shape=(256,256),label="Generated Image"), gr.Image(type="pil",shape=(256,256),label="Generated LabelMap"), gr.Video(label="Generated Video ") , gr.Video(label="Generated Label Video ") ] examples = [ ["seg2cat", "img.png", 1, 32, 128, 30, 30], ["seg2cat", "img2.png", 1, 32, 128, 30, 30], ["seg2cat", "img3.png", 1, 32, 128, 30, 30], ] demo_app = gr.Interface( fn=get_all, inputs=demo_inputs, outputs=demo_outputs, cache_examples=True, title=title, theme="huggingface", description=desc, examples=examples, ) demo_app.launch(debug=True, enable_queue=True)