benibraz commited on
Commit
0b28ec1
1 Parent(s): 660f849

load unet in bfloat16

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -45,7 +45,7 @@ def load_vae(vae_dir):
45
  vae = CausalVideoAutoencoder.from_config(vae_config)
46
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
47
  vae.load_state_dict(vae_state_dict)
48
- return vae.cuda().to(torch.bfloat16)
49
 
50
 
51
  def load_unet(unet_dir):
@@ -55,7 +55,7 @@ def load_unet(unet_dir):
55
  transformer = Transformer3DModel.from_config(transformer_config)
56
  unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
57
  transformer.load_state_dict(unet_state_dict, strict=True)
58
- return transformer.to(device)
59
 
60
 
61
  def load_scheduler(scheduler_dir):
 
45
  vae = CausalVideoAutoencoder.from_config(vae_config)
46
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
47
  vae.load_state_dict(vae_state_dict)
48
+ return vae.to(device=device, dtype=torch.bfloat16)
49
 
50
 
51
  def load_unet(unet_dir):
 
55
  transformer = Transformer3DModel.from_config(transformer_config)
56
  unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
57
  transformer.load_state_dict(unet_state_dict, strict=True)
58
+ return transformer.to(device=device, dtype=torch.bfloat16)
59
 
60
 
61
  def load_scheduler(scheduler_dir):