ghostsInTheMachine
commited on
Commit
•
ec60737
1
Parent(s):
e4ce1e7
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import spaces # Import the spaces module for ZeroGPU
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
import os
|
@@ -9,8 +8,9 @@ import ffmpeg
|
|
9 |
import numpy as np
|
10 |
from PIL import Image
|
11 |
import moviepy.editor as mp
|
12 |
-
from infer import lotus, load_models
|
13 |
import logging
|
|
|
14 |
|
15 |
# Set up logging
|
16 |
logging.basicConfig(level=logging.INFO)
|
@@ -19,9 +19,15 @@ logger = logging.getLogger(__name__)
|
|
19 |
# Device will be set inside GPU-decorated functions
|
20 |
device = 'cuda' # Use 'cuda' as placeholder
|
21 |
|
22 |
-
# Load
|
23 |
task_name = 'depth'
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
# Preprocess the video to adjust frame rate
|
27 |
def preprocess_video(video_path, target_fps=24):
|
@@ -58,7 +64,7 @@ def process_frame(frame, seed=0, target_size=(512, 512)):
|
|
58 |
input_image = resize_and_pad(image, target_size)
|
59 |
|
60 |
# Run inference
|
61 |
-
depth_map = lotus(input_image, 'depth', seed, device
|
62 |
|
63 |
# Crop the output depth map back to original image size
|
64 |
width, height = image.size
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import os
|
|
|
8 |
import numpy as np
|
9 |
from PIL import Image
|
10 |
import moviepy.editor as mp
|
11 |
+
from infer import lotus, load_models, pipe_g, pipe_d # Import the global models
|
12 |
import logging
|
13 |
+
import spaces # Import the spaces module for ZeroGPU
|
14 |
|
15 |
# Set up logging
|
16 |
logging.basicConfig(level=logging.INFO)
|
|
|
19 |
# Device will be set inside GPU-decorated functions
|
20 |
device = 'cuda' # Use 'cuda' as placeholder
|
21 |
|
22 |
+
# Load models once inside a GPU context
|
23 |
task_name = 'depth'
|
24 |
+
|
25 |
+
@spaces.GPU
|
26 |
+
def initialize_models():
|
27 |
+
load_models(task_name, device)
|
28 |
+
|
29 |
+
# Call the function to load models
|
30 |
+
initialize_models()
|
31 |
|
32 |
# Preprocess the video to adjust frame rate
|
33 |
def preprocess_video(video_path, target_fps=24):
|
|
|
64 |
input_image = resize_and_pad(image, target_size)
|
65 |
|
66 |
# Run inference
|
67 |
+
depth_map = lotus(input_image, 'depth', seed, device)
|
68 |
|
69 |
# Crop the output depth map back to original image size
|
70 |
width, height = image.size
|