wondervictor commited on
Commit
9ed9b88
·
verified ·
1 Parent(s): 113349b

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +18 -6
model.py CHANGED
@@ -13,6 +13,10 @@ import time
13
  from autoregressive.models.generate import generate
14
  from condition.midas.depth import MidasDetector
15
 
 
 
 
 
16
  models = {
17
  "canny": "checkpoints/canny_MR.safetensors",
18
  "depth": "checkpoints/depth_MR.safetensors",
@@ -48,7 +52,8 @@ class Model:
48
  self.gpt_model_canny = self.load_gpt(condition_type='canny')
49
  self.gpt_model_depth = self.load_gpt(condition_type='depth')
50
  self.get_control_canny = CannyDetector()
51
- self.get_control_depth = MidasDetector('cuda')
 
52
 
53
  def to(self, device):
54
  self.gpt_model_canny.to('cuda')
@@ -196,11 +201,18 @@ class Model:
196
  # self.get_control_depth.model.to(self.device)
197
  # self.vq_model.to(self.device)
198
  image_tensor = torch.from_numpy(np.array(image)).to(self.device)
199
- condition_img = torch.from_numpy(
200
- self.get_control_depth(image_tensor)).unsqueeze(0)
201
- condition_img = condition_img.unsqueeze(0).repeat(2, 3, 1, 1)
202
- condition_img = condition_img.to(self.device)
203
- condition_img = 2 * (condition_img / 255 - 0.5)
 
 
 
 
 
 
 
204
  prompts = [prompt] * 2
205
  caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
206
 
 
13
  from autoregressive.models.generate import generate
14
  from condition.midas.depth import MidasDetector
15
 
16
+ from controlnet_aux import (
17
+ MidasDetector,
18
+ )
19
+
20
  models = {
21
  "canny": "checkpoints/canny_MR.safetensors",
22
  "depth": "checkpoints/depth_MR.safetensors",
 
52
  self.gpt_model_canny = self.load_gpt(condition_type='canny')
53
  self.gpt_model_depth = self.load_gpt(condition_type='depth')
54
  self.get_control_canny = CannyDetector()
55
+ # self.get_control_depth = MidasDetector('cuda')
56
+ self.get_control_depth = MidasDetector.from_pretrained("lllyasviel/Annotators")
57
 
58
  def to(self, device):
59
  self.gpt_model_canny.to('cuda')
 
201
  # self.get_control_depth.model.to(self.device)
202
  # self.vq_model.to(self.device)
203
  image_tensor = torch.from_numpy(np.array(image)).to(self.device)
204
+ # condition_img = torch.from_numpy(
205
+ # self.get_control_depth(image_tensor)).unsqueeze(0)
206
+ # condition_img = condition_img.unsqueeze(0).repeat(2, 3, 1, 1)
207
+ # condition_img = condition_img.to(self.device)
208
+ # condition_img = 2 * (condition_img / 255 - 0.5)
209
+
210
+ control_image = self.get_control_depth(
211
+ image=image,
212
+ image_resolution=512,
213
+ detect_resolution=512,
214
+ )
215
+
216
  prompts = [prompt] * 2
217
  caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
218