wondervictor commited on
Commit
108d3f4
·
verified ·
1 Parent(s): 0342c0e

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +20 -11
model.py CHANGED
@@ -26,8 +26,9 @@ class Model:
26
  self.task_name = ""
27
  self.vq_model = self.load_vq()
28
  self.t5_model = self.load_t5()
29
- self.gpt_model_edge = self.load_gpt(condition_type='edge')
30
- self.gpt_model_depth = self.load_gpt(condition_type='depth')
 
31
  self.preprocessor = Preprocessor()
32
 
33
  def to(self, device):
@@ -45,7 +46,7 @@ class Model:
45
  return vq_model
46
 
47
  def load_gpt(self, condition_type='edge'):
48
- gpt_ckpt = models[condition_type]
49
  # precision = torch.bfloat16
50
  precision = torch.float32
51
  latent_size = 512 // 16
@@ -56,12 +57,19 @@ class Model:
56
  condition_type=condition_type,
57
  adapter_size='base',
58
  ).to(device='cpu', dtype=precision)
59
- model_weight = load_file(gpt_ckpt)
60
- gpt_model.load_state_dict(model_weight, strict=False)
61
- gpt_model.eval()
62
- print("gpt model is loaded")
63
  return gpt_model
64
 
 
 
 
 
 
 
 
65
  def load_t5(self):
66
  # precision = torch.bfloat16
67
  precision = torch.float32
@@ -92,7 +100,8 @@ class Model:
92
  preprocessor_name: str,
93
  ) -> list[PIL.Image.Image]:
94
  self.t5_model.model.to('cuda').to(torch.bfloat16)
95
- self.gpt_model_edge.to('cuda').to(torch.bfloat16)
 
96
  self.vq_model.to('cuda')
97
  if isinstance(image, np.ndarray):
98
  image = Image.fromarray(image)
@@ -114,10 +123,10 @@ class Model:
114
  condition_img = condition_img.resize((512,512))
115
  W, H = condition_img.size
116
 
117
- condition_img = torch.from_numpy(np.array(condition_img)).unsqueeze(0).permute(0,3,1,2).repeat(1,1,1,1)
118
  condition_img = condition_img.to(self.device)
119
  condition_img = 2*(condition_img/255 - 0.5)
120
- prompts = [prompt] * 1
121
  caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
122
 
123
  print(f"processing left-padding...")
@@ -137,7 +146,7 @@ class Model:
137
  t1 = time.time()
138
  print(caption_embs.device)
139
  index_sample = generate(
140
- self.gpt_model_edge,
141
  c_indices,
142
  (H // 16) * (W // 16),
143
  c_emb_masks,
 
26
  self.task_name = ""
27
  self.vq_model = self.load_vq()
28
  self.t5_model = self.load_t5()
29
+ # self.gpt_model_edge = self.load_gpt(condition_type='edge')
30
+ # self.gpt_model_depth = self.load_gpt(condition_type='depth')
31
+ self.gpt_model = self.load_gpt()
32
  self.preprocessor = Preprocessor()
33
 
34
  def to(self, device):
 
46
  return vq_model
47
 
48
  def load_gpt(self, condition_type='edge'):
49
+ # gpt_ckpt = models[condition_type]
50
  # precision = torch.bfloat16
51
  precision = torch.float32
52
  latent_size = 512 // 16
 
57
  condition_type=condition_type,
58
  adapter_size='base',
59
  ).to(device='cpu', dtype=precision)
60
+ # model_weight = load_file(gpt_ckpt)
61
+ # gpt_model.load_state_dict(model_weight, strict=False)
62
+ # gpt_model.eval()
63
+ # print("gpt model is loaded")
64
  return gpt_model
65
 
66
+ def load_gpt_weight(self, condition_type='edge'):
67
+ gpt_ckpt = models[condition_type]
68
+ model_weight = load_file(gpt_ckpt)
69
+ self.gpt_model.load_state_dict(model_weight, strict=False)
70
+ self.gpt_model.eval()
71
+ # print("gpt model is loaded")
72
+
73
  def load_t5(self):
74
  # precision = torch.bfloat16
75
  precision = torch.float32
 
100
  preprocessor_name: str,
101
  ) -> list[PIL.Image.Image]:
102
  self.t5_model.model.to('cuda').to(torch.bfloat16)
103
+ self.load_gpt_weight('edge')
104
+ self.gpt_model.to('cuda').to(torch.bfloat16)
105
  self.vq_model.to('cuda')
106
  if isinstance(image, np.ndarray):
107
  image = Image.fromarray(image)
 
123
  condition_img = condition_img.resize((512,512))
124
  W, H = condition_img.size
125
 
126
+ condition_img = torch.from_numpy(np.array(condition_img)).unsqueeze(0).permute(0,3,1,2).repeat(3,1,1,1)
127
  condition_img = condition_img.to(self.device)
128
  condition_img = 2*(condition_img/255 - 0.5)
129
+ prompts = [prompt] * 3
130
  caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
131
 
132
  print(f"processing left-padding...")
 
146
  t1 = time.time()
147
  print(caption_embs.device)
148
  index_sample = generate(
149
+ self.gpt_model,
150
  c_indices,
151
  (H // 16) * (W // 16),
152
  c_emb_masks,