Spaces:
Sleeping
Sleeping
Update model.py
Browse files
model.py
CHANGED
@@ -437,6 +437,7 @@ class StreamMultiDiffusion(nn.Module):
|
|
437 |
encoder. Shape: (B, 3, H, W).
|
438 |
"""
|
439 |
latents = 1 / self.vae.config.scaling_factor * latents
|
|
|
440 |
imgs = self.vae.decode(latents).sample
|
441 |
imgs = (imgs / 2 + 0.5).clip_(0, 1)
|
442 |
return imgs
|
@@ -1125,8 +1126,6 @@ class StreamMultiDiffusion(nn.Module):
|
|
1125 |
t_list, # (B,)
|
1126 |
encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
|
1127 |
return_dict=False,
|
1128 |
-
# TODO: Add SDXL Support.
|
1129 |
-
# added_cond_kwargs={'text_embeds': add_text_embeds, 'time_ids': add_time_ids},
|
1130 |
)[0] # (B, 4, h, w)
|
1131 |
|
1132 |
if self.bootstrap_steps[0] > 0:
|
@@ -1233,7 +1232,7 @@ class StreamMultiDiffusion(nn.Module):
|
|
1233 |
if no_decode:
|
1234 |
return latent
|
1235 |
|
1236 |
-
imgs = self.decode_latents(latent.half()) # (1, 3, H, W)
|
1237 |
img = T.ToPILImage()(imgs[0].cpu())
|
1238 |
return img
|
1239 |
|
|
|
437 |
encoder. Shape: (B, 3, H, W).
|
438 |
"""
|
439 |
latents = 1 / self.vae.config.scaling_factor * latents
|
440 |
+
latents = latents.to(dtype=self.vae.dtype, device=self.vae.device)
|
441 |
imgs = self.vae.decode(latents).sample
|
442 |
imgs = (imgs / 2 + 0.5).clip_(0, 1)
|
443 |
return imgs
|
|
|
1126 |
t_list, # (B,)
|
1127 |
encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
|
1128 |
return_dict=False,
|
|
|
|
|
1129 |
)[0] # (B, 4, h, w)
|
1130 |
|
1131 |
if self.bootstrap_steps[0] > 0:
|
|
|
1232 |
if no_decode:
|
1233 |
return latent
|
1234 |
|
1235 |
+
imgs = self.decode_latents(latent.half()).float() # (1, 3, H, W)
|
1236 |
img = T.ToPILImage()(imgs[0].cpu())
|
1237 |
return img
|
1238 |
|