AntonVoronov commited on
Commit
94cd78d
1 Parent(s): ebf782e

make default init in torch.bfloat16

Browse files
Files changed (1) hide show
  1. models/pipeline.py +7 -5
models/pipeline.py CHANGED
@@ -13,11 +13,13 @@ class SwittiPipeline:
13
  text_encoder_path = "openai/clip-vit-large-patch14"
14
  text_encoder_2_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
15
 
16
- def __init__(self, switti, vae, text_encoder, text_encoder_2, device):
17
- self.switti = switti
18
- self.vae = vae
19
- self.text_encoder = text_encoder
20
- self.text_encoder_2 = text_encoder_2
 
 
21
 
22
  self.switti.eval()
23
  self.vae.eval()
 
13
  text_encoder_path = "openai/clip-vit-large-patch14"
14
  text_encoder_2_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
15
 
16
+ def __init__(self, switti, vae, text_encoder, text_encoder_2, device,
17
+ dtype=torch.bfloat16,
18
+ ):
19
+ self.switti = switti.to(dtype)
20
+ self.vae = vae.to(dtype)
21
+ self.text_encoder = text_encoder.to(dtype)
22
+ self.text_encoder_2 = text_encoder_2.to(dtype)
23
 
24
  self.switti.eval()
25
  self.vae.eval()