ghostsInTheMachine commited on
Commit
17eeb1a
1 Parent(s): 48a87fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -86
app.py CHANGED
@@ -2,114 +2,227 @@ import gradio as gr
2
  import torch
3
  import os
4
  import tempfile
5
- import imageio
6
- import numpy as np
7
  import shutil
 
 
 
8
  from PIL import Image
9
  from concurrent.futures import ThreadPoolExecutor
10
- import ffmpeg
11
- from infer import lotus, lotus_video # Import the depth model inference function
12
 
13
- # Set device
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
- def process_frame(path_input, seed):
17
- """
18
- Process a single frame through the depth model.
19
- Returns the original and depth-processed images.
20
- """
21
- name_base, name_ext = os.path.splitext(os.path.basename(path_input))
22
 
23
- # Process the frame with the model
24
- output_g, output_d = lotus(path_input, 'depth', seed, device)
 
25
 
26
- # Save generated and depth maps to temporary paths
27
- g_save_path = os.path.join(tempfile.gettempdir(), f"{name_base}_g{name_ext}")
28
- d_save_path = os.path.join(tempfile.gettempdir(), f"{name_base}_d{name_ext}")
29
 
30
- output_g.save(g_save_path)
31
- output_d.save(d_save_path)
32
-
33
- return [path_input, g_save_path], [path_input, d_save_path]
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- def process_video_live(path_input, seed):
37
- """
38
- Process video frame-by-frame, showing each processed frame live in the preview and compile the final video.
39
- """
40
- temp_dir = tempfile.mkdtemp()
41
-
42
- # Extract video frames
43
- video = imageio.get_reader(path_input)
44
- fps = video.get_meta_data()['fps']
45
- frames = [frame for frame in video]
46
- total_frames = len(frames)
47
-
48
- print(f"Processing {total_frames} frames at {fps} FPS...")
49
-
50
- processed_frames_g = []
51
- processed_frames_d = []
52
-
53
- for i, frame in enumerate(frames):
54
- frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png")
55
- Image.fromarray(frame).save(frame_path)
56
 
57
- # Process the frame using the lotus model
58
- output_g_paths, output_d_paths = process_frame(frame_path, seed)
59
 
60
- # Append processed frames for final video compilation
61
- processed_frames_g.append(imageio.imread(output_g_paths[1]))
62
- processed_frames_d.append(imageio.imread(output_d_paths[1]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- # Update the live preview
65
- yield output_g_paths[1], output_d_paths[1], f"Processing frame {i+1}/{total_frames}..."
66
-
67
- # Compile final videos
68
- g_video_path = os.path.join(temp_dir, "output_g.mp4")
69
- d_video_path = os.path.join(temp_dir, "output_d.mp4")
70
-
71
- imageio.mimsave(g_video_path, processed_frames_g, fps=fps)
72
- imageio.mimsave(d_video_path, processed_frames_d, fps=fps)
73
-
74
- # Clean up temporary directory
75
- if os.path.exists(temp_dir):
76
  try:
77
- shutil.rmtree(temp_dir)
78
- except Exception as e:
79
- print(f"Error cleaning up temp directory: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- return g_video_path, d_video_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Gradio Interface
85
- with gr.Blocks() as demo:
86
- gr.Markdown("# Video Depth Estimation: Live Frame Processing and Video Compilation")
 
 
 
 
87
 
88
  with gr.Row():
89
  with gr.Column():
90
- video_input = gr.Video(
91
- label="Upload Video",
92
- interactive=True,
93
- show_label=True
94
- )
95
- seed_input = gr.Number(
96
- label="Seed",
97
- value=0,
98
- interactive=True
99
- )
100
- process_btn = gr.Button("Process Video")
101
 
102
  with gr.Column():
103
- live_preview_g = gr.Image(label="Live Preview (Generative)", show_label=True)
104
- live_preview_d = gr.Image(label="Live Preview (Discriminative)", show_label=True)
105
- status_text = gr.Textbox(label="Status", interactive=False)
106
- final_g_video = gr.Video(label="Final Generative Video")
107
- final_d_video = gr.Video(label="Final Discriminative Video")
108
-
109
- process_btn.click(
110
- fn=process_video_live,
111
- inputs=[video_input, seed_input],
112
- outputs=[live_preview_g, live_preview_d, status_text, final_g_video, final_d_video]
113
- )
114
 
115
- demo.launch(debug=True)
 
 
2
  import torch
3
  import os
4
  import tempfile
 
 
5
  import shutil
6
+ import time
7
+ import ffmpeg
8
+ import numpy as np
9
  from PIL import Image
10
  from concurrent.futures import ThreadPoolExecutor
11
+ import moviepy.editor as mp
12
+ from infer import lotus # Import the depth model inference function
13
 
14
+ # Set device to use the L40s GPU
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ # Add the preprocess_video function to limit video resolution and frame rate
18
+ def preprocess_video(video_path, target_fps=24, max_resolution=(1920, 1080)):
19
+ """Preprocess the video to resize and reduce its frame rate."""
20
+ video = mp.VideoFileClip(video_path)
 
 
21
 
22
+ # Resize video if it's larger than the target resolution
23
+ if video.size[0] > max_resolution[0] or video.size[1] > max_resolution[1]:
24
+ video = video.resize(newsize=max_resolution)
25
 
26
+ # Limit FPS
27
+ video = video.set_fps(target_fps)
 
28
 
29
+ return video
 
 
 
30
 
31
+ def process_frame(frame, seed=0):
32
+ """Process a single frame through the depth model and return depth map."""
33
+ try:
34
+ # Convert frame to PIL Image
35
+ image = Image.fromarray(frame)
36
+
37
+ # Save temporary image (lotus requires a file path)
38
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
39
+ image.save(tmp.name)
40
+
41
+ # Process through the depth model (lotus)
42
+ _, output_d = lotus(tmp.name, 'depth', seed, device)
43
+
44
+ # Clean up temp file
45
+ os.unlink(tmp.name)
46
+
47
+ # Convert depth output to numpy array
48
+ depth_array = np.array(output_d)
49
+ return depth_array
50
+
51
+ except Exception as e:
52
+ print(f"Error processing frame: {e}")
53
+ return None
54
 
55
+ @spaces.GPU
56
+ def process_video(video_path, fps=0, seed=0, max_workers=32):
57
+ """Process video, batch frames, and use L40s GPU to generate depth maps."""
58
+ temp_dir = None
59
+ try:
60
+ start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # Preprocess the video
63
+ video = preprocess_video(video_path)
64
 
65
+ # Use original video FPS if not specified
66
+ if fps == 0:
67
+ fps = video.fps
68
+
69
+ frames = list(video.iter_frames(fps=fps))
70
+ total_frames = len(frames)
71
+
72
+ print(f"Processing {total_frames} frames at {fps} FPS...")
73
+
74
+ # Create temporary directory for frame sequence
75
+ temp_dir = tempfile.mkdtemp()
76
+ frames_dir = os.path.join(temp_dir, "frames")
77
+ os.makedirs(frames_dir, exist_ok=True)
78
+
79
+ # Process frames in larger batches (based on GPU VRAM)
80
+ batch_size = 50 # Increased batch size to fully utilize the GPU's capabilities
81
+ processed_frames = []
82
+
83
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
84
+ for i in range(0, total_frames, batch_size):
85
+ futures = [executor.submit(process_frame, frames[j], seed) for j in range(i, min(i + batch_size, total_frames))]
86
+ for j, future in enumerate(futures):
87
+ try:
88
+ result = future.result()
89
+ if result is not None:
90
+ # Save frame
91
+ frame_path = os.path.join(frames_dir, f"frame_{i+j:06d}.png")
92
+ Image.fromarray(result).save(frame_path)
93
+
94
+ # Collect processed frame for preview
95
+ processed_frames.append(result)
96
+
97
+ # Update preview (only showing every 10th frame to avoid clutter)
98
+ if (i + j + 1) % 10 == 0:
99
+ elapsed_time = time.time() - start_time
100
+ yield processed_frames[-1], None, None, f"Processed {i+j+1}/{total_frames} frames... Elapsed: {elapsed_time:.2f}s"
101
+ except Exception as e:
102
+ print(f"Error processing frame {i + j + 1}: {e}")
103
+
104
+ print("Creating output files...")
105
+ # Create output directory
106
+ output_dir = os.path.join(os.path.dirname(video_path), "output")
107
+ os.makedirs(output_dir, exist_ok=True)
108
+
109
+ # Create ZIP of frame sequence
110
+ zip_filename = f"depth_frames_{int(time.time())}.zip"
111
+ zip_path = os.path.join(output_dir, zip_filename)
112
+ shutil.make_archive(zip_path[:-4], 'zip', frames_dir)
113
+
114
+ # Create MP4 video
115
+ video_filename = f"depth_video_{int(time.time())}.mp4"
116
+ video_path = os.path.join(output_dir, video_filename)
117
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  try:
119
+ # FFmpeg settings for high-quality MP4
120
+ stream = ffmpeg.input(
121
+ os.path.join(frames_dir, 'frame_%06d.png'),
122
+ pattern_type='sequence',
123
+ framerate=fps
124
+ )
125
+
126
+ stream = ffmpeg.output(
127
+ stream,
128
+ video_path,
129
+ vcodec='libx264',
130
+ pix_fmt='yuv420p',
131
+ crf=17, # High quality
132
+ threads=max_workers
133
+ )
134
+
135
+ ffmpeg.run(stream, overwrite_output=True, capture_stdout=True, capture_stderr=True)
136
+ print("MP4 video created successfully!")
137
+
138
+ except ffmpeg.Error as e:
139
+ print(f"Error creating video: {e.stderr.decode() if e.stderr else str(e)}")
140
+ video_path = None
141
 
142
+ print("Processing complete!")
143
+ yield None, zip_path, video_path, f"Processing complete! Total time: {time.time() - start_time:.2f} seconds"
144
+
145
+ except Exception as e:
146
+ print(f"Error: {e}")
147
+ yield None, None, None, f"Error processing video: {e}"
148
+ finally:
149
+ if temp_dir and os.path.exists(temp_dir):
150
+ try:
151
+ shutil.rmtree(temp_dir)
152
+ except Exception as e:
153
+ print(f"Error cleaning up temp directory: {e}")
154
+
155
+ def process_wrapper(video, fps=0, seed=0, max_workers=32):
156
+ if video is None:
157
+ raise gr.Error("Please upload a video.")
158
+ try:
159
+ outputs = []
160
+ for output in process_video(video, fps, seed, max_workers):
161
+ outputs.append(output)
162
+ yield output
163
+ return outputs[-1]
164
+ except Exception as e:
165
+ raise gr.Error(f"Error processing video: {str(e)}")
166
 
167
+ # Custom CSS for styling
168
+ custom_css = """
169
+ .title-container {
170
+ text-align: center;
171
+ padding: 10px 0;
172
+ }
173
+
174
+ #title {
175
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
176
+ font-size: 36px;
177
+ font-weight: bold;
178
+ color: #000000;
179
+ padding: 10px;
180
+ border-radius: 10px;
181
+ display: inline-block;
182
+ background: linear-gradient(
183
+ 135deg,
184
+ #e0f7fa, #e8f5e9, #fff9c4, #ffebee,
185
+ #f3e5f5, #e1f5fe, #fff3e0, #e8eaf6
186
+ );
187
+ background-size: 400% 400%;
188
+ animation: gradient-animation 15s ease infinite;
189
+ }
190
+
191
+ @keyframes gradient-animation {
192
+ 0% { background-position: 0% 50%; }
193
+ 50% { background-position: 100% 50%; }
194
+ 100% { background-position: 0% 50%; }
195
+ }
196
+ """
197
 
198
  # Gradio Interface
199
+ with gr.Blocks(css=custom_css) as demo:
200
+ gr.HTML('''
201
+ <div class="title-container">
202
+ <div id="title">Video Depth Estimation</div>
203
+ </div>
204
+ ''')
205
 
206
  with gr.Row():
207
  with gr.Column():
208
+ video_input = gr.Video(label="Upload Video", interactive=True, show_label=True)
209
+ fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS")
210
+ seed_slider = gr.Slider(minimum=0, maximum=999999999, step=1, value=0, label="Seed")
211
+ max_workers_slider = gr.Slider(minimum=1, maximum=32, step=1, value=32, label="Max Workers")
212
+ btn = gr.Button("Process Video", elem_id="submit-button")
 
 
 
 
 
 
213
 
214
  with gr.Column():
215
+ preview_image = gr.Image(label="Live Preview", show_label=True)
216
+ output_frames_zip = gr.File(label="Download Frame Sequence (ZIP)")
217
+ output_video = gr.File(label="Download Video (MP4)")
218
+ time_textbox = gr.Textbox(label="Status", interactive=False)
219
+
220
+ btn.click(fn=process_wrapper
221
+
222
+ , inputs=[video_input, fps_slider, seed_slider, max_workers_slider],
223
+ outputs=[preview_image, output_frames_zip, output_video, time_textbox])
224
+
225
+ demo.queue()
226
 
227
+ if __name__ == "__main__":
228
+ demo.launch(debug=True)