ghostsInTheMachine commited on
Commit
c12e34c
1 Parent(s): 7bb1989

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -33
app.py CHANGED
@@ -1,23 +1,72 @@
1
  import gradio as gr
2
  import torch
3
- import os
 
 
 
4
  import tempfile
5
- import shutil
6
  import time
 
 
7
  import ffmpeg
8
- import numpy as np
9
- from PIL import Image
10
  from concurrent.futures import ThreadPoolExecutor
11
- import moviepy.editor as mp
 
12
  from infer import lotus # Import the depth model inference function
13
- import spaces
14
 
15
- # Set device to use the L40s GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
  # Add the preprocess_video function to limit video resolution and frame rate
19
- def preprocess_video(video_path, target_fps=24, max_resolution=(1920, 1080)):
20
- """Preprocess the video to resize and reduce its frame rate."""
21
  video = mp.VideoFileClip(video_path)
22
 
23
  # Resize video if it's larger than the target resolution
@@ -29,8 +78,11 @@ def preprocess_video(video_path, target_fps=24, max_resolution=(1920, 1080)):
29
 
30
  return video
31
 
32
- def process_frame(frame, seed=0):
33
- """Process a single frame through the depth model and return depth map."""
 
 
 
34
  try:
35
  # Convert frame to PIL Image
36
  image = Image.fromarray(frame)
@@ -39,7 +91,7 @@ def process_frame(frame, seed=0):
39
  with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
40
  image.save(tmp.name)
41
 
42
- # Process through the depth model (lotus)
43
  _, output_d = lotus(tmp.name, 'depth', seed, device)
44
 
45
  # Clean up temp file
@@ -54,10 +106,14 @@ def process_frame(frame, seed=0):
54
  return None
55
 
56
  @spaces.GPU
57
- def process_video(video_path, fps=0, seed=0, max_workers=32):
58
- """Process video, batch frames, and use L40s GPU to generate depth maps."""
 
 
 
59
  temp_dir = None
60
  try:
 
61
  start_time = time.time()
62
 
63
  # Preprocess the video
@@ -77,13 +133,11 @@ def process_video(video_path, fps=0, seed=0, max_workers=32):
77
  frames_dir = os.path.join(temp_dir, "frames")
78
  os.makedirs(frames_dir, exist_ok=True)
79
 
80
- # Process frames in larger batches (based on GPU VRAM)
81
- batch_size = 50 # Increased batch size to fully utilize the GPU's capabilities
82
  processed_frames = []
83
-
84
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
85
- for i in range(0, total_frames, batch_size):
86
- futures = [executor.submit(process_frame, frames[j], seed) for j in range(i, min(i + batch_size, total_frames))]
87
  for j, future in enumerate(futures):
88
  try:
89
  result = future.result()
@@ -95,10 +149,12 @@ def process_video(video_path, fps=0, seed=0, max_workers=32):
95
  # Collect processed frame for preview
96
  processed_frames.append(result)
97
 
98
- # Update preview (only showing every 10th frame to avoid clutter)
99
- if (i + j + 1) % 10 == 0:
100
- elapsed_time = time.time() - start_time
101
- yield processed_frames[-1], None, None, f"Processed {i+j+1}/{total_frames} frames... Elapsed: {elapsed_time:.2f}s"
 
 
102
  except Exception as e:
103
  print(f"Error processing frame {i + j + 1}: {e}")
104
 
@@ -113,6 +169,7 @@ def process_video(video_path, fps=0, seed=0, max_workers=32):
113
  shutil.make_archive(zip_path[:-4], 'zip', frames_dir)
114
 
115
  # Create MP4 video
 
116
  video_filename = f"depth_video_{int(time.time())}.mp4"
117
  video_path = os.path.join(output_dir, video_filename)
118
 
@@ -153,7 +210,7 @@ def process_video(video_path, fps=0, seed=0, max_workers=32):
153
  except Exception as e:
154
  print(f"Error cleaning up temp directory: {e}")
155
 
156
- def process_wrapper(video, fps=0, seed=0, max_workers=32):
157
  if video is None:
158
  raise gr.Error("Please upload a video.")
159
  try:
@@ -197,7 +254,7 @@ custom_css = """
197
  """
198
 
199
  # Gradio Interface
200
- with gr.Blocks(css=custom_css) as demo:
201
  gr.HTML('''
202
  <div class="title-container">
203
  <div id="title">Video Depth Estimation</div>
@@ -206,10 +263,36 @@ with gr.Blocks(css=custom_css) as demo:
206
 
207
  with gr.Row():
208
  with gr.Column():
209
- video_input = gr.Video(label="Upload Video", interactive=True, show_label=True)
210
- fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS")
211
- seed_slider = gr.Slider(minimum=0, maximum=999999999, step=1, value=0, label="Seed")
212
- max_workers_slider = gr.Slider(minimum=1, maximum=32, step=1, value=32, label="Max Workers")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  btn = gr.Button("Process Video", elem_id="submit-button")
214
 
215
  with gr.Column():
@@ -218,12 +301,39 @@ with gr.Blocks(css=custom_css) as demo:
218
  output_video = gr.File(label="Download Video (MP4)")
219
  time_textbox = gr.Textbox(label="Status", interactive=False)
220
 
221
- btn.click(fn=process_wrapper
 
 
 
 
 
222
 
223
- , inputs=[video_input, fps_slider, seed_slider, max_workers_slider],
224
- outputs=[preview_image, output_frames_zip, output_video, time_textbox])
 
 
 
225
 
226
  demo.queue()
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  if __name__ == "__main__":
229
- demo.launch(debug=True)
 
1
  import gradio as gr
2
  import torch
3
+ import spaces
4
+ import moviepy.editor as mp
5
+ from PIL import Image
6
+ import numpy as np
7
  import tempfile
 
8
  import time
9
+ import os
10
+ import shutil
11
  import ffmpeg
 
 
12
  from concurrent.futures import ThreadPoolExecutor
13
+ from gradio.themes.base import Base
14
+ from gradio.themes.utils import colors, fonts
15
  from infer import lotus # Import the depth model inference function
 
16
 
17
+ # Custom Theme Definition
18
+ class WhiteTheme(Base):
19
+ def __init__(
20
+ self,
21
+ *,
22
+ primary_hue: colors.Color | str = colors.orange,
23
+ font: fonts.Font | str | tuple[fonts.Font | str, ...] = (
24
+ fonts.GoogleFont("Inter"),
25
+ "ui-sans-serif",
26
+ "system-ui",
27
+ "sans-serif",
28
+ ),
29
+ font_mono: fonts.Font | str | tuple[fonts.Font | str, ...] = (
30
+ fonts.GoogleFont("Inter"),
31
+ "ui-monospace",
32
+ "system-ui",
33
+ "monospace",
34
+ )
35
+ ):
36
+ super().__init__(
37
+ primary_hue=primary_hue,
38
+ font=font,
39
+ font_mono=font_mono,
40
+ )
41
+
42
+ self.set(
43
+ background_fill_primary="*primary_50",
44
+ background_fill_secondary="white",
45
+ border_color_primary="*primary_300",
46
+ body_background_fill="white",
47
+ body_background_fill_dark="white",
48
+ block_background_fill="white",
49
+ block_background_fill_dark="white",
50
+ panel_background_fill="white",
51
+ panel_background_fill_dark="white",
52
+ body_text_color="black",
53
+ body_text_color_dark="black",
54
+ block_label_text_color="black",
55
+ block_label_text_color_dark="black",
56
+ block_border_color="white",
57
+ panel_border_color="white",
58
+ input_border_color="lightgray",
59
+ input_background_fill="white",
60
+ input_background_fill_dark="white",
61
+ shadow_drop="none"
62
+ )
63
+
64
+ # Set device
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
 
67
  # Add the preprocess_video function to limit video resolution and frame rate
68
+ def preprocess_video(video_path, target_fps=24, max_resolution=(640, 360)):
69
+ """Preprocess the video to reduce its resolution and frame rate."""
70
  video = mp.VideoFileClip(video_path)
71
 
72
  # Resize video if it's larger than the target resolution
 
78
 
79
  return video
80
 
81
+ def process_frame(frame, seed=0, start_time=None):
82
+ """
83
+ Process a single frame through the depth model.
84
+ Returns the discriminative depth map.
85
+ """
86
  try:
87
  # Convert frame to PIL Image
88
  image = Image.fromarray(frame)
 
91
  with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
92
  image.save(tmp.name)
93
 
94
+ # Process through lotus model
95
  _, output_d = lotus(tmp.name, 'depth', seed, device)
96
 
97
  # Clean up temp file
 
106
  return None
107
 
108
  @spaces.GPU
109
+ def process_video(video_path, fps=0, seed=0, max_workers=2):
110
+ """
111
+ Process video to create depth map sequence and video.
112
+ Maintains original resolution and framerate if fps=0.
113
+ """
114
  temp_dir = None
115
  try:
116
+ # Initialize start_time here for use in process_frame
117
  start_time = time.time()
118
 
119
  # Preprocess the video
 
133
  frames_dir = os.path.join(temp_dir, "frames")
134
  os.makedirs(frames_dir, exist_ok=True)
135
 
136
+ # Process frames in batches of 10
 
137
  processed_frames = []
 
138
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
139
+ for i in range(0, total_frames, 10): # Process 10 frames at a time
140
+ futures = [executor.submit(process_frame, frames[j], seed, start_time) for j in range(i, min(i + 10, total_frames))]
141
  for j, future in enumerate(futures):
142
  try:
143
  result = future.result()
 
149
  # Collect processed frame for preview
150
  processed_frames.append(result)
151
 
152
+ # Update preview
153
+ elapsed_time = time.time() - start_time
154
+ yield processed_frames[-1], None, None, f"Processing frame {i+j+1}/{total_frames}... Elapsed time: {elapsed_time:.2f} seconds"
155
+
156
+ if (i + j + 1) % 10 == 0:
157
+ print(f"Processed {i + j + 1}/{total_frames} frames")
158
  except Exception as e:
159
  print(f"Error processing frame {i + j + 1}: {e}")
160
 
 
169
  shutil.make_archive(zip_path[:-4], 'zip', frames_dir)
170
 
171
  # Create MP4 video
172
+ print("Creating MP4 video...")
173
  video_filename = f"depth_video_{int(time.time())}.mp4"
174
  video_path = os.path.join(output_dir, video_filename)
175
 
 
210
  except Exception as e:
211
  print(f"Error cleaning up temp directory: {e}")
212
 
213
+ def process_wrapper(video, fps=0, seed=0, max_workers=6):
214
  if video is None:
215
  raise gr.Error("Please upload a video.")
216
  try:
 
254
  """
255
 
256
  # Gradio Interface
257
+ with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
258
  gr.HTML('''
259
  <div class="title-container">
260
  <div id="title">Video Depth Estimation</div>
 
263
 
264
  with gr.Row():
265
  with gr.Column():
266
+ video_input = gr.Video(
267
+ label="Upload Video",
268
+ interactive=True,
269
+ show_label=True,
270
+ height=360,
271
+ width=640
272
+ )
273
+ with gr.Row():
274
+ fps_slider = gr.Slider(
275
+ minimum=0,
276
+ maximum=60,
277
+ step=1,
278
+ value=0,
279
+ label="Output FPS (0 will inherit the original fps value)",
280
+ )
281
+ seed_slider = gr.Slider(
282
+ minimum=0,
283
+ maximum=999999999,
284
+ step=1,
285
+ value=0,
286
+ label="Seed",
287
+ )
288
+ max_workers_slider = gr.Slider(
289
+ minimum=1,
290
+ maximum=32,
291
+ step=1,
292
+ value=6,
293
+ label="Max Workers",
294
+ info="Determines how many frames to process in parallel"
295
+ )
296
  btn = gr.Button("Process Video", elem_id="submit-button")
297
 
298
  with gr.Column():
 
301
  output_video = gr.File(label="Download Video (MP4)")
302
  time_textbox = gr.Textbox(label="Status", interactive=False)
303
 
304
+ gr.Markdown("""
305
+ ### Output Information
306
+ - High-quality MP4 video output
307
+ - Original resolution and framerate are maintained
308
+ - Frame sequence provided for maximum compatibility
309
+ """)
310
 
311
+ btn.click(
312
+ fn=process_wrapper,
313
+ inputs=[video_input, fps_slider, seed_slider, max_workers_slider],
314
+ outputs=[preview_image, output_frames_zip, output_video, time_textbox]
315
+ )
316
 
317
  demo.queue()
318
 
319
+ api = gr.Interface(
320
+ fn=process_wrapper,
321
+ inputs=[
322
+ gr.Video(label="Upload Video"),
323
+ gr.Number(label="FPS", value=0),
324
+ gr.Number(label="Seed", value=0),
325
+ gr.Number(label="Max Workers", value=6)
326
+ ],
327
+ outputs=[
328
+ gr.Image(label="Preview"),
329
+ gr.File(label="Frame Sequence"),
330
+ gr.File(label="Video"),
331
+ gr.Textbox(label="Status")
332
+ ],
333
+ title="Video Depth Estimation API",
334
+ description="Generate depth maps from videos",
335
+ api_name="/process_video"
336
+ )
337
+
338
  if __name__ == "__main__":
339
+ demo.launch(debug=True, show_error=True, share=False, server_name="0.0.0.0", server_port=7860)