Spaces:
Paused
Paused
guardiancc
commited on
Commit
•
f823cf1
1
Parent(s):
f3672f8
Update mimicmotion/pipelines/pipeline_mimicmotion.py
Browse files
mimicmotion/pipelines/pipeline_mimicmotion.py
CHANGED
@@ -556,21 +556,17 @@ class MimicMotionPipeline(DiffusionPipeline):
|
|
556 |
# expand the latents if we are doing classifier free guidance
|
557 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
558 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
559 |
-
|
560 |
# Concatenate image_latents over channels dimension
|
561 |
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
562 |
-
|
563 |
# predict the noise residual
|
564 |
noise_pred = torch.zeros_like(image_latents)
|
565 |
noise_pred_cnt = image_latents.new_zeros((num_frames,))
|
566 |
weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size
|
567 |
weight = torch.minimum(weight, 2 - weight)
|
568 |
-
|
569 |
-
|
570 |
-
def process_index(idx):
|
571 |
-
nonlocal noise_pred, noise_pred_cnt
|
572 |
-
result = torch.zeros_like(image_latents[:1, idx]) # Placeholder for thread-safe accumulation
|
573 |
-
|
574 |
# classification-free inference
|
575 |
pose_latents = self.pose_net(image_pose[idx].to(device))
|
576 |
_noise_pred = self.unet(
|
@@ -582,8 +578,8 @@ class MimicMotionPipeline(DiffusionPipeline):
|
|
582 |
image_only_indicator=image_only_indicator,
|
583 |
return_dict=False,
|
584 |
)[0]
|
585 |
-
|
586 |
-
|
587 |
# normal inference
|
588 |
_noise_pred = self.unet(
|
589 |
latent_model_input[1:, idx],
|
@@ -594,34 +590,26 @@ class MimicMotionPipeline(DiffusionPipeline):
|
|
594 |
image_only_indicator=image_only_indicator,
|
595 |
return_dict=False,
|
596 |
)[0]
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
with concurrent.futures.ThreadPoolExecutor() as executor:
|
602 |
-
futures = [executor.submit(process_index, idx) for idx in indices]
|
603 |
-
for future in concurrent.futures.as_completed(futures):
|
604 |
-
_noise_pred, idx = future.result()
|
605 |
-
noise_pred[:, idx] += _noise_pred
|
606 |
-
noise_pred_cnt[idx] += weight
|
607 |
-
progress_bar.update()
|
608 |
-
|
609 |
noise_pred.div_(noise_pred_cnt[:, None, None, None])
|
610 |
-
|
611 |
# perform guidance
|
612 |
if self.do_classifier_free_guidance:
|
613 |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
614 |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
615 |
-
|
616 |
# compute the previous noisy sample x_t -> x_t-1
|
617 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
618 |
-
|
619 |
if callback_on_step_end is not None:
|
620 |
callback_kwargs = {}
|
621 |
for k in callback_on_step_end_tensor_inputs:
|
622 |
callback_kwargs[k] = locals()[k]
|
623 |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
624 |
-
|
625 |
latents = callback_outputs.pop("latents", latents)
|
626 |
|
627 |
self.pose_net.cpu()
|
|
|
556 |
# expand the latents if we are doing classifier free guidance
|
557 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
558 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
559 |
+
|
560 |
# Concatenate image_latents over channels dimension
|
561 |
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
562 |
+
|
563 |
# predict the noise residual
|
564 |
noise_pred = torch.zeros_like(image_latents)
|
565 |
noise_pred_cnt = image_latents.new_zeros((num_frames,))
|
566 |
weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size
|
567 |
weight = torch.minimum(weight, 2 - weight)
|
568 |
+
for idx in indices:
|
569 |
+
|
|
|
|
|
|
|
|
|
570 |
# classification-free inference
|
571 |
pose_latents = self.pose_net(image_pose[idx].to(device))
|
572 |
_noise_pred = self.unet(
|
|
|
578 |
image_only_indicator=image_only_indicator,
|
579 |
return_dict=False,
|
580 |
)[0]
|
581 |
+
noise_pred[:1, idx] += _noise_pred * weight[:, None, None, None]
|
582 |
+
|
583 |
# normal inference
|
584 |
_noise_pred = self.unet(
|
585 |
latent_model_input[1:, idx],
|
|
|
590 |
image_only_indicator=image_only_indicator,
|
591 |
return_dict=False,
|
592 |
)[0]
|
593 |
+
noise_pred[1:, idx] += _noise_pred * weight[:, None, None, None]
|
594 |
+
|
595 |
+
noise_pred_cnt[idx] += weight
|
596 |
+
progress_bar.update()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
597 |
noise_pred.div_(noise_pred_cnt[:, None, None, None])
|
598 |
+
|
599 |
# perform guidance
|
600 |
if self.do_classifier_free_guidance:
|
601 |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
602 |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
603 |
+
|
604 |
# compute the previous noisy sample x_t -> x_t-1
|
605 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
606 |
+
|
607 |
if callback_on_step_end is not None:
|
608 |
callback_kwargs = {}
|
609 |
for k in callback_on_step_end_tensor_inputs:
|
610 |
callback_kwargs[k] = locals()[k]
|
611 |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
612 |
+
|
613 |
latents = callback_outputs.pop("latents", latents)
|
614 |
|
615 |
self.pose_net.cpu()
|