Spaces:
Runtime error
Runtime error
# Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# -------------------------------------------------------------------------- | |
# If you find this code useful, we kindly ask you to cite our paper in your work. | |
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation | |
# More information about the method can be found at https://marigoldmonodepth.github.io | |
# -------------------------------------------------------------------------- | |
import functools | |
import os | |
import gradio as gr | |
import numpy as np | |
import torch as torch | |
from PIL import Image | |
import spaces | |
import diffusers | |
from stablenormal.pipeline_yoso_normal import YOSONormalsPipeline | |
from stablenormal.pipeline_stablenormal import StableNormalPipeline | |
from stablenormal.scheduler.heuristics_ddimsampler import HEURI_DDIMScheduler | |
from data_utils import HWC3, resize_image | |
import sys | |
import cv2 | |
sys.path.append('./geowizard') | |
from models.geowizard_pipeline import DepthNormalEstimationPipeline | |
class Geowizard(object): | |
''' | |
Simple Stable Diffusion Package | |
''' | |
def __init__(self): | |
self.model = DepthNormalEstimationPipeline.from_pretrained("lemonaddie/Geowizard", torch_dtype=torch.float16) | |
def cuda(self): | |
self.model.cuda() | |
return self | |
def cpu(self): | |
self.model.cpu() | |
return self | |
def float(self): | |
self.model.float() | |
return self | |
def to(self, device): | |
self.model.to(device) | |
return self | |
def eval(self): | |
self.model.eval() | |
return self | |
def train(self): | |
self.model.train() | |
return self | |
def __call__(self, img, image_resolution=768): | |
pipe_out = self.model(Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)), | |
denoising_steps = 10, | |
ensemble_size= 1, | |
processing_res = image_resolution, | |
match_input_res = True, | |
domain = "indoor", | |
color_map = "Spectral", | |
show_progress_bar = False, | |
) | |
pred_normal = pipe_out.normal_np | |
pred_normal = (pred_normal + 1) / 2 * 255 | |
pred_normal = pred_normal.astype(np.uint8) | |
return pred_normal | |
def __repr__(self): | |
return f"model: \n{self.model}" | |
class Marigold(Geowizard): | |
''' | |
Simple Stable Diffusion Package | |
''' | |
def __init__(self): | |
self.model= diffusers.MarigoldNormalsPipeline.from_pretrained("prs-eth/marigold-normals-v0-1", torch_dtype=torch.float16) | |
def __call__(self, img, image_resolution=768): | |
pipe_out = self.model(Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))) | |
pred_normal = pipe_out.prediction[0] | |
pred_normal[..., 0] = -pred_normal[..., 0] | |
pred_normal = (pred_normal + 1) / 2 * 255 | |
pred_normal = pred_normal.astype(np.uint8) | |
return pred_normal | |
def __repr__(self): | |
return f"model: \n{self.model}" | |
class StableNormal(Geowizard): | |
''' | |
Simple Stable Diffusion Package | |
''' | |
def __init__(self): | |
x_start_pipeline = YOSONormalsPipeline.from_pretrained('Stable-X/yoso-normal-v0-2', trust_remote_code=True, | |
variant="fp16", torch_dtype=torch.float16) | |
self.model = StableNormalPipeline.from_pretrained('Stable-X/stable-normal-v0-1', trust_remote_code=True, | |
variant="fp16", torch_dtype=torch.float16, | |
scheduler=HEURI_DDIMScheduler(prediction_type='sample', | |
beta_start=0.00085, beta_end=0.0120, | |
beta_schedule = "scaled_linear")) | |
# two stage concat | |
self.model.x_start_pipeline = x_start_pipeline | |
self.model.x_start_pipeline.to('cuda', torch.float16) | |
self.model.prior.to('cuda', torch.float16) | |
def __call__(self, img, image_resolution=768): | |
pipe_out = self.model(Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))) | |
pred_normal = pipe_out.prediction[0] | |
pred_normal = (pred_normal + 1) / 2 * 255 | |
pred_normal = pred_normal.astype(np.uint8) | |
return pred_normal | |
def to(self, device): | |
self.model.to(device, torch.float16) | |
def __repr__(self): | |
return f"model: \n{self.model}" | |
class DSINE(object): | |
''' | |
Simple Stable Diffusion Package | |
''' | |
def __init__(self): | |
self.model = torch.hub.load("hugoycj/DSINE-hub", "DSINE", local_file_path='./models/dsine.pt', trust_repo=True) | |
def cuda(self): | |
self.model.cuda() | |
return self | |
def float(self): | |
self.model.float() | |
return self | |
def to(self, device): | |
self.model.to(device) | |
return self | |
def eval(self): | |
self.model.eval() | |
return self | |
def train(self): | |
self.model.train() | |
return self | |
def __call__(self, img, image_resolution=768): | |
pred_normal = self.model.infer_cv2(img)[0] # (3, H, W) | |
pred_normal = (pred_normal + 1) / 2 * 255 | |
pred_normal = pred_normal.cpu().numpy().transpose(1, 2, 0) | |
# rgb | |
pred_normal = pred_normal.astype(np.uint8) | |
return pred_normal | |
def __repr__(self): | |
return f"model: \n{self.model}" | |
def process( | |
pipe_list, | |
path_input, | |
): | |
names = ['DSINE', 'Marigold', 'GeoWizard', 'StableNormal'] | |
path_out_vis_list = [] | |
for pipe in pipe_list: | |
try: | |
pipe.to('cuda') | |
except: | |
pass | |
img = cv2.imread(path_input) | |
raw_input_image = HWC3(img) | |
ori_H, ori_W, _ = raw_input_image.shape | |
img = resize_image(raw_input_image, 768) | |
pipe_out = pipe( | |
img, | |
768, | |
) | |
pred_normal= cv2.resize(pipe_out, (ori_W, ori_H)) | |
path_out_vis_list.append(Image.fromarray(pred_normal)) | |
try: | |
pipe.to('cpu') | |
except: | |
pass | |
_output = path_out_vis_list + [None] * (4 - len(path_out_vis_list)) | |
yield _output | |
def run_demo_server(pipe): | |
process_pipe = spaces.GPU(functools.partial(process, pipe), duration=120) | |
os.environ["GRADIO_ALLOW_FLAGGING"] = "never" | |
with gr.Blocks( | |
analytics_enabled=False, | |
title="Normal Estimation Comparison", | |
css=""" | |
#download { | |
height: 118px; | |
} | |
.slider .inner { | |
width: 5px; | |
background: #FFF; | |
} | |
.viewport { | |
aspect-ratio: 4/3; | |
} | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
h2 { | |
text-align: center; | |
display: block; | |
} | |
h3 { | |
text-align: center; | |
display: block; | |
} | |
""", | |
) as demo: | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
label="Input Image", | |
type="filepath", | |
height=256, | |
) | |
with gr.Column(): | |
submit_btn = gr.Button(value="Compute normal", variant="primary") | |
clear_btn = gr.Button(value="Clear") | |
with gr.Row(): | |
with gr.Column(): | |
DSINE_output_slider = gr.Image( | |
label="DSINE", | |
type="filepath", | |
) | |
with gr.Column(): | |
marigold_output_slider = gr.Image( | |
label="Marigold", | |
type="filepath", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
geowizard_output_slider = gr.Image( | |
label="Geowizard", | |
type="filepath", | |
) | |
with gr.Column(): | |
Ours_slider = gr.Image( | |
label="StableNormal", | |
type="filepath", | |
) | |
outputs = [ | |
DSINE_output_slider, | |
marigold_output_slider, | |
geowizard_output_slider, | |
Ours_slider, | |
] | |
submit_btn.click( | |
fn=process_pipe, | |
inputs=input_image, | |
outputs=outputs, | |
concurrency_limit=1, | |
) | |
gr.Examples( | |
fn=process_pipe, | |
examples=sorted([ | |
os.path.join("files", "images", name) | |
for name in os.listdir(os.path.join("files", "images")) | |
]), | |
inputs=input_image, | |
outputs=outputs, | |
cache_examples=False, | |
) | |
def clear_fn(): | |
out = [] | |
out += [ | |
gr.Button(interactive=True), | |
gr.Button(interactive=True), | |
gr.Image(value=None, interactive=True), | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
] | |
return out | |
clear_btn.click( | |
fn=clear_fn, | |
inputs=[], | |
outputs= | |
[ | |
submit_btn, | |
input_image, | |
marigold_output_slider, | |
geowizard_output_slider, | |
DSINE_output_slider, | |
Ours_slider, | |
], | |
) | |
demo.queue( | |
api_open=False, | |
).launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) | |
def main(): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
dsine_pipe = DSINE() | |
marigold_pipe = Marigold() | |
geowizard_pipe = Geowizard() | |
our_pipe = StableNormal() | |
run_demo_server([dsine_pipe, marigold_pipe, geowizard_pipe, our_pipe]) | |
if __name__ == "__main__": | |
main() | |