wondervictor commited on
Commit
d9a7d69
·
verified ·
1 Parent(s): 6efa488

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +4 -4
model.py CHANGED
@@ -74,8 +74,8 @@ class Model:
74
 
75
  def load_gpt(self, condition_type='canny'):
76
  gpt_ckpt = models[condition_type]
77
- precision = torch.bfloat16
78
- # precision = torch.float32
79
  latent_size = 768 // 16
80
  gpt_model = GPT_models["GPT-XL"](
81
  block_size=latent_size**2,
@@ -91,8 +91,8 @@ class Model:
91
  return gpt_model
92
 
93
  def load_t5(self):
94
- precision = torch.bfloat16
95
- # precision = torch.float32
96
  t5_model = T5Embedder(
97
  device=self.device,
98
  local_cache=True,
 
74
 
75
  def load_gpt(self, condition_type='canny'):
76
  gpt_ckpt = models[condition_type]
77
+ # precision = torch.bfloat16
78
+ precision = torch.float32
79
  latent_size = 768 // 16
80
  gpt_model = GPT_models["GPT-XL"](
81
  block_size=latent_size**2,
 
91
  return gpt_model
92
 
93
  def load_t5(self):
94
+ # precision = torch.bfloat16
95
+ precision = torch.float32
96
  t5_model = T5Embedder(
97
  device=self.device,
98
  local_cache=True,