ironjr commited on
Commit
4d7f709
1 Parent(s): 4d4572e

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -3
model.py CHANGED
@@ -437,6 +437,7 @@ class StreamMultiDiffusion(nn.Module):
437
  encoder. Shape: (B, 3, H, W).
438
  """
439
  latents = 1 / self.vae.config.scaling_factor * latents
 
440
  imgs = self.vae.decode(latents).sample
441
  imgs = (imgs / 2 + 0.5).clip_(0, 1)
442
  return imgs
@@ -1125,8 +1126,6 @@ class StreamMultiDiffusion(nn.Module):
1125
  t_list, # (B,)
1126
  encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
1127
  return_dict=False,
1128
- # TODO: Add SDXL Support.
1129
- # added_cond_kwargs={'text_embeds': add_text_embeds, 'time_ids': add_time_ids},
1130
  )[0] # (B, 4, h, w)
1131
 
1132
  if self.bootstrap_steps[0] > 0:
@@ -1233,7 +1232,7 @@ class StreamMultiDiffusion(nn.Module):
1233
  if no_decode:
1234
  return latent
1235
 
1236
- imgs = self.decode_latents(latent.half()) # (1, 3, H, W)
1237
  img = T.ToPILImage()(imgs[0].cpu())
1238
  return img
1239
 
 
437
  encoder. Shape: (B, 3, H, W).
438
  """
439
  latents = 1 / self.vae.config.scaling_factor * latents
440
+ latents = latents.to(dtype=self.vae.dtype, device=self.vae.device)
441
  imgs = self.vae.decode(latents).sample
442
  imgs = (imgs / 2 + 0.5).clip_(0, 1)
443
  return imgs
 
1126
  t_list, # (B,)
1127
  encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
1128
  return_dict=False,
 
 
1129
  )[0] # (B, 4, h, w)
1130
 
1131
  if self.bootstrap_steps[0] > 0:
 
1232
  if no_decode:
1233
  return latent
1234
 
1235
+ imgs = self.decode_latents(latent.half()).float() # (1, 3, H, W)
1236
  img = T.ToPILImage()(imgs[0].cpu())
1237
  return img
1238