LDM / app.py
xrg
fix bug
3a7d40d
raw
history blame
16 kB
import os
import tyro
import imageio
import numpy as np
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from safetensors.torch import load_file
import rembg
import gradio as gr
import kiui
from kiui.op import recenter
from kiui.cam import orbit_camera
from core.utils import get_rays, grid_distortion, orbit_camera_jitter
from core.options import AllConfigs, Options
from core.models import LTRFM_Mesh,LTRFM_NeRF
from core.instant_utils.mesh_util import save_obj, save_obj_with_mtl
from mvdream.pipeline_mvdream import MVDreamPipeline
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
from huggingface_hub import hf_hub_download
import spaces
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
GRADIO_VIDEO_PATH = 'gradio_output.mp4'
GRADIO_OBJ_PATH = 'gradio_output_rgb.obj'
GRADIO_OBJ_ALBEDO_PATH = 'gradio_output_albedo.obj'
GRADIO_OBJ_SHADING_PATH = 'gradio_output_shading.obj'
#opt = tyro.cli(AllConfigs)
ckpt_path = hf_hub_download(repo_id="rgxie/LDM", filename="LDM6v01.ckpt")
#ckpt_path = '/ssd3/xrg/tensor23d/pretrained/last_6view_0610_14.ckpt'
opt = Options(
input_size=512,
down_channels=(32, 64, 128, 256, 512),
down_attention=(False, False, False, False, True),
up_channels=(512, 256, 128),
up_attention=(True, False, False, False),
volume_mode='TRF_NeRF',
splat_size=64,
output_size=62, #crop patch
data_mode='s5',
num_views=8,
gradient_accumulation_steps=1, #2
mixed_precision='bf16',
resume=ckpt_path,
)
# model
if opt.volume_mode == 'TRF_Mesh':
model = LTRFM_Mesh(opt)
elif opt.volume_mode == 'TRF_NeRF':
model = LTRFM_NeRF(opt)
else:
model = LGM(opt)
# resume pretrained checkpoint
if opt.resume is not None:
if opt.resume.endswith('safetensors'):
ckpt = load_file(opt.resume, device='cpu')
else: #ckpt
ckpt_dict = torch.load(opt.resume, map_location='cpu')
ckpt=ckpt_dict["model"]
state_dict = model.state_dict()
for k, v in ckpt.items():
k=k.replace('module.', '')
if k in state_dict:
if state_dict[k].shape == v.shape:
state_dict[k].copy_(v)
else:
print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.')
else:
print(f'[WARN] unexpected param {k}: {v.shape}')
print(f'[INFO] load resume success!')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('first device')
print(device)
model = model.half().to(device)
model.eval()
tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
proj_matrix = torch.zeros(4, 4, dtype=torch.float32).to(device)
proj_matrix[0, 0] = 1 / tan_half_fov
proj_matrix[1, 1] = 1 / tan_half_fov
proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
proj_matrix[2, 3] = 1
# load dreams
pipe_text = MVDreamPipeline.from_pretrained(
'ashawkey/mvdream-sd2.1-diffusers', # remote weights
torch_dtype=torch.float16,
trust_remote_code=True,
# local_files_only=True,
)
pipe_text = pipe_text.to(device)
# mvdream
pipe_image = MVDreamPipeline.from_pretrained(
"ashawkey/imagedream-ipmv-diffusers", # remote weights
torch_dtype=torch.float16,
trust_remote_code=True,
# local_files_only=True,
)
pipe_image = pipe_image.to(device)
print('Loading 123plus model ...')
pipe_image_plus = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.2",
custom_pipeline="zero123plus",
torch_dtype=torch.float16,
trust_remote_code=True,
#local_files_only=True,
)
pipe_image_plus.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe_image_plus.scheduler.config, timestep_spacing='trailing'
)
unet_path='./pretrained/diffusion_pytorch_model.bin'
print('Loading custom white-background unet ...')
if os.path.exists(unet_path):
unet_ckpt_path = unet_path
else:
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
pipe_image_plus.unet.load_state_dict(state_dict, strict=True)
pipe_image_plus = pipe_image_plus.to(device)
# load rembg
bg_remover = rembg.new_session()
# process function
@spaces.GPU
def process(condition_input_image, prompt, prompt_neg='', input_elevation=0, input_num_steps=30, input_seed=42, mv_moedl_option=None):
# seed
kiui.seed_everything(input_seed)
os.makedirs(os.path.join(opt.workspace, "gradio"), exist_ok=True)
output_video_path = os.path.join(opt.workspace,"gradio", GRADIO_VIDEO_PATH)
output_obj_rgb_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_PATH)
output_obj_albedo_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_ALBEDO_PATH)
output_obj_shading_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_SHADING_PATH)
# text-conditioned
if condition_input_image is None:
mv_image_uint8 = pipe_text(prompt, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=7.5, elevation=input_elevation)
mv_image_uint8 = (mv_image_uint8 * 255).astype(np.uint8)
# bg removal
mv_image = []
for i in range(4):
image = rembg.remove(mv_image_uint8[i], session=bg_remover) # [H, W, 4]
# to white bg
image = image.astype(np.float32) / 255
image = recenter(image, image[..., 0] > 0, border_ratio=0.2)
image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:])
mv_image.append(image)
mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1)
input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
processed_image=None
# image-conditioned (may also input text, but no text usually works too)
else:
condition_input_image = np.array(condition_input_image) # uint8
# bg removal
carved_image = rembg.remove(condition_input_image, session=bg_remover) # [H, W, 4]
mask = carved_image[..., -1] > 0
image = recenter(carved_image, mask, border_ratio=0.2)
image = image.astype(np.float32) / 255.0
processed_image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
if mv_moedl_option=='mvdream':
mv_image = pipe_image(prompt, processed_image, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=5.0, elevation=input_elevation)
mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1)
input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
else:
from PIL import Image
from einops import rearrange, repeat
# input_image=input_image* 255
processed_image = Image.fromarray((processed_image * 255).astype(np.uint8))
mv_image = pipe_image_plus(processed_image, num_inference_steps=input_num_steps).images[0]
mv_image = np.asarray(mv_image, dtype=np.float32) / 255.0
mv_image = torch.from_numpy(mv_image).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
mv_image_grid = rearrange(mv_image, 'c (n h) (m w) -> (m h) (n w) c', n=3, m=2).numpy()
mv_image = rearrange(mv_image, 'c (n h) (m w) -> (n m) h w c', n=3, m=2).numpy()
input_image = mv_image
# generate gaussians
# [4, 256, 256, 3], float32
input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
images_input_vit = F.interpolate(input_image, size=(224, 224), mode='bilinear', align_corners=False)
data = {}
input_image = input_image.unsqueeze(0) # [1, 4, 9, H, W]
images_input_vit=images_input_vit.unsqueeze(0)
data['input_vit']=images_input_vit
elevation = 0
cam_poses =[]
if mv_moedl_option=='mvdream' or condition_input_image is None:
azimuth = np.arange(0, 360, 90, dtype=np.int32)
for azi in tqdm.tqdm(azimuth):
cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
cam_poses.append(cam_pose)
else:
azimuth = np.arange(30, 360, 60, dtype=np.int32)
cnt = 0
for azi in tqdm.tqdm(azimuth):
if (cnt+1) % 2!= 0:
elevation=-20
else:
elevation=30
cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
cam_poses.append(cam_pose)
cnt=cnt+1
cam_poses = torch.cat(cam_poses,0)
radius = torch.norm(cam_poses[0, :3, 3])
cam_poses[:, :3, 3] *= opt.cam_radius / radius
transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32).to(device) @ torch.inverse(cam_poses[0])
cam_poses = transform.unsqueeze(0) @ cam_poses
cam_poses=cam_poses.unsqueeze(0)
data['source_camera']=cam_poses
with torch.no_grad():
if opt.volume_mode == 'TRF_Mesh':
with torch.autocast(device_type='cuda', dtype=torch.float32):
svd_volume = model.forward_svd_volume(input_image,data)
else:
with torch.autocast(device_type='cuda', dtype=torch.float16):
svd_volume = model.forward_svd_volume(input_image,data)
#time-consuming
export_texmap=False
mesh_out = model.extract_mesh(svd_volume,use_texture_map=export_texmap)
if export_texmap:
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
for i in range(len(tex_map)):
mesh_path=os.path.join(opt.workspace, name + str(i) + '_'+ str(seed)+ '.obj')
save_obj_with_mtl(
vertices.data.cpu().numpy(),
uvs.data.cpu().numpy(),
faces.data.cpu().numpy(),
mesh_tex_idx.data.cpu().numpy(),
tex_map[i].permute(1, 2, 0).data.cpu().numpy(),
mesh_path,
)
else:
vertices, faces, vertex_colors = mesh_out
save_obj(vertices, faces, vertex_colors[0], output_obj_rgb_path)
save_obj(vertices, faces, vertex_colors[1], output_obj_albedo_path)
save_obj(vertices, faces, vertex_colors[2], output_obj_shading_path)
return mv_image_grid, processed_image, output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path
# gradio UI
_TITLE = '''LDM: Large Tensorial SDF Model for Textured Mesh Generation'''
_DESCRIPTION = '''
* Input can be text prompt, image.
* If you find the output unsatisfying, try using different seeds!
'''
block = gr.Blocks(title=_TITLE).queue()
with block:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown('# ' + _TITLE)
gr.Markdown(_DESCRIPTION)
with gr.Row(variant='panel'):
with gr.Column(scale=1):
with gr.Tab("Image-to-3D"):
# input image
with gr.Row():
condition_input_image = gr.Image(
label="Input Image",
image_mode="RGBA",
type="pil"
)
processed_image = gr.Image(
label="Processed Image",
image_mode="RGBA",
type="pil",
interactive=False
)
with gr.Row():
mv_moedl_option = gr.Radio([
"zero123plus",
"mvdream"
], value="zero123plus",
label="Multi-view Diffusion")
with gr.Row(variant="panel"):
gr.Examples(
examples=[
os.path.join("example", img_name) for img_name in sorted(os.listdir("example"))
],
inputs=[condition_input_image],
fn=lambda x: process(condition_input_image=x, prompt=''),
cache_examples=False,
examples_per_page=20,
label='Image-to-3D Examples'
)
with gr.Tab("Text-to-3D"):
# input prompt
with gr.Row():
input_text = gr.Textbox(label="prompt")
# negative prompt
with gr.Row():
input_neg_text = gr.Textbox(label="negative prompt", value='ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate')
with gr.Row(variant="panel"):
gr.Examples(
examples=[
"a hamburger",
"a furry red fox head",
"a teddy bear",
"a motorbike",
],
inputs=[input_text],
fn=lambda x: process(condition_input_image=None, prompt=x),
cache_examples=False,
label='Text-to-3D Examples'
)
# elevation
input_elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0)
# inference steps
input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=30)
# random seed
input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0)
# gen button
button_gen = gr.Button("Generate")
with gr.Column(scale=1):
with gr.Row():
# multi-view results
mv_image_grid = gr.Image(interactive=False, show_label=False)
with gr.Row():
output_obj_rgb_path = gr.Model3D(
label="RGB Model (OBJ Format)",
interactive=False,
)
with gr.Row():
output_obj_albedo_path = gr.Model3D(
label="Albedo Model (OBJ Format)",
interactive=False,
)
with gr.Row():
output_obj_shading_path = gr.Model3D(
label="Shading Model (OBJ Format)",
interactive=False,
)
button_gen.click(process, inputs=[condition_input_image, input_text, input_neg_text, input_elevation, input_num_steps, input_seed,mv_moedl_option], outputs=[mv_image_grid,processed_image, output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path])
block.launch(server_name="0.0.0.0", share=False)