Spaces:
Paused
Paused
Update app_gradio.py
Browse files- app_gradio.py +8 -16
app_gradio.py
CHANGED
@@ -67,23 +67,15 @@ def prepare_latents(pipe, x_aug):
|
|
67 |
|
68 |
@torch.no_grad()
|
69 |
def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16):
|
70 |
-
|
71 |
-
input_img =
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
# Convert image to latent space
|
76 |
-
latents = prepare_latents(pipe, input_img).to(dtype) # Shape: (B, latent_dim, T, H/8, W/8)
|
77 |
-
|
78 |
-
# Configure the inversion process
|
79 |
inv.set_timesteps(25)
|
80 |
-
|
81 |
-
#
|
82 |
-
|
83 |
-
id_latents = id_latents.to(dtype) # Ensure correct dtype
|
84 |
-
id_latents = torch.mean(id_latents, dim=2, keepdim=True) # Shape: (B, latent_dim, 1, H/8, W/8)
|
85 |
-
|
86 |
-
return id_latents
|
87 |
|
88 |
def load_primary_models(pretrained_model_path):
|
89 |
return (
|
|
|
67 |
|
68 |
@torch.no_grad()
|
69 |
def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16):
|
70 |
+
input_img = [load_image(load_name, 256).to(device, dtype=dtype).unsqueeze(1)] * 5
|
71 |
+
input_img = torch.cat(input_img, dim=1)
|
72 |
+
torch.cuda.synchronize() # Ensure image tensor preparation is complete
|
73 |
+
latents = prepare_latents(pipe, input_img).to(torch.bfloat16)
|
74 |
+
torch.cuda.synchronize() # Wait for latents to finish encoding
|
|
|
|
|
|
|
|
|
75 |
inv.set_timesteps(25)
|
76 |
+
id_latents = dd_inversion(pipe, inv, video_latent=latents, num_inv_steps=25, prompt="")[-1].to(dtype)
|
77 |
+
torch.cuda.synchronize() # Ensure DDIM inversion is complete
|
78 |
+
return torch.mean(id_latents, dim=2, keepdim=True)
|
|
|
|
|
|
|
|
|
79 |
|
80 |
def load_primary_models(pretrained_model_path):
|
81 |
return (
|