fffiloni commited on
Commit
b6a6144
1 Parent(s): 1246a25

Update app_gradio.py

Browse files
Files changed (1) hide show
  1. app_gradio.py +16 -5
app_gradio.py CHANGED
@@ -67,12 +67,23 @@ def prepare_latents(pipe, x_aug):
67
 
68
  @torch.no_grad()
69
  def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16):
70
- input_img = [load_image(load_name, 256).to(device, dtype=dtype).unsqueeze(1)] * 5
71
- input_img = torch.cat(input_img, dim=1)
72
- latents = prepare_latents(pipe, input_img).to(torch.bfloat16)
 
 
 
 
 
 
73
  inv.set_timesteps(25)
74
- id_latents = dd_inversion(pipe, inv, video_latent=latents, num_inv_steps=25, prompt="")[-1].to(dtype)
75
- return torch.mean(id_latents, dim=2, keepdim=True)
 
 
 
 
 
76
 
77
  def load_primary_models(pretrained_model_path):
78
  return (
 
67
 
68
  @torch.no_grad()
69
  def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16):
70
+ # Load and process the image
71
+ input_img = load_image(load_name, target_size=256).to(device, dtype=torch.float32) # Shape: (1, C, H, W)
72
+ input_img = input_img.unsqueeze(1).repeat(1, 5, 1, 1, 1) # Add time dimension and repeat for T=5
73
+ # Shape: (B=1, T=5, C=3, H=256, W=256)
74
+
75
+ # Convert image to latent space
76
+ latents = prepare_latents(pipe, input_img).to(dtype) # Shape: (B, latent_dim, T, H/8, W/8)
77
+
78
+ # Configure the inversion process
79
  inv.set_timesteps(25)
80
+
81
+ # Perform inversion and extract final latent representation
82
+ id_latents = dd_inversion(pipe, inv, video_latent=latents, num_inv_steps=25, prompt="")[-1]
83
+ id_latents = id_latents.to(dtype) # Ensure correct dtype
84
+ id_latents = torch.mean(id_latents, dim=2, keepdim=True) # Shape: (B, latent_dim, 1, H/8, W/8)
85
+
86
+ return id_latents
87
 
88
  def load_primary_models(pretrained_model_path):
89
  return (