Spaces:
Paused
Paused
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import open3d as o3d | |
import os | |
from PIL import Image | |
import tempfile | |
import torch | |
from transformers import GLPNImageProcessor, GLPNForDepthEstimation | |
def predict_depth(image): | |
feature_extractor = GLPNImageProcessor.from_pretrained("vinvino02/glpn-nyu") | |
model = GLPNForDepthEstimation.from_pretrained("vinvino02/glpn-nyu") | |
# load and resize the input image | |
new_height = 480 if image.height > 480 else image.height | |
new_height -= (new_height % 32) | |
new_width = int(new_height * image.width / image.height) | |
diff = new_width % 32 | |
new_width = new_width - diff if diff < 16 else new_width + 32 - diff | |
new_size = (new_width, new_height) | |
image = image.resize(new_size) | |
# prepare image for the model | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
# get the prediction from the model | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predicted_depth = outputs.predicted_depth | |
output = predicted_depth.squeeze().cpu().numpy() * 1000.0 | |
# remove borders | |
pad = 16 | |
output = output[pad:-pad, pad:-pad] | |
image = image.crop((pad, pad, image.width - pad, image.height - pad)) | |
return image, output | |
def generate_mesh(image, depth_image, quality): | |
width, height = image.size | |
# depth_image = (depth_map * 255 / np.max(depth_map)).astype('uint8') | |
image = np.array(image) | |
# create rgbd image | |
depth_o3d = o3d.geometry.Image(depth_image) | |
image_o3d = o3d.geometry.Image(image) | |
rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(image_o3d, depth_o3d, | |
convert_rgb_to_intensity=False) | |
# camera settings | |
camera_intrinsic = o3d.camera.PinholeCameraIntrinsic() | |
camera_intrinsic.set_intrinsics(width, height, 500, 500, width / 2, height / 2) | |
# create point cloud | |
pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, camera_intrinsic) | |
# outliers removal | |
cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=20.0) | |
pcd = pcd.select_by_index(ind) | |
# estimate normals | |
pcd.estimate_normals() | |
pcd.orient_normals_to_align_with_direction(orientation_reference=(0., 0., -1.)) | |
# surface reconstruction | |
mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=quality, n_threads=1)[0] | |
# rotate the mesh | |
rotation = mesh.get_rotation_matrix_from_xyz((np.pi, np.pi, 0)) | |
mesh.rotate(rotation, center=(0, 0, 0)) | |
# save the mesh | |
temp_name = next(tempfile._get_candidate_names()) + '.obj' | |
o3d.io.write_triangle_mesh("public/" + temp_name, mesh) | |
return temp_name | |
def predict(image, quality): | |
image, depth_map = predict_depth(image) | |
depth_image = (depth_map * 255 / np.max(depth_map)).astype('uint8') | |
mesh_path = generate_mesh(image, depth_image, quality + 5) | |
colormap = plt.get_cmap('plasma') | |
depth_image = (colormap(depth_image) * 255).astype('uint8') | |
depth_image = Image.fromarray(depth_image) | |
return depth_image, mesh_path | |
if __name__ == '__main__': | |
# GUI | |
title = 'Image2Mesh' | |
description = 'Demo based on my <a href="https://towardsdatascience.com/generate-a-3d-mesh-from-an-image-with-python' \ | |
'-12210c73e5cc">article</a>. This demo predicts the depth of an image and then generates the 3D mesh. ' \ | |
'Choosing a higher quality increases the time to generate the mesh. You can download the mesh by ' \ | |
'clicking the top-right button on the 3D viewer. ' | |
examples = [[f'examples/{name}', 3] for name in sorted(os.listdir('examples'))] | |
# example image source: | |
# N. Silberman, D. Hoiem, P. Kohli, and Rob Fergus, | |
# Indoor Segmentation and Support Inference from RGBD Images (2012) | |
iface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Image(type='pil', label='Input Image'), | |
gr.Slider(1, 5, step=1, value=3, label='Mesh quality') | |
], | |
outputs=[ | |
gr.Image(label='Depth'), | |
gr.Model3D(label='3D Model', clear_color=[0.0, 0.0, 0.0, 0.0]) | |
], | |
examples=examples, | |
allow_flagging='never', | |
cache_examples=False, | |
title=title, | |
description=description | |
) | |
iface.launch() |