ghostsInTheMachine
commited on
Commit
•
c71b96e
1
Parent(s):
afe7cc3
Update infer.py
Browse files
infer.py
CHANGED
@@ -1,71 +1,18 @@
|
|
1 |
-
# from utils.args import parse_args
|
2 |
import logging
|
3 |
-
import os
|
4 |
-
import argparse
|
5 |
-
from pathlib import Path
|
6 |
-
from PIL import Image
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
import torch
|
10 |
-
|
|
|
11 |
from diffusers.utils import check_min_version
|
12 |
-
|
13 |
from pipeline import LotusGPipeline, LotusDPipeline
|
14 |
from utils.image_utils import colorize_depth_map
|
15 |
-
from utils.seed_all import seed_all
|
16 |
-
|
17 |
from contextlib import nullcontext
|
18 |
-
import cv2
|
19 |
|
20 |
check_min_version('0.28.0.dev0')
|
21 |
|
22 |
-
def
|
23 |
-
if seed is None:
|
24 |
-
generator = None
|
25 |
-
else:
|
26 |
-
generator = torch.Generator(device=device).manual_seed(seed)
|
27 |
-
|
28 |
-
if torch.backends.mps.is_available():
|
29 |
-
autocast_ctx = nullcontext()
|
30 |
-
else:
|
31 |
-
autocast_ctx = torch.autocast(pipe.device.type)
|
32 |
-
with autocast_ctx:
|
33 |
-
|
34 |
-
test_image = Image.open(image_input).convert('RGB')
|
35 |
-
test_image = np.array(test_image).astype(np.float16)
|
36 |
-
test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
|
37 |
-
test_image = test_image / 127.5 - 1.0
|
38 |
-
test_image = test_image.to(device)
|
39 |
-
|
40 |
-
task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device)
|
41 |
-
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)
|
42 |
-
|
43 |
-
# Run
|
44 |
-
pred = pipe(
|
45 |
-
rgb_in=test_image,
|
46 |
-
prompt='',
|
47 |
-
num_inference_steps=1,
|
48 |
-
generator=generator,
|
49 |
-
# guidance_scale=0,
|
50 |
-
output_type='np',
|
51 |
-
timesteps=[999],
|
52 |
-
task_emb=task_emb,
|
53 |
-
).images[0]
|
54 |
-
|
55 |
-
# Post-process the prediction
|
56 |
-
if task_name == 'depth':
|
57 |
-
output_npy = pred.mean(axis=-1)
|
58 |
-
output_color = colorize_depth_map(output_npy)
|
59 |
-
else:
|
60 |
-
output_npy = pred
|
61 |
-
output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
|
62 |
-
|
63 |
-
return output_color
|
64 |
-
|
65 |
-
def lotus_video(input_video, task_name, seed, device):
|
66 |
if task_name == 'depth':
|
67 |
model_g = 'jingheya/lotus-depth-g-v1-0'
|
68 |
-
model_d = 'jingheya/lotus-depth-d-v1-
|
69 |
else:
|
70 |
model_g = 'jingheya/lotus-normal-g-v1-0'
|
71 |
model_d = 'jingheya/lotus-normal-d-v1-0'
|
@@ -83,268 +30,57 @@ def lotus_video(input_video, task_name, seed, device):
|
|
83 |
pipe_d.to(device)
|
84 |
pipe_g.set_progress_bar_config(disable=True)
|
85 |
pipe_d.set_progress_bar_config(disable=True)
|
86 |
-
logging.info(f"Successfully
|
87 |
-
|
88 |
-
# load the video and split it into frames
|
89 |
-
cap = cv2.VideoCapture(input_video)
|
90 |
-
frames = []
|
91 |
-
while True:
|
92 |
-
ret, frame = cap.read()
|
93 |
-
if not ret:
|
94 |
-
break
|
95 |
-
frames.append(frame)
|
96 |
-
cap.release()
|
97 |
-
logging.info(f"There are {len(frames)} frames in the video.")
|
98 |
|
|
|
99 |
if seed is None:
|
100 |
generator = None
|
101 |
else:
|
102 |
generator = torch.Generator(device=device).manual_seed(seed)
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
output_g = []
|
108 |
-
output_d = []
|
109 |
-
for frame in frames:
|
110 |
-
if torch.backends.mps.is_available():
|
111 |
-
autocast_ctx = nullcontext()
|
112 |
-
else:
|
113 |
-
autocast_ctx = torch.autocast(pipe_g.device.type)
|
114 |
-
with autocast_ctx:
|
115 |
-
test_image = frame
|
116 |
-
test_image = np.array(test_image).astype(np.float16)
|
117 |
-
test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
|
118 |
-
test_image = test_image / 127.5 - 1.0
|
119 |
-
test_image = test_image.to(device)
|
120 |
-
|
121 |
-
# Run
|
122 |
-
pred_g = pipe_g(
|
123 |
-
rgb_in=test_image,
|
124 |
-
prompt='',
|
125 |
-
num_inference_steps=1,
|
126 |
-
generator=generator,
|
127 |
-
# guidance_scale=0,
|
128 |
-
output_type='np',
|
129 |
-
timesteps=[999],
|
130 |
-
task_emb=task_emb,
|
131 |
-
).images[0]
|
132 |
-
pred_d = pipe_d(
|
133 |
-
rgb_in=test_image,
|
134 |
-
prompt='',
|
135 |
-
num_inference_steps=1,
|
136 |
-
generator=generator,
|
137 |
-
# guidance_scale=0,
|
138 |
-
output_type='np',
|
139 |
-
timesteps=[999],
|
140 |
-
task_emb=task_emb,
|
141 |
-
).images[0]
|
142 |
-
|
143 |
-
# Post-process the prediction
|
144 |
-
if task_name == 'depth':
|
145 |
-
output_npy_g = pred_g.mean(axis=-1)
|
146 |
-
output_color_g = colorize_depth_map(output_npy_g)
|
147 |
-
output_npy_d = pred_d.mean(axis=-1)
|
148 |
-
output_color_d = colorize_depth_map(output_npy_d)
|
149 |
-
else:
|
150 |
-
output_npy_g = pred_g
|
151 |
-
output_color_g = Image.fromarray((output_npy_g * 255).astype(np.uint8))
|
152 |
-
output_npy_d = pred_d
|
153 |
-
output_color_d = Image.fromarray((output_npy_d * 255).astype(np.uint8))
|
154 |
-
|
155 |
-
output_g.append(output_color_g)
|
156 |
-
output_d.append(output_color_d)
|
157 |
-
|
158 |
-
return output_g, output_d
|
159 |
-
|
160 |
-
def lotus(image_input, task_name, seed, device):
|
161 |
-
if task_name == 'depth':
|
162 |
-
model_g = 'jingheya/lotus-depth-g-v1-0'
|
163 |
-
model_d = 'jingheya/lotus-depth-d-v1-1'
|
164 |
-
else:
|
165 |
-
model_g = 'jingheya/lotus-normal-g-v1-0'
|
166 |
-
model_d = 'jingheya/lotus-normal-d-v1-0'
|
167 |
-
|
168 |
-
dtype = torch.float16
|
169 |
-
pipe_g = LotusGPipeline.from_pretrained(
|
170 |
-
model_g,
|
171 |
-
torch_dtype=dtype,
|
172 |
-
)
|
173 |
-
pipe_d = LotusDPipeline.from_pretrained(
|
174 |
-
model_d,
|
175 |
-
torch_dtype=dtype,
|
176 |
-
)
|
177 |
-
pipe_g.to(device)
|
178 |
-
pipe_d.to(device)
|
179 |
-
pipe_g.set_progress_bar_config(disable=True)
|
180 |
-
pipe_d.set_progress_bar_config(disable=True)
|
181 |
-
logging.info(f"Successfully loading pipeline from {model_g} and {model_d}.")
|
182 |
-
output_g = infer_pipe(pipe_g, image_input, task_name, seed, device)
|
183 |
-
output_d = infer_pipe(pipe_d, image_input, task_name, seed, device)
|
184 |
-
return output_g, output_d
|
185 |
-
|
186 |
-
def parse_args():
|
187 |
-
'''Set the Args'''
|
188 |
-
parser = argparse.ArgumentParser(
|
189 |
-
description="Run Lotus..."
|
190 |
-
)
|
191 |
-
# model settings
|
192 |
-
parser.add_argument(
|
193 |
-
"--pretrained_model_name_or_path",
|
194 |
-
type=str,
|
195 |
-
default=None,
|
196 |
-
help="pretrained model path from hugging face or local dir",
|
197 |
-
)
|
198 |
-
parser.add_argument(
|
199 |
-
"--prediction_type",
|
200 |
-
type=str,
|
201 |
-
default="sample",
|
202 |
-
help="The used prediction_type. ",
|
203 |
-
)
|
204 |
-
parser.add_argument(
|
205 |
-
"--timestep",
|
206 |
-
type=int,
|
207 |
-
default=999,
|
208 |
-
)
|
209 |
-
parser.add_argument(
|
210 |
-
"--mode",
|
211 |
-
type=str,
|
212 |
-
default="regression", # "generation"
|
213 |
-
help="Whether to use the generation or regression pipeline."
|
214 |
-
)
|
215 |
-
parser.add_argument(
|
216 |
-
"--task_name",
|
217 |
-
type=str,
|
218 |
-
default="depth", # "normal"
|
219 |
-
)
|
220 |
-
parser.add_argument(
|
221 |
-
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
222 |
-
)
|
223 |
-
|
224 |
-
# inference settings
|
225 |
-
parser.add_argument("--seed", type=int, default=None, help="Random seed.")
|
226 |
-
parser.add_argument(
|
227 |
-
"--output_dir", type=str, required=True, help="Output directory."
|
228 |
-
)
|
229 |
-
parser.add_argument(
|
230 |
-
"--input_dir", type=str, required=True, help="Input directory."
|
231 |
-
)
|
232 |
-
parser.add_argument(
|
233 |
-
"--half_precision",
|
234 |
-
action="store_true",
|
235 |
-
help="Run with half-precision (16-bit float), might lead to suboptimal result.",
|
236 |
-
)
|
237 |
-
|
238 |
-
args = parser.parse_args()
|
239 |
-
|
240 |
-
return args
|
241 |
-
|
242 |
-
def main():
|
243 |
-
logging.basicConfig(level=logging.INFO)
|
244 |
-
logging.info(f"Run inference...")
|
245 |
-
|
246 |
-
args = parse_args()
|
247 |
-
|
248 |
-
# -------------------- Preparation --------------------
|
249 |
-
# Random seed
|
250 |
-
if args.seed is not None:
|
251 |
-
seed_all(args.seed)
|
252 |
-
|
253 |
-
# Output directories
|
254 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
255 |
-
logging.info(f"Output dir = {args.output_dir}")
|
256 |
-
|
257 |
-
output_dir_color = os.path.join(args.output_dir, f'{args.task_name}_vis')
|
258 |
-
output_dir_npy = os.path.join(args.output_dir, f'{args.task_name}')
|
259 |
-
if not os.path.exists(output_dir_color): os.makedirs(output_dir_color)
|
260 |
-
if not os.path.exists(output_dir_npy): os.makedirs(output_dir_npy)
|
261 |
-
|
262 |
-
# half_precision
|
263 |
-
if args.half_precision:
|
264 |
-
dtype = torch.float16
|
265 |
-
logging.info(f"Running with half precision ({dtype}).")
|
266 |
-
else:
|
267 |
-
dtype = torch.float16
|
268 |
-
|
269 |
-
# -------------------- Device --------------------
|
270 |
-
if torch.cuda.is_available():
|
271 |
-
device = torch.device("cuda")
|
272 |
-
else:
|
273 |
-
device = torch.device("cpu")
|
274 |
-
logging.warning("CUDA is not available. Running on CPU will be slow.")
|
275 |
-
logging.info(f"Device = {device}")
|
276 |
-
|
277 |
-
# -------------------- Data --------------------
|
278 |
-
root_dir = Path(args.input_dir)
|
279 |
-
test_images = list(root_dir.rglob('*.png')) + list(root_dir.rglob('*.jpg'))
|
280 |
-
test_images = sorted(test_images)
|
281 |
-
print('==> There are', len(test_images), 'images for validation.')
|
282 |
-
# -------------------- Model --------------------
|
283 |
-
|
284 |
-
if args.mode == 'generation':
|
285 |
-
pipeline = LotusGPipeline.from_pretrained(
|
286 |
-
args.pretrained_model_name_or_path,
|
287 |
-
torch_dtype=dtype,
|
288 |
-
)
|
289 |
-
elif args.mode == 'regression':
|
290 |
-
pipeline = LotusDPipeline.from_pretrained(
|
291 |
-
args.pretrained_model_name_or_path,
|
292 |
-
torch_dtype=dtype,
|
293 |
-
)
|
294 |
-
else:
|
295 |
-
raise ValueError(f'Invalid mode: {args.mode}')
|
296 |
-
logging.info(f"Successfully loading pipeline from {args.pretrained_model_name_or_path}.")
|
297 |
-
|
298 |
-
pipeline = pipeline.to(device)
|
299 |
-
pipeline.set_progress_bar_config(disable=True)
|
300 |
-
|
301 |
-
if args.enable_xformers_memory_efficient_attention:
|
302 |
-
pipeline.enable_xformers_memory_efficient_attention()
|
303 |
-
|
304 |
-
|
305 |
-
if args.seed is None:
|
306 |
-
generator = None
|
307 |
else:
|
308 |
-
|
309 |
-
|
310 |
-
# -------------------- Inference and saving --------------------
|
311 |
-
with torch.no_grad():
|
312 |
-
for i in tqdm(range(len(test_images))):
|
313 |
-
# Preprocess validation image
|
314 |
-
test_image = Image.open(test_images[i]).convert('RGB')
|
315 |
-
test_image = np.array(test_image).astype(np.float16)
|
316 |
-
test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
|
317 |
-
test_image = test_image / 127.5 - 1.0
|
318 |
-
test_image = test_image.to(device)
|
319 |
-
|
320 |
-
task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device)
|
321 |
-
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)
|
322 |
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
|
|
339 |
output_color = colorize_depth_map(output_npy)
|
340 |
-
|
341 |
-
|
|
|
|
|
342 |
output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
|
|
|
343 |
|
344 |
-
|
345 |
-
np.save(os.path.join(output_dir_npy, f'{save_file_name}.npy'), output_npy)
|
346 |
-
|
347 |
-
print('==> Inference is done. \n==> Results saved to:', args.output_dir)
|
348 |
|
349 |
-
|
350 |
-
|
|
|
|
|
|
1 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
from diffusers.utils import check_min_version
|
|
|
6 |
from pipeline import LotusGPipeline, LotusDPipeline
|
7 |
from utils.image_utils import colorize_depth_map
|
|
|
|
|
8 |
from contextlib import nullcontext
|
|
|
9 |
|
10 |
check_min_version('0.28.0.dev0')
|
11 |
|
12 |
+
def load_models(task_name, device):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
if task_name == 'depth':
|
14 |
model_g = 'jingheya/lotus-depth-g-v1-0'
|
15 |
+
model_d = 'jingheya/lotus-depth-d-v1-1'
|
16 |
else:
|
17 |
model_g = 'jingheya/lotus-normal-g-v1-0'
|
18 |
model_d = 'jingheya/lotus-normal-d-v1-0'
|
|
|
30 |
pipe_d.to(device)
|
31 |
pipe_g.set_progress_bar_config(disable=True)
|
32 |
pipe_d.set_progress_bar_config(disable=True)
|
33 |
+
logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
|
34 |
+
return pipe_g, pipe_d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
def infer_pipe(pipe, images_batch, task_name, seed, device):
|
37 |
if seed is None:
|
38 |
generator = None
|
39 |
else:
|
40 |
generator = torch.Generator(device=device).manual_seed(seed)
|
41 |
|
42 |
+
if torch.backends.mps.is_available():
|
43 |
+
autocast_ctx = nullcontext()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
else:
|
45 |
+
autocast_ctx = torch.autocast(pipe.device.type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
+
with autocast_ctx:
|
48 |
+
# Convert list of images to tensor
|
49 |
+
images = [np.array(img.convert('RGB')).astype(np.float16) for img in images_batch]
|
50 |
+
test_images = torch.stack([torch.tensor(img).permute(2, 0, 1) for img in images])
|
51 |
+
test_images = test_images / 127.5 - 1.0
|
52 |
+
test_images = test_images.to(device)
|
53 |
+
|
54 |
+
task_emb = torch.tensor([1, 0]).float().unsqueeze(0).to(device)
|
55 |
+
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
|
56 |
+
task_emb = task_emb.repeat(len(test_images), 1)
|
57 |
+
|
58 |
+
# Run inference
|
59 |
+
preds = pipe(
|
60 |
+
rgb_in=test_images,
|
61 |
+
prompt='',
|
62 |
+
num_inference_steps=1,
|
63 |
+
generator=generator,
|
64 |
+
output_type='np',
|
65 |
+
timesteps=[999],
|
66 |
+
task_emb=task_emb,
|
67 |
+
).images
|
68 |
|
69 |
+
# Post-process predictions
|
70 |
+
outputs = []
|
71 |
+
if task_name == 'depth':
|
72 |
+
for p in preds:
|
73 |
+
output_npy = p.mean(axis=-1)
|
74 |
output_color = colorize_depth_map(output_npy)
|
75 |
+
outputs.append(output_color)
|
76 |
+
else:
|
77 |
+
for p in preds:
|
78 |
+
output_npy = p
|
79 |
output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
|
80 |
+
outputs.append(output_color)
|
81 |
|
82 |
+
return outputs
|
|
|
|
|
|
|
83 |
|
84 |
+
def lotus(images_batch, task_name, seed, device, pipe_g, pipe_d):
|
85 |
+
output_d = infer_pipe(pipe_d, images_batch, task_name, seed, device)
|
86 |
+
return output_d # Only returning depth outputs for this application
|