ghostsInTheMachine
commited on
Commit
•
e4ce1e7
1
Parent(s):
3508f0c
Update infer.py
Browse files
infer.py
CHANGED
@@ -10,7 +10,13 @@ import spaces # Import the spaces module for ZeroGPU
|
|
10 |
|
11 |
check_min_version('0.28.0.dev0')
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
def load_models(task_name, device):
|
|
|
14 |
if task_name == 'depth':
|
15 |
model_g = 'jingheya/lotus-depth-g-v1-0'
|
16 |
model_d = 'jingheya/lotus-depth-d-v1-1'
|
@@ -19,10 +25,20 @@ def load_models(task_name, device):
|
|
19 |
model_d = 'jingheya/lotus-normal-d-v1-0'
|
20 |
|
21 |
dtype = torch.float16
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
@spaces.GPU
|
26 |
def infer_pipe(pipe, image, task_name, seed, device):
|
27 |
if seed is None:
|
28 |
generator = None
|
@@ -67,22 +83,7 @@ def infer_pipe(pipe, image, task_name, seed, device):
|
|
67 |
|
68 |
return output_color
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
# Load models inside the GPU-decorated function
|
73 |
-
pipe_g = LotusGPipeline.from_pretrained(
|
74 |
-
model_g,
|
75 |
-
torch_dtype=dtype,
|
76 |
-
)
|
77 |
-
pipe_d = LotusDPipeline.from_pretrained(
|
78 |
-
model_d,
|
79 |
-
torch_dtype=dtype,
|
80 |
-
)
|
81 |
-
pipe_g.to(device)
|
82 |
-
pipe_d.to(device)
|
83 |
-
pipe_g.set_progress_bar_config(disable=True)
|
84 |
-
pipe_d.set_progress_bar_config(disable=True)
|
85 |
-
logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
|
86 |
-
|
87 |
output_d = infer_pipe(pipe_d, image, task_name, seed, device)
|
88 |
return output_d # Only returning depth outputs for this application
|
|
|
10 |
|
11 |
check_min_version('0.28.0.dev0')
|
12 |
|
13 |
+
# Global variables to store the models
|
14 |
+
pipe_g = None
|
15 |
+
pipe_d = None
|
16 |
+
|
17 |
+
@spaces.GPU
|
18 |
def load_models(task_name, device):
|
19 |
+
global pipe_g, pipe_d # Use global variables to store the models
|
20 |
if task_name == 'depth':
|
21 |
model_g = 'jingheya/lotus-depth-g-v1-0'
|
22 |
model_d = 'jingheya/lotus-depth-d-v1-1'
|
|
|
25 |
model_d = 'jingheya/lotus-normal-d-v1-0'
|
26 |
|
27 |
dtype = torch.float16
|
28 |
+
pipe_g = LotusGPipeline.from_pretrained(
|
29 |
+
model_g,
|
30 |
+
torch_dtype=dtype,
|
31 |
+
)
|
32 |
+
pipe_d = LotusDPipeline.from_pretrained(
|
33 |
+
model_d,
|
34 |
+
torch_dtype=dtype,
|
35 |
+
)
|
36 |
+
pipe_g.to(device)
|
37 |
+
pipe_d.to(device)
|
38 |
+
pipe_g.set_progress_bar_config(disable=True)
|
39 |
+
pipe_d.set_progress_bar_config(disable=True)
|
40 |
+
logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
|
41 |
|
|
|
42 |
def infer_pipe(pipe, image, task_name, seed, device):
|
43 |
if seed is None:
|
44 |
generator = None
|
|
|
83 |
|
84 |
return output_color
|
85 |
|
86 |
+
def lotus(image, task_name, seed, device):
|
87 |
+
global pipe_g, pipe_d # Access the global models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
output_d = infer_pipe(pipe_d, image, task_name, seed, device)
|
89 |
return output_d # Only returning depth outputs for this application
|