wondervictor commited on
Commit
7d5e8b3
·
verified ·
1 Parent(s): 122437d

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +6 -6
model.py CHANGED
@@ -74,8 +74,8 @@ class Model:
74
 
75
  def load_gpt(self, condition_type='canny'):
76
  gpt_ckpt = models[condition_type]
77
- precision = torch.bfloat16
78
- # precision = torch.float32
79
  latent_size = 768 // 16
80
  gpt_model = GPT_models["GPT-XL"](
81
  block_size=latent_size**2,
@@ -93,8 +93,8 @@ class Model:
93
  return gpt_model
94
 
95
  def load_t5(self):
96
- precision = torch.bfloat16
97
- # precision = torch.float32
98
  t5_model = T5Embedder(
99
  device=self.device,
100
  local_cache=True,
@@ -124,8 +124,8 @@ class Model:
124
  W, H = image.size
125
  print(W, H)
126
  # self.gpt_model_depth.to('cpu')
127
- self.t5_model.model.to('cuda')
128
- self.gpt_model_canny.to('cuda')
129
  self.vq_model.to('cuda')
130
  # print("after cuda", self.gpt_model_canny.adapter.model.embeddings.patch_embeddings.projection.weight)
131
 
 
74
 
75
  def load_gpt(self, condition_type='canny'):
76
  gpt_ckpt = models[condition_type]
77
+ # precision = torch.bfloat16
78
+ precision = torch.float32
79
  latent_size = 768 // 16
80
  gpt_model = GPT_models["GPT-XL"](
81
  block_size=latent_size**2,
 
93
  return gpt_model
94
 
95
  def load_t5(self):
96
+ # precision = torch.bfloat16
97
+ precision = torch.float32
98
  t5_model = T5Embedder(
99
  device=self.device,
100
  local_cache=True,
 
124
  W, H = image.size
125
  print(W, H)
126
  # self.gpt_model_depth.to('cpu')
127
+ self.t5_model.model.to('cuda').to(torch.bfloat16)
128
+ self.gpt_model_canny.to('cuda').to(torch.bfloat16)
129
  self.vq_model.to('cuda')
130
  # print("after cuda", self.gpt_model_canny.adapter.model.embeddings.patch_embeddings.projection.weight)
131