|
import gradio as gr |
|
import torch |
|
import os |
|
import tempfile |
|
import shutil |
|
import time |
|
import ffmpeg |
|
import numpy as np |
|
from PIL import Image |
|
import moviepy.editor as mp |
|
from infer import lotus, load_models |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
task_name = 'depth' |
|
pipe_g, pipe_d = load_models(task_name, device) |
|
|
|
|
|
def preprocess_video(video_path, target_fps=24, max_resolution=(1920, 1080)): |
|
"""Preprocess the video to resize and adjust its frame rate.""" |
|
video = mp.VideoFileClip(video_path) |
|
|
|
|
|
if video.size[0] > max_resolution[0] or video.size[1] > max_resolution[1]: |
|
video = video.resize(height=max_resolution[1]) |
|
|
|
|
|
if target_fps > 0: |
|
video = video.set_fps(target_fps) |
|
|
|
return video |
|
|
|
|
|
def process_frames_batch(frames_batch, seed=0): |
|
"""Process a batch of frames and return depth maps.""" |
|
try: |
|
|
|
images_batch = [Image.fromarray(frame).convert('RGB') for frame in frames_batch] |
|
|
|
|
|
depth_maps = lotus(images_batch, 'depth', seed, device, pipe_g, pipe_d) |
|
|
|
return depth_maps |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing batch: {e}") |
|
return [None] * len(frames_batch) |
|
|
|
|
|
def process_video(video_path, fps=0, seed=0, batch_size=16): |
|
"""Process video frames in batches and generate depth maps.""" |
|
|
|
temp_dir = tempfile.mkdtemp() |
|
try: |
|
start_time = time.time() |
|
|
|
|
|
video = preprocess_video(video_path, target_fps=fps) |
|
|
|
|
|
if fps == 0: |
|
fps = video.fps |
|
|
|
frames = list(video.iter_frames(fps=video.fps)) |
|
total_frames = len(frames) |
|
|
|
logger.info(f"Processing {total_frames} frames at {fps} FPS...") |
|
|
|
|
|
frames_dir = os.path.join(temp_dir, "frames") |
|
os.makedirs(frames_dir, exist_ok=True) |
|
|
|
processed_frames = [] |
|
|
|
|
|
for i in range(0, total_frames, batch_size): |
|
frames_batch = frames[i:i+batch_size] |
|
depth_maps = process_frames_batch(frames_batch, seed) |
|
|
|
for j, depth_map in enumerate(depth_maps): |
|
frame_index = i + j |
|
if depth_map is not None: |
|
|
|
frame_path = os.path.join(frames_dir, f"frame_{frame_index:06d}.png") |
|
depth_map.save(frame_path) |
|
|
|
|
|
if frame_index % max(1, total_frames // 10) == 0: |
|
elapsed_time = time.time() - start_time |
|
progress = (frame_index / total_frames) * 100 |
|
yield depth_map, None, None, f"Processed {frame_index}/{total_frames} frames... ({progress:.2f}%) Elapsed: {elapsed_time:.2f}s" |
|
else: |
|
logger.error(f"Error processing frame {frame_index}") |
|
|
|
logger.info("Creating output files...") |
|
|
|
|
|
zip_filename = f"depth_frames_{int(time.time())}.zip" |
|
zip_path = os.path.join(temp_dir, zip_filename) |
|
shutil.make_archive(zip_path[:-4], 'zip', frames_dir) |
|
|
|
|
|
video_filename = f"depth_video_{int(time.time())}.mp4" |
|
output_video_path = os.path.join(temp_dir, video_filename) |
|
|
|
try: |
|
|
|
( |
|
ffmpeg |
|
.input(os.path.join(frames_dir, 'frame_%06d.png'), pattern_type='sequence', framerate=fps) |
|
.output(output_video_path, vcodec='libx264', pix_fmt='yuv420p', crf=17) |
|
.run(overwrite_output=True, quiet=True) |
|
) |
|
logger.info("MP4 video created successfully!") |
|
|
|
except ffmpeg.Error as e: |
|
logger.error(f"Error creating video: {e.stderr.decode() if e.stderr else str(e)}") |
|
output_video_path = None |
|
|
|
total_time = time.time() - start_time |
|
logger.info("Processing complete!") |
|
|
|
|
|
yield None, zip_path, output_video_path, f"Processing complete! Total time: {total_time:.2f} seconds" |
|
|
|
except Exception as e: |
|
logger.error(f"Error: {e}") |
|
yield None, None, None, f"Error processing video: {e}" |
|
|
|
|
|
def process_wrapper(video, fps=0, seed=0, batch_size=16): |
|
if video is None: |
|
raise gr.Error("Please upload a video.") |
|
try: |
|
outputs = [] |
|
|
|
for output in process_video(video, fps, seed, batch_size): |
|
outputs.append(output) |
|
yield output |
|
return outputs[-1] |
|
except Exception as e: |
|
raise gr.Error(f"Error processing video: {str(e)}") |
|
|
|
|
|
custom_css = """ |
|
.title-container { |
|
text-align: center; |
|
padding: 10px 0; |
|
} |
|
|
|
#title { |
|
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; |
|
font-size: 36px; |
|
font-weight: bold; |
|
color: #000000; |
|
padding: 10px; |
|
border-radius: 10px; |
|
display: inline-block; |
|
background: linear-gradient( |
|
135deg, |
|
#e0f7fa, #e8f5e9, #fff9c4, #ffebee, |
|
#f3e5f5, #e1f5fe, #fff3e0, #e8eaf6 |
|
); |
|
background-size: 400% 400%; |
|
animation: gradient-animation 15s ease infinite; |
|
} |
|
|
|
@keyframes gradient-animation { |
|
0% { background-position: 0% 50%; } |
|
50% { background-position: 100% 50%; } |
|
100% { background-position: 0% 50%; } |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
gr.HTML(''' |
|
<div class="title-container"> |
|
<div id="title">Video Depth Estimation</div> |
|
</div> |
|
''') |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
video_input = gr.Video(label="Upload Video", interactive=True) |
|
fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 for original)") |
|
seed_slider = gr.Number(value=0, label="Seed") |
|
batch_size_slider = gr.Slider(minimum=1, maximum=64, step=1, value=16, label="Batch Size") |
|
btn = gr.Button("Process Video") |
|
|
|
with gr.Column(): |
|
preview_image = gr.Image(label="Live Preview") |
|
output_frames_zip = gr.File(label="Download Frame Sequence (ZIP)") |
|
output_video = gr.File(label="Download Video (MP4)") |
|
time_textbox = gr.Textbox(label="Status", interactive=False) |
|
|
|
btn.click( |
|
fn=process_wrapper, |
|
inputs=[video_input, fps_slider, seed_slider, batch_size_slider], |
|
outputs=[preview_image, output_frames_zip, output_video, time_textbox] |
|
) |
|
|
|
demo.queue() |
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True) |