wondervictor commited on
Commit
40c9ea6
·
verified ·
1 Parent(s): 180819f

Update model_new.py

Browse files
Files changed (1) hide show
  1. model_new.py +20 -8
model_new.py CHANGED
@@ -180,21 +180,33 @@ class Model:
180
  top_k: int,
181
  top_p: int,
182
  seed: int,
 
 
183
  ) -> list[PIL.Image.Image]:
184
- image = resize_image_to_16_multiple(image, 'depth')
185
- W, H = image.size
186
- print(W, H)
187
  self.gpt_model_canny.to('cpu')
188
  self.t5_model.model.to(self.device)
189
  self.gpt_model_depth.to(self.device)
190
  self.get_control_depth.model.to(self.device)
191
  self.vq_model.to(self.device)
192
- image_tensor = torch.from_numpy(np.array(image)).to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- condition_img = 2 * (image_tensor / 255 - 0.5)
195
- print(condition_img.shape)
196
- condition_img = condition_img.permute(2,0,1).unsqueeze(0).repeat(2, 1, 1, 1)
197
-
198
  prompts = [prompt] * 2
199
  caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
200
 
 
180
  top_k: int,
181
  top_p: int,
182
  seed: int,
183
+ control_strength: float,
184
+ preprocessor_name: str
185
  ) -> list[PIL.Image.Image]:
 
 
 
186
  self.gpt_model_canny.to('cpu')
187
  self.t5_model.model.to(self.device)
188
  self.gpt_model_depth.to(self.device)
189
  self.get_control_depth.model.to(self.device)
190
  self.vq_model.to(self.device)
191
+ if isinstance(image, np.ndarray):
192
+ image = Image.fromarray(image)
193
+ origin_W, origin_H = image.size
194
+ # print(image)
195
+ if preprocessor_name == 'depth':
196
+ self.preprocessor.load("Depth")
197
+ condition_img = self.preprocessor(
198
+ image=image,
199
+ image_resolution=512,
200
+ detect_resolution=512,
201
+ )
202
+ elif preprocessor_name == 'No preprocess':
203
+ condition_img = image
204
+ condition_img = condition_img.resize((512,512))
205
+ W, H = condition_img.size
206
 
207
+ condition_img = torch.from_numpy(np.array(condition_img)).unsqueeze(0).permute(0,3,1,2).repeat(2,1,1,1)
208
+ condition_img = condition_img.to(self.device)
209
+ condition_img = 2*(condition_img/255 - 0.5)
 
210
  prompts = [prompt] * 2
211
  caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
212