Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
import torchvision.transforms.functional as torchvision_F | |
import numpy as np | |
import os | |
import shutil | |
import importlib | |
import trimesh | |
import tempfile | |
import subprocess | |
import utils.options as options | |
import shlex | |
import time | |
import rembg | |
from utils.util import EasyDict as edict | |
from PIL import Image | |
from utils.eval_3D import get_dense_3D_grid, compute_level_grid, convert_to_explicit | |
def get_1d_bounds(arr): | |
nz = np.flatnonzero(arr) | |
return nz[0], nz[-1] | |
def get_bbox_from_mask(mask, thr): | |
masks_for_box = (mask > thr).astype(np.float32) | |
assert masks_for_box.sum() > 0, "Empty mask!" | |
x0, x1 = get_1d_bounds(masks_for_box.sum(axis=-2)) | |
y0, y1 = get_1d_bounds(masks_for_box.sum(axis=-1)) | |
return x0, y0, x1, y1 | |
def square_crop(image, bbox, crop_ratio=1.): | |
x1, y1, x2, y2 = bbox | |
h, w = y2-y1, x2-x1 | |
yc, xc = (y1+y2)/2, (x1+x2)/2 | |
S = max(h, w)*1.2 | |
scale = S*crop_ratio | |
image = torchvision_F.crop(image, top=int(yc-scale/2), left=int(xc-scale/2), height=int(scale), width=int(scale)) | |
return image | |
def preprocess_image(opt, image, bbox): | |
image = square_crop(image, bbox=bbox) | |
if image.size[0] != opt.W or image.size[1] != opt.H: | |
image = image.resize((opt.W, opt.H)) | |
image = torchvision_F.to_tensor(image) | |
rgb, mask = image[:3], image[3:] | |
if opt.data.bgcolor is not None: | |
# replace background color using mask | |
rgb = rgb * mask + opt.data.bgcolor * (1 - mask) | |
mask = (mask > 0.5).float() | |
return rgb, mask | |
def get_image(opt, image_fname, mask_fname): | |
image = Image.open(image_fname).convert("RGB") | |
mask = Image.open(mask_fname).convert("L") | |
mask_np = np.array(mask) | |
#binarize | |
mask_np[mask_np <= 127] = 0 | |
mask_np[mask_np >= 127] = 1.0 | |
image = Image.merge("RGBA", (*image.split(), mask)) | |
bbox = get_bbox_from_mask(mask_np, 0.5) | |
rgb_input_map, mask_input_map = preprocess_image(opt, image, bbox=bbox) | |
return rgb_input_map, mask_input_map | |
def get_intr(opt): | |
# load camera | |
f = 1.3875 | |
K = torch.tensor([[f*opt.W, 0, opt.W/2], | |
[0, f*opt.H, opt.H/2], | |
[0, 0, 1]]).float() | |
return K | |
def get_pixel_grid(H, W, device='cuda'): | |
y_range = torch.arange(H, dtype=torch.float32).to(device) | |
x_range = torch.arange(W, dtype=torch.float32).to(device) | |
Y, X = torch.meshgrid(y_range, x_range, indexing='ij') | |
Z = torch.ones_like(Y).to(device) | |
xyz_grid = torch.stack([X, Y, Z],dim=-1).view(-1,3) | |
return xyz_grid | |
def unproj_depth(depth, intr): | |
''' | |
depth: [B, H, W] | |
intr: [B, 3, 3] | |
''' | |
batch_size, H, W = depth.shape | |
intr = intr.to(depth.device) | |
# [B, 3, 3] | |
K_inv = torch.linalg.inv(intr).float() | |
# [1, H*W,3] | |
pixel_grid = get_pixel_grid(H, W, depth.device).unsqueeze(0) | |
# [B, H*W,3] | |
pixel_grid = pixel_grid.repeat(batch_size, 1, 1) | |
# [B, 3, H*W] | |
ray_dirs = K_inv @ pixel_grid.permute(0, 2, 1).contiguous() | |
# [B, H*W, 3], in camera coordinates | |
seen_points = ray_dirs.permute(0, 2, 1).contiguous() * depth.view(batch_size, H*W, 1) | |
# [B, H, W, 3] | |
seen_points = seen_points.view(batch_size, H, W, 3) | |
return seen_points | |
def prepare_data(opt, image_path, mask_path): | |
var = edict() | |
rgb_input_map, mask_input_map = get_image(opt, image_path, mask_path) | |
intr = get_intr(opt) | |
var.rgb_input_map = rgb_input_map.unsqueeze(0).to(opt.device) | |
var.mask_input_map = mask_input_map.unsqueeze(0).to(opt.device) | |
var.intr = intr.unsqueeze(0).to(opt.device) | |
var.idx = torch.tensor([0]).to(opt.device).long() | |
var.pose_gt = False | |
return var | |
def marching_cubes(opt, var, impl_network, visualize_attn=False): | |
points_3D = get_dense_3D_grid(opt, var) # [B, N, N, N, 3] | |
level_vox, attn_vis = compute_level_grid(opt, impl_network, var.latent_depth, var.latent_semantic, | |
points_3D, var.rgb_input_map, visualize_attn) | |
if attn_vis: var.attn_vis = attn_vis | |
# occ_grids: a list of length B, each is [N, N, N] | |
*level_grids, = level_vox.cpu().numpy() | |
meshes = convert_to_explicit(opt, level_grids, isoval=0.5, to_pointcloud=False) | |
var.mesh_pred = meshes | |
return var | |
def infer_sample(opt, var, graph): | |
var = graph.forward(opt, var, training=False, get_loss=False) | |
var = marching_cubes(opt, var, graph.impl_network, visualize_attn=True) | |
return var.mesh_pred[0] | |
def infer(input_image_path, input_mask_path): | |
opt_cmd = options.parse_arguments(["--yaml=options/shape.yaml", "--datadir=examples", "--eval.vox_res=128", "--ckpt=weights/shape.ckpt"]) | |
opt = options.set(opt_cmd=opt_cmd, safe_check=False) | |
opt.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# build model | |
print("Building model...") | |
opt.pretrain.depth = None | |
opt.arch.depth.pretrained = None | |
module = importlib.import_module("model.compute_graph.graph_shape") | |
graph = module.Graph(opt).to(opt.device) | |
# download checkpoint | |
if not os.path.isfile(opt.ckpt): | |
print("Downloading checkpoint...") | |
subprocess.run( | |
shlex.split( | |
"wget -q -O weights/shape.ckpt https://www.dropbox.com/scl/fi/hv3w9z59dqytievwviko4/shape.ckpt?rlkey=a2gut89kavrldmnt8b3df92oi&dl=0" | |
) | |
) | |
# wait if the checkpoint is still downloading | |
while not os.path.isfile(opt.ckpt): | |
time.sleep(1) | |
# load checkpoint | |
print("Loading checkpoint...") | |
checkpoint = torch.load(opt.ckpt, map_location=torch.device(opt.device)) | |
graph.load_state_dict(checkpoint["graph"], strict=True) | |
graph.eval() | |
# load the data | |
print("Loading data...") | |
var = prepare_data(opt, input_image_path, input_mask_path) | |
# create the save dir | |
save_folder = os.path.join(opt.datadir, 'preds') | |
if os.path.isdir(save_folder): | |
shutil.rmtree(save_folder) | |
os.makedirs(save_folder) | |
opt.output_path = opt.datadir | |
# inference the model and save the results | |
print("Inferencing...") | |
mesh_pred = infer_sample(opt, var, graph) | |
# rotate the mesh upside down | |
mesh_pred.apply_transform(trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0])) | |
mesh_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False) | |
mesh_pred.export(mesh_path.name, file_type="glb") | |
return mesh_path.name | |
def infer_wrapper_mask(input_image_path, input_mask_path): | |
return infer(input_image_path, input_mask_path) | |
def infer_wrapper_nomask(input_image_path): | |
input = Image.open(input_image_path) | |
segmented = rembg.remove(input) | |
mask = segmented.split()[-1] | |
mask_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
mask.save(mask_path.name) | |
return infer(input_image_path, mask_path.name), mask_path.name | |
def assert_input_image(input_image): | |
if input_image is None: | |
raise gr.Error("No image selected or uploaded!") | |
def assert_mask_image(input_mask): | |
if input_mask is None: | |
raise gr.Error("No mask selected or uploaded! Please check the box if you do not have the mask.") | |
def demo_gradio(): | |
with gr.Blocks(analytics_enabled=False) as demo_ui: | |
# HEADERS | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown('# ZeroShape: Regression-based Zero-shot Shape Reconstruction') | |
gr.Markdown("[\[Arxiv\]](https://arxiv.org/pdf/2312.14198.pdf) | [\[Project\]](https://zixuanh.com/projects/zeroshape.html) | [\[GitHub\]](https://github.com/zxhuang1698/ZeroShape)") | |
gr.Markdown("Please switch to the \"Estimated Mask\" tab if you do not have the foreground mask. The demo will try to estimate the mask for you.") | |
# with mask | |
with gr.Tab("Groundtruth Mask"): | |
with gr.Row(): | |
input_image_tab1 = gr.Image(label="Input Image", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300) | |
mask_tab1 = gr.Image(label="Foreground Mask", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300) | |
output_mesh_tab1 = gr.Model3D(label="Output Mesh") | |
with gr.Row(): | |
submit_tab1 = gr.Button('Reconstruct', elem_id="recon_button_tab1", variant='primary') | |
# examples | |
with gr.Row(): | |
examples_tab1 = [ | |
['examples/images/armchair.png', 'examples/masks/armchair.png'], | |
['examples/images/bolt.png', 'examples/masks/bolt.png'], | |
['examples/images/bucket.png', 'examples/masks/bucket.png'], | |
['examples/images/case.png', 'examples/masks/case.png'], | |
['examples/images/dispenser.png', 'examples/masks/dispenser.png'], | |
['examples/images/hat.png', 'examples/masks/hat.png'], | |
['examples/images/teddy_bear.png', 'examples/masks/teddy_bear.png'], | |
['examples/images/tiger.png', 'examples/masks/tiger.png'], | |
['examples/images/toy.png', 'examples/masks/toy.png'], | |
['examples/images/wedding_cake.png', 'examples/masks/wedding_cake.png'], | |
] | |
gr.Examples( | |
examples=examples_tab1, | |
inputs=[input_image_tab1, mask_tab1], | |
outputs=[output_mesh_tab1], | |
fn=infer_wrapper_mask, | |
cache_examples=False#os.getenv('SYSTEM') == 'spaces', | |
) | |
# without mask | |
with gr.Tab("Estimated Mask"): | |
with gr.Row(): | |
input_image_tab2 = gr.Image(label="Input Image", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300) | |
mask_tab2 = gr.Image(label="Foreground Mask", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300) | |
output_mesh_tab2 = gr.Model3D(label="Output Mesh") | |
with gr.Row(): | |
submit_tab2 = gr.Button('Reconstruct', elem_id="recon_button_tab2", variant='primary') | |
# examples | |
with gr.Row(): | |
examples_tab2 = [ | |
['examples/images/armchair.png'], | |
['examples/images/bolt.png'], | |
['examples/images/bucket.png'], | |
['examples/images/case.png'], | |
['examples/images/dispenser.png'], | |
['examples/images/hat.png'], | |
['examples/images/teddy_bear.png'], | |
['examples/images/tiger.png'], | |
['examples/images/toy.png'], | |
['examples/images/wedding_cake.png'], | |
] | |
gr.Examples( | |
examples=examples_tab2, | |
inputs=[input_image_tab2], | |
outputs=[output_mesh_tab2, mask_tab2], | |
fn=infer_wrapper_nomask, | |
cache_examples=False#os.getenv('SYSTEM') == 'spaces', | |
) | |
submit_tab1.click( | |
fn=assert_input_image, | |
inputs=[input_image_tab1], | |
queue=False | |
).success( | |
fn=assert_mask_image, | |
inputs=[mask_tab1], | |
queue=False | |
).success( | |
fn=infer_wrapper_mask, | |
inputs=[input_image_tab1, mask_tab1], | |
outputs=[output_mesh_tab1], | |
) | |
submit_tab2.click( | |
fn=assert_input_image, | |
inputs=[input_image_tab2], | |
queue=False | |
).success( | |
fn=infer_wrapper_nomask, | |
inputs=[input_image_tab2], | |
outputs=[output_mesh_tab2, mask_tab2], | |
) | |
return demo_ui | |
if __name__ == "__main__": | |
demo_ui = demo_gradio() | |
demo_ui.queue(max_size=10) | |
demo_ui.launch() |