import os import tyro import imageio import numpy as np import tqdm import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms.functional as TF from safetensors.torch import load_file import rembg import gradio as gr import kiui from kiui.op import recenter from kiui.cam import orbit_camera from core.utils import get_rays, grid_distortion, orbit_camera_jitter from core.options import AllConfigs, Options from core.models import LTRFM_Mesh,LTRFM_NeRF from core.instant_utils.mesh_util import save_obj, save_obj_with_mtl from mvdream.pipeline_mvdream import MVDreamPipeline from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler from huggingface_hub import hf_hub_download import spaces IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) GRADIO_VIDEO_PATH = 'gradio_output.mp4' GRADIO_OBJ_PATH = 'gradio_output_rgb.obj' GRADIO_OBJ_ALBEDO_PATH = 'gradio_output_albedo.obj' GRADIO_OBJ_SHADING_PATH = 'gradio_output_shading.obj' #opt = tyro.cli(AllConfigs) ckpt_path = hf_hub_download(repo_id="rgxie/LDM", filename="LDM6v01.ckpt") opt = Options( input_size=512, down_channels=(32, 64, 128, 256, 512), down_attention=(False, False, False, False, True), up_channels=(512, 256, 128), up_attention=(True, False, False, False), volume_mode='TRF_NeRF', splat_size=64, output_size=62, #crop patch data_mode='s5', num_views=8, gradient_accumulation_steps=1, #2 mixed_precision='bf16', resume=ckpt_path, ) # model if opt.volume_mode == 'TRF_Mesh': model = LTRFM_Mesh(opt) elif opt.volume_mode == 'TRF_NeRF': model = LTRFM_NeRF(opt) else: model = LGM(opt) # resume pretrained checkpoint if opt.resume is not None: if opt.resume.endswith('safetensors'): ckpt = load_file(opt.resume, device='cpu') else: #ckpt ckpt_dict = torch.load(opt.resume, map_location='cpu') ckpt=ckpt_dict["model"] state_dict = model.state_dict() for k, v in ckpt.items(): k=k.replace('module.', '') if k in state_dict: if state_dict[k].shape == v.shape: state_dict[k].copy_(v) else: print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.') else: print(f'[WARN] unexpected param {k}: {v.shape}') print(f'[INFO] load resume success!') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.half().to(device) model.eval() tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) proj_matrix = torch.zeros(4, 4, dtype=torch.float32).to(device) proj_matrix[0, 0] = 1 / tan_half_fov proj_matrix[1, 1] = 1 / tan_half_fov proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) proj_matrix[2, 3] = 1 # load dreams pipe_text = MVDreamPipeline.from_pretrained( 'ashawkey/mvdream-sd2.1-diffusers', # remote weights torch_dtype=torch.float16, trust_remote_code=True, # local_files_only=True, ) pipe_text = pipe_text.to(device) # mvdream pipe_image = MVDreamPipeline.from_pretrained( "ashawkey/imagedream-ipmv-diffusers", # remote weights torch_dtype=torch.float16, trust_remote_code=True, # local_files_only=True, ) pipe_image = pipe_image.to(device) print('Loading 123plus model ...') pipe_image_plus = DiffusionPipeline.from_pretrained( "sudo-ai/zero123plus-v1.2", custom_pipeline="zero123plus", torch_dtype=torch.float16, trust_remote_code=True, #local_files_only=True, ) pipe_image_plus.scheduler = EulerAncestralDiscreteScheduler.from_config( pipe_image_plus.scheduler.config, timestep_spacing='trailing' ) unet_path='./pretrained/diffusion_pytorch_model.bin' print('Loading custom white-background unet ...') if os.path.exists(unet_path): unet_ckpt_path = unet_path else: unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model") state_dict = torch.load(unet_ckpt_path, map_location='cpu') pipe_image_plus.unet.load_state_dict(state_dict, strict=True) pipe_image_plus = pipe_image_plus.to(device) # load rembg bg_remover = rembg.new_session() @spaces.GPU def generate_mv(condition_input_image, prompt, prompt_neg='', input_elevation=0, input_num_steps=30, input_seed=42, mv_moedl_option=None): # seed kiui.seed_everything(input_seed) os.makedirs(os.path.join(opt.workspace, "gradio"), exist_ok=True) output_video_path = os.path.join(opt.workspace,"gradio", GRADIO_VIDEO_PATH) # text-conditioned if condition_input_image is None: mv_image_uint8 = pipe_text(prompt, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=7.5, elevation=input_elevation) mv_image_uint8 = (mv_image_uint8 * 255).astype(np.uint8) # bg removal mv_image = [] for i in range(4): image = rembg.remove(mv_image_uint8[i], session=bg_remover) # [H, W, 4] # to white bg image = image.astype(np.float32) / 255 image = recenter(image, image[..., 0] > 0, border_ratio=0.2) image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:]) mv_image.append(image) mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1) input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) processed_image=None # image-conditioned (may also input text, but no text usually works too) else: condition_input_image = np.array(condition_input_image) # uint8 # bg removal carved_image = rembg.remove(condition_input_image, session=bg_remover) # [H, W, 4] mask = carved_image[..., -1] > 0 image = recenter(carved_image, mask, border_ratio=0.2) image = image.astype(np.float32) / 255.0 processed_image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) if mv_moedl_option=='mvdream': mv_image = pipe_image(prompt, processed_image, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=5.0, elevation=input_elevation) mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1) input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) else: from PIL import Image from einops import rearrange, repeat # input_image=input_image* 255 processed_image = Image.fromarray((processed_image * 255).astype(np.uint8)) mv_image = pipe_image_plus(processed_image, num_inference_steps=input_num_steps).images[0] mv_image = np.asarray(mv_image, dtype=np.float32) / 255.0 mv_image = torch.from_numpy(mv_image).permute(2, 0, 1).contiguous().float() # (3, 960, 640) mv_image_grid = rearrange(mv_image, 'c (n h) (m w) -> (m h) (n w) c', n=3, m=2).numpy() mv_image = rearrange(mv_image, 'c (n h) (m w) -> (n m) h w c', n=3, m=2).numpy() input_image = mv_image return mv_image_grid, processed_image, input_image @spaces.GPU def generate_3d(input_image, condition_input_image, mv_moedl_option=None, input_seed=42): kiui.seed_everything(input_seed) output_obj_rgb_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_PATH) output_obj_albedo_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_ALBEDO_PATH) output_obj_shading_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_SHADING_PATH) # generate gaussians # [4, 256, 256, 3], float32 input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256] input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False) images_input_vit = F.interpolate(input_image, size=(224, 224), mode='bilinear', align_corners=False) data = {} input_image = input_image.unsqueeze(0) # [1, 4, 9, H, W] images_input_vit=images_input_vit.unsqueeze(0) data['input_vit']=images_input_vit elevation = 0 cam_poses =[] if mv_moedl_option=='mvdream' or condition_input_image is None: azimuth = np.arange(0, 360, 90, dtype=np.int32) for azi in tqdm.tqdm(azimuth): cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) cam_poses.append(cam_pose) else: azimuth = np.arange(30, 360, 60, dtype=np.int32) cnt = 0 for azi in tqdm.tqdm(azimuth): if (cnt+1) % 2!= 0: elevation=-20 else: elevation=30 cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) cam_poses.append(cam_pose) cnt=cnt+1 cam_poses = torch.cat(cam_poses,0) radius = torch.norm(cam_poses[0, :3, 3]) cam_poses[:, :3, 3] *= opt.cam_radius / radius transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32).to(device) @ torch.inverse(cam_poses[0]) cam_poses = transform.unsqueeze(0) @ cam_poses cam_poses=cam_poses.unsqueeze(0) data['source_camera']=cam_poses with torch.no_grad(): if opt.volume_mode == 'TRF_Mesh': with torch.autocast(device_type='cuda', dtype=torch.float32): svd_volume = model.forward_svd_volume(input_image,data) else: with torch.autocast(device_type='cuda', dtype=torch.float16): svd_volume = model.forward_svd_volume(input_image,data) #time-consuming export_texmap=False mesh_out = model.extract_mesh(svd_volume,use_texture_map=export_texmap) if export_texmap: vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out for i in range(len(tex_map)): mesh_path=os.path.join(opt.workspace, name + str(i) + '_'+ str(seed)+ '.obj') save_obj_with_mtl( vertices.data.cpu().numpy(), uvs.data.cpu().numpy(), faces.data.cpu().numpy(), mesh_tex_idx.data.cpu().numpy(), tex_map[i].permute(1, 2, 0).data.cpu().numpy(), mesh_path, ) else: vertices, faces, vertex_colors = mesh_out save_obj(vertices, faces, vertex_colors[0], output_obj_rgb_path) save_obj(vertices, faces, vertex_colors[1], output_obj_albedo_path) save_obj(vertices, faces, vertex_colors[2], output_obj_shading_path) return output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path # gradio UI _TITLE = '''LDM: Large Tensorial SDF Model for Textured Mesh Generation''' _DESCRIPTION = ''' * Input can be text prompt, image. * If you find the output unsatisfying, try using different seeds! ''' block = gr.Blocks(title=_TITLE).queue() with block: with gr.Row(): with gr.Column(scale=1): gr.Markdown('# ' + _TITLE) gr.Markdown(_DESCRIPTION) with gr.Row(variant='panel'): with gr.Column(scale=1): with gr.Tab("Image-to-3D"): # input image with gr.Row(): condition_input_image = gr.Image( label="Input Image", image_mode="RGBA", type="pil" ) processed_image = gr.Image( label="Processed Image", image_mode="RGBA", type="pil", interactive=False ) with gr.Row(): mv_moedl_option = gr.Radio([ "zero123plus", "mvdream" ], value="zero123plus", label="Multi-view Diffusion") with gr.Row(variant="panel"): gr.Examples( examples=[ os.path.join("example", img_name) for img_name in sorted(os.listdir("example")) ], inputs=[condition_input_image], fn=lambda x: process(condition_input_image=x, prompt=''), cache_examples=False, examples_per_page=20, label='Image-to-3D Examples' ) with gr.Tab("Text-to-3D"): # input prompt with gr.Row(): input_text = gr.Textbox(label="prompt") # negative prompt with gr.Row(): input_neg_text = gr.Textbox(label="negative prompt", value='ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate') with gr.Row(variant="panel"): gr.Examples( examples=[ "a hamburger", "a furry red fox head", "a teddy bear", "a motorbike", ], inputs=[input_text], fn=lambda x: process(condition_input_image=None, prompt=x), cache_examples=False, label='Text-to-3D Examples' ) # elevation input_elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0) # inference steps input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=30) # random seed input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0) # gen button button_gen = gr.Button("Generate") with gr.Column(scale=1): with gr.Row(): # multi-view results mv_image_grid = gr.Image(interactive=False, show_label=False) with gr.Row(): output_obj_rgb_path = gr.Model3D( label="RGB Model (OBJ Format)", interactive=False, ) with gr.Row(): output_obj_albedo_path = gr.Model3D( label="Albedo Model (OBJ Format)", interactive=False, ) with gr.Row(): output_obj_shading_path = gr.Model3D( label="Shading Model (OBJ Format)", interactive=False, ) input_image = gr.State() button_gen.click(fn=generate_mv, inputs=[condition_input_image, input_text, input_neg_text, input_elevation, input_num_steps, input_seed, mv_moedl_option], outputs=[mv_image_grid, processed_image, input_image],).success( fn=generate_3d, inputs=[input_image, condition_input_image, mv_moedl_option, input_seed], outputs=[output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path] , ) block.launch(server_name="0.0.0.0", share=False)