Spaces:
Paused
Paused
Update app_gradio.py
Browse files- app_gradio.py +16 -5
app_gradio.py
CHANGED
@@ -67,12 +67,23 @@ def prepare_latents(pipe, x_aug):
|
|
67 |
|
68 |
@torch.no_grad()
|
69 |
def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16):
|
70 |
-
|
71 |
-
input_img =
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
inv.set_timesteps(25)
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
def load_primary_models(pretrained_model_path):
|
78 |
return (
|
|
|
67 |
|
68 |
@torch.no_grad()
|
69 |
def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16):
|
70 |
+
# Load and process the image
|
71 |
+
input_img = load_image(load_name, target_size=256).to(device, dtype=torch.float32) # Shape: (1, C, H, W)
|
72 |
+
input_img = input_img.unsqueeze(1).repeat(1, 5, 1, 1, 1) # Add time dimension and repeat for T=5
|
73 |
+
# Shape: (B=1, T=5, C=3, H=256, W=256)
|
74 |
+
|
75 |
+
# Convert image to latent space
|
76 |
+
latents = prepare_latents(pipe, input_img).to(dtype) # Shape: (B, latent_dim, T, H/8, W/8)
|
77 |
+
|
78 |
+
# Configure the inversion process
|
79 |
inv.set_timesteps(25)
|
80 |
+
|
81 |
+
# Perform inversion and extract final latent representation
|
82 |
+
id_latents = dd_inversion(pipe, inv, video_latent=latents, num_inv_steps=25, prompt="")[-1]
|
83 |
+
id_latents = id_latents.to(dtype) # Ensure correct dtype
|
84 |
+
id_latents = torch.mean(id_latents, dim=2, keepdim=True) # Shape: (B, latent_dim, 1, H/8, W/8)
|
85 |
+
|
86 |
+
return id_latents
|
87 |
|
88 |
def load_primary_models(pretrained_model_path):
|
89 |
return (
|