Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -65,34 +65,34 @@ def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
|
|
65 |
noise_shape = [batch_size, channels, frames, h, w]
|
66 |
|
67 |
# text cond
|
68 |
-
|
69 |
-
|
70 |
-
# img cond
|
71 |
-
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
|
72 |
-
img_tensor = (img_tensor / 255. - 0.5) * 2
|
73 |
-
|
74 |
-
image_tensor_resized = transform(img_tensor) #3,256,256
|
75 |
-
videos = image_tensor_resized.unsqueeze(0) # bchw
|
76 |
|
77 |
-
|
|
|
|
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
fs = torch.tensor([fs], dtype=torch.long, device=model.device)
|
87 |
-
cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
|
88 |
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
model = model.cpu()
|
97 |
return video_path
|
98 |
|
|
|
65 |
noise_shape = [batch_size, channels, frames, h, w]
|
66 |
|
67 |
# text cond
|
68 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
69 |
+
text_emb = model.get_learned_conditioning([prompt])
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
# img cond
|
72 |
+
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
|
73 |
+
img_tensor = (img_tensor / 255. - 0.5) * 2
|
74 |
|
75 |
+
image_tensor_resized = transform(img_tensor) #3,256,256
|
76 |
+
videos = image_tensor_resized.unsqueeze(0) # bchw
|
77 |
+
|
78 |
+
z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
|
79 |
+
|
80 |
+
img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
|
|
|
|
|
|
|
81 |
|
82 |
+
cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
|
83 |
+
img_emb = model.image_proj_model(cond_images)
|
84 |
+
|
85 |
+
imtext_cond = torch.cat([text_emb, img_emb], dim=1)
|
86 |
+
|
87 |
+
fs = torch.tensor([fs], dtype=torch.long, device=model.device)
|
88 |
+
cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
|
89 |
+
|
90 |
+
## inference
|
91 |
batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
|
92 |
+
## b,samples,c,t,h,w
|
93 |
+
|
94 |
+
video_path = './output.mp4'
|
95 |
+
save_videos(batch_samples, './', filenames=['output'], fps=save_fps)
|
96 |
model = model.cpu()
|
97 |
return video_path
|
98 |
|