ghostsInTheMachine commited on
Commit
ec60737
1 Parent(s): e4ce1e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import spaces # Import the spaces module for ZeroGPU
2
  import gradio as gr
3
  import torch
4
  import os
@@ -9,8 +8,9 @@ import ffmpeg
9
  import numpy as np
10
  from PIL import Image
11
  import moviepy.editor as mp
12
- from infer import lotus, load_models
13
  import logging
 
14
 
15
  # Set up logging
16
  logging.basicConfig(level=logging.INFO)
@@ -19,9 +19,15 @@ logger = logging.getLogger(__name__)
19
  # Device will be set inside GPU-decorated functions
20
  device = 'cuda' # Use 'cuda' as placeholder
21
 
22
- # Load model names and dtype
23
  task_name = 'depth'
24
- model_g, model_d, dtype = load_models(task_name, device)
 
 
 
 
 
 
25
 
26
  # Preprocess the video to adjust frame rate
27
  def preprocess_video(video_path, target_fps=24):
@@ -58,7 +64,7 @@ def process_frame(frame, seed=0, target_size=(512, 512)):
58
  input_image = resize_and_pad(image, target_size)
59
 
60
  # Run inference
61
- depth_map = lotus(input_image, 'depth', seed, device, model_g, model_d, dtype)
62
 
63
  # Crop the output depth map back to original image size
64
  width, height = image.size
 
 
1
  import gradio as gr
2
  import torch
3
  import os
 
8
  import numpy as np
9
  from PIL import Image
10
  import moviepy.editor as mp
11
+ from infer import lotus, load_models, pipe_g, pipe_d # Import the global models
12
  import logging
13
+ import spaces # Import the spaces module for ZeroGPU
14
 
15
  # Set up logging
16
  logging.basicConfig(level=logging.INFO)
 
19
  # Device will be set inside GPU-decorated functions
20
  device = 'cuda' # Use 'cuda' as placeholder
21
 
22
+ # Load models once inside a GPU context
23
  task_name = 'depth'
24
+
25
+ @spaces.GPU
26
+ def initialize_models():
27
+ load_models(task_name, device)
28
+
29
+ # Call the function to load models
30
+ initialize_models()
31
 
32
  # Preprocess the video to adjust frame rate
33
  def preprocess_video(video_path, target_fps=24):
 
64
  input_image = resize_and_pad(image, target_size)
65
 
66
  # Run inference
67
+ depth_map = lotus(input_image, 'depth', seed, device)
68
 
69
  # Crop the output depth map back to original image size
70
  width, height = image.size