fffiloni commited on
Commit
55659e5
1 Parent(s): b6a6144

Update app_gradio.py

Browse files
Files changed (1) hide show
  1. app_gradio.py +8 -16
app_gradio.py CHANGED
@@ -67,23 +67,15 @@ def prepare_latents(pipe, x_aug):
67
 
68
  @torch.no_grad()
69
  def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16):
70
- # Load and process the image
71
- input_img = load_image(load_name, target_size=256).to(device, dtype=torch.float32) # Shape: (1, C, H, W)
72
- input_img = input_img.unsqueeze(1).repeat(1, 5, 1, 1, 1) # Add time dimension and repeat for T=5
73
- # Shape: (B=1, T=5, C=3, H=256, W=256)
74
-
75
- # Convert image to latent space
76
- latents = prepare_latents(pipe, input_img).to(dtype) # Shape: (B, latent_dim, T, H/8, W/8)
77
-
78
- # Configure the inversion process
79
  inv.set_timesteps(25)
80
-
81
- # Perform inversion and extract final latent representation
82
- id_latents = dd_inversion(pipe, inv, video_latent=latents, num_inv_steps=25, prompt="")[-1]
83
- id_latents = id_latents.to(dtype) # Ensure correct dtype
84
- id_latents = torch.mean(id_latents, dim=2, keepdim=True) # Shape: (B, latent_dim, 1, H/8, W/8)
85
-
86
- return id_latents
87
 
88
  def load_primary_models(pretrained_model_path):
89
  return (
 
67
 
68
  @torch.no_grad()
69
  def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16):
70
+ input_img = [load_image(load_name, 256).to(device, dtype=dtype).unsqueeze(1)] * 5
71
+ input_img = torch.cat(input_img, dim=1)
72
+ torch.cuda.synchronize() # Ensure image tensor preparation is complete
73
+ latents = prepare_latents(pipe, input_img).to(torch.bfloat16)
74
+ torch.cuda.synchronize() # Wait for latents to finish encoding
 
 
 
 
75
  inv.set_timesteps(25)
76
+ id_latents = dd_inversion(pipe, inv, video_latent=latents, num_inv_steps=25, prompt="")[-1].to(dtype)
77
+ torch.cuda.synchronize() # Ensure DDIM inversion is complete
78
+ return torch.mean(id_latents, dim=2, keepdim=True)
 
 
 
 
79
 
80
  def load_primary_models(pretrained_model_path):
81
  return (