Spaces:
Runtime error
Runtime error
convert to CPU
Browse files- annotator/midas/__init__.py +2 -2
- app.py +109 -110
- ckpt/cldm_v15.yaml +2 -0
- requirements.txt +1 -0
- stablevideo/atlas_data.py +1 -1
- stablevideo/atlas_utils.py +1 -1
annotator/midas/__init__.py
CHANGED
@@ -8,13 +8,13 @@ from .api import MiDaSInference
|
|
8 |
|
9 |
class MidasDetector:
|
10 |
def __init__(self):
|
11 |
-
self.model = MiDaSInference(model_type="dpt_hybrid")
|
12 |
|
13 |
def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
|
14 |
assert input_image.ndim == 3
|
15 |
image_depth = input_image
|
16 |
with torch.no_grad():
|
17 |
-
image_depth = torch.from_numpy(image_depth).float()
|
18 |
image_depth = image_depth / 127.5 - 1.0
|
19 |
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
20 |
depth = self.model(image_depth)[0]
|
|
|
8 |
|
9 |
class MidasDetector:
|
10 |
def __init__(self):
|
11 |
+
self.model = MiDaSInference(model_type="dpt_hybrid")
|
12 |
|
13 |
def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
|
14 |
assert input_image.ndim == 3
|
15 |
image_depth = input_image
|
16 |
with torch.no_grad():
|
17 |
+
image_depth = torch.from_numpy(image_depth).float()
|
18 |
image_depth = image_depth / 127.5 - 1.0
|
19 |
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
20 |
depth = self.model(image_depth)[0]
|
app.py
CHANGED
@@ -48,7 +48,7 @@ class StableVideo:
|
|
48 |
):
|
49 |
self.apply_canny = CannyDetector()
|
50 |
canny_model = create_model(base_cfg).cpu()
|
51 |
-
canny_model.load_state_dict(load_state_dict(canny_model_cfg, location='
|
52 |
self.canny_ddim_sampler = DDIMSampler(canny_model)
|
53 |
self.canny_model = canny_model
|
54 |
|
@@ -59,7 +59,7 @@ class StableVideo:
|
|
59 |
):
|
60 |
self.apply_midas = MidasDetector()
|
61 |
depth_model = create_model(base_cfg).cpu()
|
62 |
-
depth_model.load_state_dict(load_state_dict(depth_model_cfg, location='
|
63 |
self.depth_ddim_sampler = DDIMSampler(depth_model)
|
64 |
self.depth_model = depth_model
|
65 |
|
@@ -101,7 +101,7 @@ class StableVideo:
|
|
101 |
|
102 |
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
103 |
|
104 |
-
control = torch.from_numpy(detected_map.copy()).float()
|
105 |
control = torch.stack([control for _ in range(1)], dim=0)
|
106 |
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
107 |
|
@@ -128,7 +128,7 @@ class StableVideo:
|
|
128 |
|
129 |
@torch.no_grad()
|
130 |
def edit_background(self, *args, **kwargs):
|
131 |
-
self.depth_model = self.depth_model
|
132 |
|
133 |
input_image = self.b_atlas_origin
|
134 |
self.depth_edit(input_image, *args, **kwargs)
|
@@ -155,7 +155,7 @@ class StableVideo:
|
|
155 |
if_net=False,
|
156 |
num_samples=1):
|
157 |
|
158 |
-
self.canny_model = self.canny_model
|
159 |
|
160 |
keyframes = [int(x) for x in keyframes.split(",")]
|
161 |
if self.data is None:
|
@@ -186,7 +186,7 @@ class StableVideo:
|
|
186 |
# get canny control
|
187 |
detected_map = self.apply_canny(img, low_threshold, high_threshold)
|
188 |
detected_map = HWC3(detected_map)
|
189 |
-
control = torch.from_numpy(detected_map.copy()).float()
|
190 |
control = einops.rearrange(control.unsqueeze(0), 'b h w c -> b c h w').clone()
|
191 |
|
192 |
cond = {"c_concat": [control], "c_crossattn": c_crossattn}
|
@@ -195,7 +195,7 @@ class StableVideo:
|
|
195 |
|
196 |
# if not the key frame, calculate the mapping from last atlas
|
197 |
if i == 0:
|
198 |
-
latent = torch.randn((1, 4, H // 8, W // 8))
|
199 |
samples, _ = self.canny_ddim_sampler.sample(ddim_steps, num_samples,
|
200 |
shape, cond, verbose=False, eta=eta,
|
201 |
unconditional_guidance_scale=scale,
|
@@ -209,7 +209,7 @@ class StableVideo:
|
|
209 |
mapped_img = mapped_img.resize((W, H))
|
210 |
mapped_img = np.array(mapped_img).astype(np.float32) / 255.0
|
211 |
mapped_img = mapped_img[None].transpose(0, 3, 1, 2)
|
212 |
-
mapped_img = torch.from_numpy(mapped_img)
|
213 |
mapped_img = 2. * mapped_img - 1.
|
214 |
latent = self.canny_model.get_first_stage_encoding(self.canny_model.encode_first_stage(mapped_img))
|
215 |
|
@@ -232,7 +232,7 @@ class StableVideo:
|
|
232 |
result = alpha * result
|
233 |
|
234 |
# buffer for training
|
235 |
-
result_copy = result.clone()
|
236 |
result_copy.requires_grad = True
|
237 |
result_list.append(result_copy)
|
238 |
|
@@ -259,7 +259,7 @@ class StableVideo:
|
|
259 |
# aggregate net #
|
260 |
#####################################
|
261 |
lr, n_epoch = 1e-3, 500
|
262 |
-
agg_net = AGGNet()
|
263 |
loss_fn = nn.L1Loss()
|
264 |
optimizer = optim.SGD(agg_net.parameters(), lr=lr, momentum=0.9)
|
265 |
for _ in range(n_epoch):
|
@@ -291,12 +291,12 @@ class StableVideo:
|
|
291 |
def render(self, f_atlas, b_atlas):
|
292 |
# foreground
|
293 |
if f_atlas == None:
|
294 |
-
f_atlas = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0)
|
295 |
else:
|
296 |
f_atlas, mask = f_atlas["image"], f_atlas["mask"]
|
297 |
-
f_atlas_origin = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0)
|
298 |
-
f_atlas = transforms.ToTensor()(f_atlas).unsqueeze(0)
|
299 |
-
mask = transforms.ToTensor()(mask).unsqueeze(0)
|
300 |
if f_atlas.shape != mask.shape:
|
301 |
print("Warning: truncating mask to atlas shape {}".format(f_atlas.shape))
|
302 |
mask = mask[:f_atlas.shape[0], :f_atlas.shape[1], :f_atlas.shape[2], :f_atlas.shape[3]]
|
@@ -326,7 +326,7 @@ class StableVideo:
|
|
326 |
if b_atlas == None:
|
327 |
b_atlas = self.b_atlas_origin
|
328 |
|
329 |
-
b_atlas = transforms.ToTensor()(b_atlas).unsqueeze(0)
|
330 |
background_edit = F.grid_sample(
|
331 |
b_atlas, self.data.scaled_background_uvs, mode="bilinear", align_corners=self.data.config["align_corners"]
|
332 |
).clamp(min=0.0, max=1.0)
|
@@ -349,99 +349,98 @@ class StableVideo:
|
|
349 |
return save_name
|
350 |
|
351 |
if __name__ == '__main__':
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
with
|
362 |
-
|
363 |
-
|
364 |
-
with gr.
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
with gr.
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
output_background_atlas = gr.Image(label="Output Background Atlas", type="pil", interactive=False)
|
416 |
-
|
417 |
-
# edit param
|
418 |
-
f_adv_edit_param = [adv_keyframes,
|
419 |
-
adv_atlas_resolution,
|
420 |
-
f_prompt,
|
421 |
-
adv_a_prompt,
|
422 |
-
adv_n_prompt,
|
423 |
-
adv_image_resolution,
|
424 |
-
adv_low_threshold,
|
425 |
-
adv_high_threshold,
|
426 |
-
adv_ddim_steps,
|
427 |
-
adv_s,
|
428 |
-
adv_scale,
|
429 |
-
adv_seed,
|
430 |
-
adv_eta,
|
431 |
-
adv_if_net]
|
432 |
-
b_edit_param = [b_prompt,
|
433 |
-
b_a_prompt,
|
434 |
-
b_n_prompt,
|
435 |
-
b_image_resolution,
|
436 |
-
b_detect_resolution,
|
437 |
-
b_ddim_steps,
|
438 |
-
b_scale,
|
439 |
-
b_seed,
|
440 |
-
b_eta]
|
441 |
-
# action
|
442 |
-
load_video_button.click(fn=stablevideo.load_video, inputs=video_name, outputs=[original_video, foreground_atlas, background_atlas])
|
443 |
-
f_advance_run_button.click(fn=stablevideo.advanced_edit_foreground, inputs=f_adv_edit_param, outputs=[output_foreground_atlas])
|
444 |
-
b_run_button.click(fn=stablevideo.edit_background, inputs=b_edit_param, outputs=[output_background_atlas])
|
445 |
-
run_button.click(fn=stablevideo.render, inputs=[output_foreground_atlas, output_background_atlas], outputs=[output_video])
|
446 |
|
447 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
):
|
49 |
self.apply_canny = CannyDetector()
|
50 |
canny_model = create_model(base_cfg).cpu()
|
51 |
+
canny_model.load_state_dict(load_state_dict(canny_model_cfg, location='cpu'), strict=False)
|
52 |
self.canny_ddim_sampler = DDIMSampler(canny_model)
|
53 |
self.canny_model = canny_model
|
54 |
|
|
|
59 |
):
|
60 |
self.apply_midas = MidasDetector()
|
61 |
depth_model = create_model(base_cfg).cpu()
|
62 |
+
depth_model.load_state_dict(load_state_dict(depth_model_cfg, location='cpu'), strict=False)
|
63 |
self.depth_ddim_sampler = DDIMSampler(depth_model)
|
64 |
self.depth_model = depth_model
|
65 |
|
|
|
101 |
|
102 |
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
103 |
|
104 |
+
control = torch.from_numpy(detected_map.copy()).float() / 255.0
|
105 |
control = torch.stack([control for _ in range(1)], dim=0)
|
106 |
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
107 |
|
|
|
128 |
|
129 |
@torch.no_grad()
|
130 |
def edit_background(self, *args, **kwargs):
|
131 |
+
self.depth_model = self.depth_model
|
132 |
|
133 |
input_image = self.b_atlas_origin
|
134 |
self.depth_edit(input_image, *args, **kwargs)
|
|
|
155 |
if_net=False,
|
156 |
num_samples=1):
|
157 |
|
158 |
+
self.canny_model = self.canny_model
|
159 |
|
160 |
keyframes = [int(x) for x in keyframes.split(",")]
|
161 |
if self.data is None:
|
|
|
186 |
# get canny control
|
187 |
detected_map = self.apply_canny(img, low_threshold, high_threshold)
|
188 |
detected_map = HWC3(detected_map)
|
189 |
+
control = torch.from_numpy(detected_map.copy()).float() / 255.0
|
190 |
control = einops.rearrange(control.unsqueeze(0), 'b h w c -> b c h w').clone()
|
191 |
|
192 |
cond = {"c_concat": [control], "c_crossattn": c_crossattn}
|
|
|
195 |
|
196 |
# if not the key frame, calculate the mapping from last atlas
|
197 |
if i == 0:
|
198 |
+
latent = torch.randn((1, 4, H // 8, W // 8))
|
199 |
samples, _ = self.canny_ddim_sampler.sample(ddim_steps, num_samples,
|
200 |
shape, cond, verbose=False, eta=eta,
|
201 |
unconditional_guidance_scale=scale,
|
|
|
209 |
mapped_img = mapped_img.resize((W, H))
|
210 |
mapped_img = np.array(mapped_img).astype(np.float32) / 255.0
|
211 |
mapped_img = mapped_img[None].transpose(0, 3, 1, 2)
|
212 |
+
mapped_img = torch.from_numpy(mapped_img)
|
213 |
mapped_img = 2. * mapped_img - 1.
|
214 |
latent = self.canny_model.get_first_stage_encoding(self.canny_model.encode_first_stage(mapped_img))
|
215 |
|
|
|
232 |
result = alpha * result
|
233 |
|
234 |
# buffer for training
|
235 |
+
result_copy = result.clone()
|
236 |
result_copy.requires_grad = True
|
237 |
result_list.append(result_copy)
|
238 |
|
|
|
259 |
# aggregate net #
|
260 |
#####################################
|
261 |
lr, n_epoch = 1e-3, 500
|
262 |
+
agg_net = AGGNet()
|
263 |
loss_fn = nn.L1Loss()
|
264 |
optimizer = optim.SGD(agg_net.parameters(), lr=lr, momentum=0.9)
|
265 |
for _ in range(n_epoch):
|
|
|
291 |
def render(self, f_atlas, b_atlas):
|
292 |
# foreground
|
293 |
if f_atlas == None:
|
294 |
+
f_atlas = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0)
|
295 |
else:
|
296 |
f_atlas, mask = f_atlas["image"], f_atlas["mask"]
|
297 |
+
f_atlas_origin = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0)
|
298 |
+
f_atlas = transforms.ToTensor()(f_atlas).unsqueeze(0)
|
299 |
+
mask = transforms.ToTensor()(mask).unsqueeze(0)
|
300 |
if f_atlas.shape != mask.shape:
|
301 |
print("Warning: truncating mask to atlas shape {}".format(f_atlas.shape))
|
302 |
mask = mask[:f_atlas.shape[0], :f_atlas.shape[1], :f_atlas.shape[2], :f_atlas.shape[3]]
|
|
|
326 |
if b_atlas == None:
|
327 |
b_atlas = self.b_atlas_origin
|
328 |
|
329 |
+
b_atlas = transforms.ToTensor()(b_atlas).unsqueeze(0)
|
330 |
background_edit = F.grid_sample(
|
331 |
b_atlas, self.data.scaled_background_uvs, mode="bilinear", align_corners=self.data.config["align_corners"]
|
332 |
).clamp(min=0.0, max=1.0)
|
|
|
349 |
return save_name
|
350 |
|
351 |
if __name__ == '__main__':
|
352 |
+
stablevideo = StableVideo(base_cfg="ckpt/cldm_v15.yaml",
|
353 |
+
canny_model_cfg="ckpt/control_sd15_canny.pth",
|
354 |
+
depth_model_cfg="ckpt/control_sd15_depth.pth",
|
355 |
+
save_memory=True)
|
356 |
+
stablevideo.load_canny_model()
|
357 |
+
stablevideo.load_depth_model()
|
358 |
+
|
359 |
+
block = gr.Blocks().queue()
|
360 |
+
with block:
|
361 |
+
with gr.Row():
|
362 |
+
gr.Markdown("## StableVideo")
|
363 |
+
with gr.Row():
|
364 |
+
with gr.Column():
|
365 |
+
original_video = gr.Video(label="Original Video", interactive=False)
|
366 |
+
with gr.Row():
|
367 |
+
foreground_atlas = gr.Image(label="Foreground Atlas", type="pil")
|
368 |
+
background_atlas = gr.Image(label="Background Atlas", type="pil")
|
369 |
+
gr.Markdown("### Step 1. select one example video and click **Load Video** buttom and wait for 10 sec.")
|
370 |
+
avail_video = [f.name for f in os.scandir("data") if f.is_dir()]
|
371 |
+
video_name = gr.Radio(choices=avail_video,
|
372 |
+
label="Select Example Videos",
|
373 |
+
value="car-turn")
|
374 |
+
load_video_button = gr.Button("Load Video")
|
375 |
+
gr.Markdown("### Step 2. write text prompt and advanced options for background and foreground.")
|
376 |
+
with gr.Row():
|
377 |
+
f_prompt = gr.Textbox(label="Foreground Prompt", value="a picture of an orange suv")
|
378 |
+
b_prompt = gr.Textbox(label="Background Prompt", value="winter scene, snowy scene, beautiful snow")
|
379 |
+
with gr.Row():
|
380 |
+
with gr.Accordion("Advanced Foreground Options", open=False):
|
381 |
+
adv_keyframes = gr.Textbox(label="keyframe", value="20, 40, 60")
|
382 |
+
adv_atlas_resolution = gr.Slider(label="Atlas Resolution", minimum=1000, maximum=3000, value=2000, step=100)
|
383 |
+
adv_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
|
384 |
+
adv_low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
|
385 |
+
adv_high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
|
386 |
+
adv_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
387 |
+
adv_s = gr.Slider(label="Noise Scale", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
|
388 |
+
adv_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=15.0, value=9.0, step=0.1)
|
389 |
+
adv_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
390 |
+
adv_eta = gr.Number(label="eta (DDIM)", value=0.0)
|
391 |
+
adv_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, no background')
|
392 |
+
adv_n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
393 |
+
adv_if_net = gr.gradio.Checkbox(label="if use agg net", value=False)
|
394 |
+
|
395 |
+
with gr.Accordion("Background Options", open=False):
|
396 |
+
b_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
|
397 |
+
b_detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=512, step=1)
|
398 |
+
b_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
399 |
+
b_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
|
400 |
+
b_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
401 |
+
b_eta = gr.Number(label="eta (DDIM)", value=0.0)
|
402 |
+
b_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
403 |
+
b_n_prompt = gr.Textbox(label="Negative Prompt",
|
404 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
405 |
+
gr.Markdown("### Step 3. edit each one and render.")
|
406 |
+
with gr.Row():
|
407 |
+
f_advance_run_button = gr.Button("Advanced Edit Foreground (slower, better)")
|
408 |
+
b_run_button = gr.Button("Edit Background")
|
409 |
+
run_button = gr.Button("Render")
|
410 |
+
with gr.Column():
|
411 |
+
output_video = gr.Video(label="Output Video", interactive=False)
|
412 |
+
# output_foreground_atlas = gr.Image(label="Output Foreground Atlas", type="pil", interactive=False)
|
413 |
+
output_foreground_atlas = gr.ImageMask(label="Editable Output Foreground Atlas", type="pil", tool="sketch", interactive=True)
|
414 |
+
output_background_atlas = gr.Image(label="Output Background Atlas", type="pil", interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
|
416 |
+
# edit param
|
417 |
+
f_adv_edit_param = [adv_keyframes,
|
418 |
+
adv_atlas_resolution,
|
419 |
+
f_prompt,
|
420 |
+
adv_a_prompt,
|
421 |
+
adv_n_prompt,
|
422 |
+
adv_image_resolution,
|
423 |
+
adv_low_threshold,
|
424 |
+
adv_high_threshold,
|
425 |
+
adv_ddim_steps,
|
426 |
+
adv_s,
|
427 |
+
adv_scale,
|
428 |
+
adv_seed,
|
429 |
+
adv_eta,
|
430 |
+
adv_if_net]
|
431 |
+
b_edit_param = [b_prompt,
|
432 |
+
b_a_prompt,
|
433 |
+
b_n_prompt,
|
434 |
+
b_image_resolution,
|
435 |
+
b_detect_resolution,
|
436 |
+
b_ddim_steps,
|
437 |
+
b_scale,
|
438 |
+
b_seed,
|
439 |
+
b_eta]
|
440 |
+
# action
|
441 |
+
load_video_button.click(fn=stablevideo.load_video, inputs=video_name, outputs=[original_video, foreground_atlas, background_atlas])
|
442 |
+
f_advance_run_button.click(fn=stablevideo.advanced_edit_foreground, inputs=f_adv_edit_param, outputs=[output_foreground_atlas])
|
443 |
+
b_run_button.click(fn=stablevideo.edit_background, inputs=b_edit_param, outputs=[output_background_atlas])
|
444 |
+
run_button.click(fn=stablevideo.render, inputs=[output_foreground_atlas, output_background_atlas], outputs=[output_video])
|
445 |
+
|
446 |
+
block.launch()
|
ckpt/cldm_v15.yaml
CHANGED
@@ -77,3 +77,5 @@ model:
|
|
77 |
|
78 |
cond_stage_config:
|
79 |
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
|
|
|
|
|
77 |
|
78 |
cond_stage_config:
|
79 |
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
80 |
+
params:
|
81 |
+
device: "cpu"
|
requirements.txt
CHANGED
@@ -120,3 +120,4 @@ wcwidth==0.2.6
|
|
120 |
websockets==11.0.3
|
121 |
Werkzeug==2.3.7
|
122 |
yarl==1.9.2
|
|
|
|
120 |
websockets==11.0.3
|
121 |
Werkzeug==2.3.7
|
122 |
yarl==1.9.2
|
123 |
+
xformers
|
stablevideo/atlas_data.py
CHANGED
@@ -30,7 +30,7 @@ class AtlasData():
|
|
30 |
maximum_number_of_frames = json_dict["maximum_number_of_frames"]
|
31 |
|
32 |
config = {
|
33 |
-
"device": "
|
34 |
"checkpoint_path": f"data/{video_name}/checkpoint.ckpt",
|
35 |
"resx": json_dict["resx"],
|
36 |
"resy": json_dict["resy"],
|
|
|
30 |
maximum_number_of_frames = json_dict["maximum_number_of_frames"]
|
31 |
|
32 |
config = {
|
33 |
+
"device": "cpu",
|
34 |
"checkpoint_path": f"data/{video_name}/checkpoint.ckpt",
|
35 |
"resx": json_dict["resx"],
|
36 |
"resy": json_dict["resy"],
|
stablevideo/atlas_utils.py
CHANGED
@@ -72,7 +72,7 @@ def load_neural_atlases_models(config):
|
|
72 |
skip_layers=[],
|
73 |
).to(config["device"])
|
74 |
|
75 |
-
checkpoint = torch.load(config["checkpoint_path"])
|
76 |
foreground_mapping.load_state_dict(checkpoint["model_F_mapping1_state_dict"])
|
77 |
background_mapping.load_state_dict(checkpoint["model_F_mapping2_state_dict"])
|
78 |
foreground_atlas_model.load_state_dict(checkpoint["F_atlas_state_dict"])
|
|
|
72 |
skip_layers=[],
|
73 |
).to(config["device"])
|
74 |
|
75 |
+
checkpoint = torch.load(config["checkpoint_path"], map_location=torch.device('cpu'))
|
76 |
foreground_mapping.load_state_dict(checkpoint["model_F_mapping1_state_dict"])
|
77 |
background_mapping.load_state_dict(checkpoint["model_F_mapping2_state_dict"])
|
78 |
foreground_atlas_model.load_state_dict(checkpoint["F_atlas_state_dict"])
|