|
from gradio_imageslider import ImageSlider |
|
import functools |
|
import os |
|
import tempfile |
|
import diffusers |
|
import gradio as gr |
|
import imageio as imageio |
|
import numpy as np |
|
import spaces |
|
import torch as torch |
|
from PIL import Image |
|
from tqdm import tqdm |
|
from pathlib import Path |
|
import gradio |
|
from gradio.utils import get_cache_folder |
|
from infer import lotus, lotus_video |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
def infer(path_input, seed=0): |
|
name_base, name_ext = os.path.splitext(os.path.basename(path_input)) |
|
output_g, output_d = lotus(path_input, 'depth', seed, device) |
|
if not os.path.exists("files/output"): |
|
os.makedirs("files/output") |
|
g_save_path = os.path.join("files/output", f"{name_base}_g{name_ext}") |
|
d_save_path = os.path.join("files/output", f"{name_base}_d{name_ext}") |
|
output_g.save(g_save_path) |
|
output_d.save(d_save_path) |
|
return [path_input, g_save_path], [path_input, d_save_path] |
|
|
|
def infer_video(path_input, seed=0): |
|
frames_g, frames_d = lotus_video(path_input, 'depth', seed, device) |
|
if not os.path.exists("files/output"): |
|
os.makedirs("files/output") |
|
name_base, _ = os.path.splitext(os.path.basename(path_input)) |
|
g_save_path = os.path.join("files/output", f"{name_base}_g.mp4") |
|
d_save_path = os.path.join("files/output", f"{name_base}_d.mp4") |
|
imageio.mimsave(g_save_path, frames_g) |
|
imageio.mimsave(d_save_path, frames_d) |
|
return [g_save_path, d_save_path] |
|
|
|
def run_demo_server(): |
|
gradio_theme = gr.themes.Default() |
|
|
|
with gr.Blocks( |
|
theme=gradio_theme, |
|
title="LOTUS (Depth)", |
|
css=""" |
|
#download { |
|
height: 118px; |
|
} |
|
.slider .inner { |
|
width: 5px; |
|
background: #FFF; |
|
} |
|
.viewport { |
|
aspect-ratio: 4/3; |
|
} |
|
.tabs button.selected { |
|
font-size: 20px !important; |
|
color: crimson !important; |
|
} |
|
h1 { |
|
text-align: center; |
|
display: block; |
|
} |
|
h2 { |
|
text-align: center; |
|
display: block; |
|
} |
|
h3 { |
|
text-align: center; |
|
display: block; |
|
} |
|
.md_feedback li { |
|
margin-bottom: 0px !important; |
|
} |
|
""", |
|
head=""" |
|
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script> |
|
<script> |
|
window.dataLayer = window.dataLayer || []; |
|
function gtag() {dataLayer.push(arguments);} |
|
gtag('js', new Date()); |
|
gtag('config', 'G-1FWSVCGZTG'); |
|
</script> |
|
""", |
|
) as demo: |
|
gr.Markdown( |
|
""" |
|
# LOTUS: Diffusion-based Visual Foundation Model for High-quality Dense Prediction |
|
<p align="center"> |
|
<a title="Page" href="https://lotus3d.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://img.shields.io/badge/Project-Website-pink?logo=googlechrome&logoColor=white"> |
|
</a> |
|
<a title="arXiv" href="https://arxiv.org/abs/2409.18124" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://img.shields.io/badge/arXiv-Paper-b31b1b?logo=arxiv&logoColor=white"> |
|
</a> |
|
<a title="Github" href="https://github.com/EnVision-Research/Lotus" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://img.shields.io/github/stars/EnVision-Research/Lotus?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> |
|
</a> |
|
<a title="Social" href="https://x.com/haodongli00/status/1839524569058582884" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> |
|
</a> |
|
""" |
|
) |
|
with gr.Tabs(elem_classes=["tabs"]): |
|
with gr.Tab("IMAGE"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image( |
|
label="Input Image", |
|
type="filepath", |
|
) |
|
seed = gr.Number( |
|
label="Seed (only for Generative mode)", |
|
minimum=0, |
|
maximum=999999999, |
|
) |
|
with gr.Row(): |
|
image_submit_btn = gr.Button( |
|
value="Predict Depth!", variant="primary" |
|
) |
|
image_reset_btn = gr.Button(value="Reset") |
|
with gr.Column(): |
|
image_output_g = ImageSlider( |
|
label="Output (Generative)", |
|
type="filepath", |
|
interactive=False, |
|
elem_classes="slider", |
|
position=0.25, |
|
) |
|
with gr.Row(): |
|
image_output_d = ImageSlider( |
|
label="Output (Discriminative)", |
|
type="filepath", |
|
interactive=False, |
|
elem_classes="slider", |
|
position=0.25, |
|
) |
|
|
|
gr.Examples( |
|
fn=infer, |
|
examples=sorted([ |
|
os.path.join("files", "images", name) |
|
for name in os.listdir(os.path.join("files", "images")) |
|
]), |
|
inputs=[image_input], |
|
outputs=[image_output_g, image_output_d], |
|
cache_examples=True, |
|
) |
|
|
|
with gr.Tab("VIDEO"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_video = gr.Video( |
|
label="Input Video", |
|
autoplay=True, |
|
loop=True, |
|
) |
|
seed = gr.Number( |
|
label="Seed (only for Generative mode)", |
|
minimum=0, |
|
maximum=999999999, |
|
) |
|
with gr.Row(): |
|
video_submit_btn = gr.Button( |
|
value="Compute Depth!", variant="primary" |
|
) |
|
video_reset_btn = gr.Button(value="Reset") |
|
with gr.Column(): |
|
video_output_g = gr.Video( |
|
label="Output (Generative)", |
|
interactive=False, |
|
autoplay=True, |
|
loop=True, |
|
show_share_button=True, |
|
) |
|
with gr.Row(): |
|
video_output_d = gr.Video( |
|
label="Output (Discriminative)", |
|
interactive=False, |
|
autoplay=True, |
|
loop=True, |
|
show_share_button=True, |
|
) |
|
|
|
gr.Examples( |
|
fn=infer_video, |
|
examples=sorted([ |
|
os.path.join("files", "videos", name) |
|
for name in os.listdir(os.path.join("files", "videos")) |
|
]), |
|
inputs=[input_video], |
|
outputs=[video_output_g, video_output_d], |
|
cache_examples=True, |
|
) |
|
|
|
|
|
image_submit_btn.click( |
|
fn=infer, |
|
inputs=[image_input, seed], |
|
outputs=[image_output_g, image_output_d], |
|
concurrency_limit=1, |
|
) |
|
image_reset_btn.click( |
|
fn=lambda: ( |
|
None, |
|
None, |
|
None, |
|
), |
|
inputs=[], |
|
outputs=[image_output_g, image_output_d], |
|
queue=False, |
|
) |
|
|
|
|
|
video_submit_btn.click( |
|
fn=infer_video, |
|
inputs=[input_video, seed], |
|
outputs=[video_output_g, video_output_d], |
|
queue=True, |
|
) |
|
|
|
|
|
demo.queue( |
|
api_open=False, |
|
).launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
) |
|
|
|
def main(): |
|
os.system("pip freeze") |
|
run_demo_server() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|