depthanyvideo commited on
Commit
0297809
1 Parent(s): 47ac829
Files changed (2) hide show
  1. app.py +176 -143
  2. dav/utils/img_utils.py +27 -20
app.py CHANGED
@@ -1,10 +1,11 @@
1
- import gradio as gr
2
- import logging
3
  import os
 
 
4
  import random
5
  import tempfile
6
  import time
7
- import spaces
8
  from easydict import EasyDict
9
  import numpy as np
10
  import torch
@@ -24,11 +25,11 @@ def seed_all(seed: int = 0):
24
  torch.cuda.manual_seed_all(seed)
25
 
26
 
27
- # Initialize logging
28
- logging.basicConfig(level=logging.INFO)
 
29
 
30
 
31
- # Load models once to avoid reloading on every inference
32
  def load_models(model_base, device):
33
  vae = AutoencoderKLTemporalDecoder.from_pretrained(model_base, subfolder="vae")
34
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
@@ -50,146 +51,178 @@ def load_models(model_base, device):
50
  return pipe
51
 
52
 
53
- # Load models at startup
54
- MODEL_BASE = "hhyangcs/depth-any-video"
55
- DEVICE_TYPE = "cuda"
56
- DEVICE = torch.device(DEVICE_TYPE)
57
- pipe = load_models(MODEL_BASE, DEVICE)
58
 
59
 
60
  @spaces.GPU(duration=140)
61
- def depth_any_video(
62
- file,
63
- denoise_steps=3,
64
- num_frames=32,
65
- decode_chunk_size=16,
66
- num_interp_frames=16,
67
- num_overlap_frames=6,
68
- max_resolution=1024,
 
 
69
  ):
70
- """
71
- Perform depth estimation on the uploaded video/image.
72
- """
73
- with open(file, "rb") as _file:
74
- with tempfile.TemporaryDirectory() as tmp_dir:
75
- # Save the uploaded file
76
- input_path = os.path.join(tmp_dir, file.name)
77
- with open(input_path, "wb") as f:
78
- f.write(_file.read())
79
-
80
- # Set up output directory
81
- output_dir = os.path.join(tmp_dir, "output")
82
- os.makedirs(output_dir, exist_ok=True)
83
-
84
- # Prepare configuration
85
- cfg = EasyDict(
86
- {
87
- "model_base": MODEL_BASE,
88
- "data_path": input_path,
89
- "output_dir": output_dir,
90
- "denoise_steps": denoise_steps,
91
- "num_frames": num_frames,
92
- "decode_chunk_size": decode_chunk_size,
93
- "num_interp_frames": num_interp_frames,
94
- "num_overlap_frames": num_overlap_frames,
95
- "max_resolution": max_resolution,
96
- "seed": 666,
97
- }
98
- )
99
-
100
- seed_all(cfg.seed)
101
-
102
- file_name = os.path.splitext(os.path.basename(cfg.data_path))[0]
103
- is_video = cfg.data_path.lower().endswith((".mp4", ".avi", ".mov", ".mkv"))
104
-
105
- if is_video:
106
- num_interp_frames = cfg.num_interp_frames
107
- num_overlap_frames = cfg.num_overlap_frames
108
- num_frames = cfg.num_frames
109
- assert num_frames % 2 == 0, "num_frames should be even."
110
- assert (
111
- 2 <= num_overlap_frames <= (num_interp_frames + 2 + 1) // 2
112
- ), "Invalid frame overlap."
113
- max_frames = (num_interp_frames + 2 - num_overlap_frames) * (
114
- num_frames // 2
115
- )
116
- image, fps = img_utils.read_video(cfg.data_path, max_frames=max_frames)
117
- else:
118
- image = img_utils.read_image(cfg.data_path)
119
-
120
- image = img_utils.imresize_max(image, cfg.max_resolution)
121
- image = img_utils.imcrop_multi(image)
122
- image_tensor = np.ascontiguousarray(
123
- [_img.transpose(2, 0, 1) / 255.0 for _img in image]
124
- )
125
- image_tensor = torch.from_numpy(image_tensor).to(DEVICE)
126
-
127
- with torch.no_grad(), torch.autocast(
128
- device_type=DEVICE_TYPE, dtype=torch.float16
129
- ):
130
- pipe_out = pipe(
131
- image_tensor,
132
- num_frames=cfg.num_frames,
133
- num_overlap_frames=cfg.num_overlap_frames,
134
- num_interp_frames=cfg.num_interp_frames,
135
- decode_chunk_size=cfg.decode_chunk_size,
136
- num_inference_steps=cfg.denoise_steps,
137
- )
138
-
139
- disparity = pipe_out.disparity
140
- disparity_colored = pipe_out.disparity_colored
141
- image = pipe_out.image
142
- # (N, H, 2 * W, 3)
143
- merged = np.concatenate(
144
- [
145
- image,
146
- disparity_colored,
147
- ],
148
- axis=2,
149
- )
150
-
151
- if is_video:
152
- output_path = os.path.join(cfg.output_dir, f"{file_name}_depth.mp4")
153
- img_utils.write_video(
154
- output_path,
155
- merged,
156
- fps,
157
- )
158
- return output_path
159
- else:
160
- output_path = os.path.join(cfg.output_dir, f"{file_name}_depth.png")
161
- img_utils.write_image(
162
- output_path,
163
- merged[0],
164
- )
165
- return output_path
166
-
167
-
168
- # Define Gradio interface
169
- title = "Depth Any Video with Scalable Synthetic Data"
170
- description = """
171
- Upload a video or image to perform depth estimation using the Depth Any Video model.
172
- Adjust the parameters as needed to control the inference process.
173
- """
174
-
175
- iface = gr.Interface(
176
- fn=depth_any_video,
177
- inputs=[
178
- gr.File(label="Upload Video/Image"),
179
- gr.Slider(1, 10, step=1, value=3, label="Denoise Steps"),
180
- gr.Slider(16, 64, step=1, value=32, label="Number of Frames"),
181
- gr.Slider(8, 32, step=1, value=16, label="Decode Chunk Size"),
182
- gr.Slider(8, 32, step=1, value=16, label="Number of Interpolation Frames"),
183
- gr.Slider(2, 10, step=1, value=6, label="Number of Overlap Frames"),
184
- gr.Slider(512, 2048, step=32, value=1024, label="Maximum Resolution"),
185
- ],
186
- outputs=gr.Video(label="Depth Enhanced Video/Image"),
187
- title=title,
188
- description=description,
189
- examples=[["demos/arch_2.jpg"], ["demos/wooly_mammoth.mp4"]],
190
- allow_flagging="never",
191
- analytics_enabled=False,
192
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  if __name__ == "__main__":
195
- iface.launch(share=True)
 
 
1
+ import gc
 
2
  import os
3
+ import spaces
4
+ import gradio as gr
5
  import random
6
  import tempfile
7
  import time
8
+
9
  from easydict import EasyDict
10
  import numpy as np
11
  import torch
 
25
  torch.cuda.manual_seed_all(seed)
26
 
27
 
28
+ examples = [
29
+ ["demos/wooly_mammoth.mp4", 3, 32, 16, 16, 6, 960],
30
+ ]
31
 
32
 
 
33
  def load_models(model_base, device):
34
  vae = AutoencoderKLTemporalDecoder.from_pretrained(model_base, subfolder="vae")
35
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
 
51
  return pipe
52
 
53
 
54
+ model_base = "hhyangcs/depth-any-video"
55
+ device_type = "cuda"
56
+ device = torch.device(device_type)
57
+ pipe = load_models(model_base, device)
 
58
 
59
 
60
  @spaces.GPU(duration=140)
61
+ def infer_depth(
62
+ file: str,
63
+ denoise_steps: int = 3,
64
+ num_frames: int = 32,
65
+ decode_chunk_size: int = 16,
66
+ num_interp_frames: int = 16,
67
+ num_overlap_frames: int = 6,
68
+ max_resolution: int = 1024,
69
+ seed: int = 66,
70
+ output_dir: str = "./outputs",
71
  ):
72
+ seed_all(seed)
73
+
74
+ max_frames = (num_interp_frames + 2 - num_overlap_frames) * (num_frames // 2)
75
+ image, fps = img_utils.read_video(file, max_frames=max_frames)
76
+
77
+ image = img_utils.imresize_max(image, max_resolution)
78
+ image = img_utils.imcrop_multi(image)
79
+ image_tensor = np.ascontiguousarray(
80
+ [_img.transpose(2, 0, 1) / 255.0 for _img in image]
81
+ )
82
+ image_tensor = torch.from_numpy(image_tensor).to(device)
83
+ print(f"==> video name: {file}, frames shape: {image_tensor.shape}")
84
+
85
+ with torch.no_grad(), torch.autocast(device_type=device_type, dtype=torch.float16):
86
+ pipe_out = pipe(
87
+ image_tensor,
88
+ num_frames=num_frames,
89
+ num_overlap_frames=num_overlap_frames,
90
+ num_interp_frames=num_interp_frames,
91
+ decode_chunk_size=decode_chunk_size,
92
+ num_inference_steps=denoise_steps,
93
+ )
94
+
95
+ disparity = pipe_out.disparity
96
+ disparity_colored = pipe_out.disparity_colored
97
+ image = pipe_out.image
98
+ # (N, H, 2 * W, 3)
99
+ merged = np.concatenate(
100
+ [
101
+ image,
102
+ disparity_colored,
103
+ ],
104
+ axis=2,
105
+ )
106
+
107
+ file_name = os.path.splitext(os.path.basename(file))[0]
108
+ os.makedirs(output_dir, exist_ok=True)
109
+ output_path = os.path.join(output_dir, f"{file_name}_depth.mp4")
110
+ img_utils.write_video(
111
+ output_path,
112
+ merged,
113
+ fps,
114
+ )
115
+
116
+ # clear the cache for the next video
117
+ gc.collect()
118
+ torch.cuda.empty_cache()
119
+
120
+ return output_path
121
+
122
+
123
+ def construct_demo():
124
+ with gr.Blocks(analytics_enabled=False) as depthanyvideo_iface:
125
+
126
+ with gr.Row(equal_height=True):
127
+ with gr.Column(scale=1):
128
+ input_video = gr.Video(label="Input Video")
129
+
130
+ with gr.Column(scale=1):
131
+ with gr.Row(equal_height=True):
132
+ output_video = gr.Video(
133
+ label="Ouput Video Depth",
134
+ interactive=False,
135
+ autoplay=True,
136
+ loop=True,
137
+ show_share_button=True,
138
+ scale=1,
139
+ )
140
+
141
+ with gr.Row(equal_height=True):
142
+ with gr.Column(scale=1):
143
+ with gr.Row(equal_height=False):
144
+ with gr.Accordion("Advanced Settings", open=False):
145
+ denoise_steps = gr.Slider(
146
+ label="Denoise Steps",
147
+ minimum=1,
148
+ maximum=10,
149
+ value=3,
150
+ step=1,
151
+ )
152
+ num_frames = gr.Slider(
153
+ label="Number of Key Frames",
154
+ minimum=16,
155
+ maximum=32,
156
+ value=24,
157
+ step=2,
158
+ )
159
+ decode_chunk_size = gr.Slider(
160
+ label="Decode Chunk Size",
161
+ minimum=8,
162
+ maximum=32,
163
+ value=16,
164
+ step=1,
165
+ )
166
+ num_interp_frames = gr.Slider(
167
+ label="Number of Interpolation Frames",
168
+ minimum=8,
169
+ maximum=32,
170
+ value=16,
171
+ step=1,
172
+ )
173
+ num_overlap_frames = gr.Slider(
174
+ label="Number of Overlap Frames",
175
+ minimum=2,
176
+ maximum=10,
177
+ value=6,
178
+ step=1,
179
+ )
180
+ max_resolution = gr.Slider(
181
+ label="Maximum Resolution",
182
+ minimum=512,
183
+ maximum=2048,
184
+ value=1024,
185
+ step=32,
186
+ )
187
+ generate_btn = gr.Button("Generate")
188
+ with gr.Column(scale=2):
189
+ pass
190
+
191
+ gr.Examples(
192
+ examples=examples,
193
+ inputs=[
194
+ input_video,
195
+ denoise_steps,
196
+ num_frames,
197
+ decode_chunk_size,
198
+ num_interp_frames,
199
+ num_overlap_frames,
200
+ max_resolution,
201
+ ],
202
+ outputs=output_video,
203
+ fn=infer_depth,
204
+ cache_examples="lazy",
205
+ )
206
+
207
+ generate_btn.click(
208
+ fn=infer_depth,
209
+ inputs=[
210
+ input_video,
211
+ denoise_steps,
212
+ num_frames,
213
+ decode_chunk_size,
214
+ num_interp_frames,
215
+ num_overlap_frames,
216
+ max_resolution,
217
+ ],
218
+ outputs=output_video,
219
+ )
220
+
221
+ return depthanyvideo_iface
222
+
223
+
224
+ demo = construct_demo()
225
 
226
  if __name__ == "__main__":
227
+ demo.queue()
228
+ demo.launch(share=True)
dav/utils/img_utils.py CHANGED
@@ -85,26 +85,33 @@ def read_image(image_path):
85
 
86
 
87
  def write_video(video_path, frames, fps):
88
- tmp_dir = os.path.join(os.path.dirname(video_path), "tmp")
89
- os.makedirs(tmp_dir, exist_ok=True)
90
- for i, frame in enumerate(frames):
91
- write_image(os.path.join(tmp_dir, f"{i:06d}.png"), frame)
92
- # it will cause visual compression artifacts
93
- ffmpeg_command = [
94
- "ffmpeg",
95
- "-f",
96
- "image2",
97
- "-framerate",
98
- f"{fps}",
99
- "-i",
100
- os.path.join(tmp_dir, "%06d.png"),
101
- "-b:v",
102
- "5626k",
103
- "-y",
104
- video_path,
105
- ]
106
- os.system(" ".join(ffmpeg_command))
107
- os.system(f"rm -rf {tmp_dir}")
 
 
 
 
 
 
 
108
 
109
 
110
  def write_image(image_path, frame):
 
85
 
86
 
87
  def write_video(video_path, frames, fps):
88
+ # tmp_dir = os.path.join(os.path.dirname(video_path), "tmp")
89
+ # os.makedirs(tmp_dir, exist_ok=True)
90
+ # for i, frame in enumerate(frames):
91
+ # write_image(os.path.join(tmp_dir, f"{i:06d}.png"), frame)
92
+ # # it will cause visual compression artifacts
93
+ # ffmpeg_command = [
94
+ # "ffmpeg",
95
+ # "-f",
96
+ # "image2",
97
+ # "-framerate",
98
+ # f"{fps}",
99
+ # "-i",
100
+ # os.path.join(tmp_dir, "%06d.png"),
101
+ # "-b:v",
102
+ # "5626k",
103
+ # "-y",
104
+ # video_path,
105
+ # ]
106
+ # os.system(" ".join(ffmpeg_command))
107
+ # os.system(f"rm -rf {tmp_dir}")
108
+ h, w = frames[0].shape[:2]
109
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
110
+ out = cv2.VideoWriter(video_path, fourcc, fps, (w, h))
111
+ for frame in frames:
112
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
113
+ out.write(frame)
114
+ out.release()
115
 
116
 
117
  def write_image(image_path, frame):