RamAnanth1 commited on
Commit
e102afc
1 Parent(s): 70e3488

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +33 -9
model.py CHANGED
@@ -19,6 +19,23 @@ import shlex
19
  import subprocess
20
  import sys
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  sys.path.append('T2I-Adapter')
23
 
24
  config_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/configs/stable-diffusion/'
@@ -101,6 +118,10 @@ class Model:
101
  def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
102
  global current_base
103
  device = 'cuda'
 
 
 
 
104
  # if current_base != base_model:
105
  # ckpt = os.path.join("models", base_model)
106
  # pl_sd = torch.load(ckpt, map_location="cpu")
@@ -132,8 +153,8 @@ class Model:
132
  im = im.float()
133
  im_edge = tensor2img(im)
134
 
135
- c = self.model.get_learned_conditioning([prompt])
136
- nc = self.model.get_learned_conditioning([neg_prompt])
137
 
138
  with torch.no_grad():
139
  # extract condition features
@@ -155,7 +176,7 @@ class Model:
155
  mode = 'sketch',
156
  con_strength = con_strength)
157
 
158
- x_samples_ddim = self.model.decode_first_stage(samples_ddim)
159
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
160
  x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).cpu().numpy()[0]
161
  x_samples_ddim = 255.*x_samples_ddim
@@ -166,7 +187,11 @@ class Model:
166
  @torch.inference_mode()
167
  def process_pose(self, input_img, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
168
  global current_base
169
- device = 'cuda'
 
 
 
 
170
  # if current_base != base_model:
171
  # ckpt = os.path.join("models", base_model)
172
  # pl_sd = torch.load(ckpt, map_location="cpu")
@@ -186,8 +211,8 @@ class Model:
186
 
187
  im_pose = tensor2img(pose)
188
 
189
- c = self.model.get_learned_conditioning([prompt])
190
- nc = self.model.get_learned_conditioning([neg_prompt])
191
 
192
  with torch.no_grad():
193
  # extract condition features
@@ -209,11 +234,10 @@ class Model:
209
  mode = 'sketch',
210
  con_strength = con_strength)
211
 
212
- x_samples_ddim = self.model.decode_first_stage(samples_ddim)
213
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
214
  x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).cpu().numpy()[0]
215
  x_samples_ddim = 255.*x_samples_ddim
216
  x_samples_ddim = x_samples_ddim.astype(np.uint8)
217
 
218
- return [im_pose, x_samples_ddim]
219
-
 
19
  import subprocess
20
  import sys
21
 
22
+ import mmcv
23
+ from mmdet.apis import inference_detector, init_detector
24
+ from mmpose.apis import (inference_top_down_pose_model, init_pose_model, process_mmdet_results, vis_pose_result)
25
+
26
+ skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], [8, 10],
27
+ [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]]
28
+
29
+ pose_kpt_color = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0],
30
+ [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0],
31
+ [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]]
32
+
33
+ pose_link_color = [[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0],
34
+ [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0], [255, 128, 0],
35
+ [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
36
+ [51, 153, 255], [51, 153, 255], [51, 153, 255]]
37
+
38
+
39
  sys.path.append('T2I-Adapter')
40
 
41
  config_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/configs/stable-diffusion/'
 
118
  def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
119
  global current_base
120
  device = 'cuda'
121
+ if base_model == 'sd-v1-4.ckpt':
122
+ model = self.model
123
+ else:
124
+ model = self.model_anything
125
  # if current_base != base_model:
126
  # ckpt = os.path.join("models", base_model)
127
  # pl_sd = torch.load(ckpt, map_location="cpu")
 
153
  im = im.float()
154
  im_edge = tensor2img(im)
155
 
156
+ c = model.get_learned_conditioning([prompt])
157
+ nc = model.get_learned_conditioning([neg_prompt])
158
 
159
  with torch.no_grad():
160
  # extract condition features
 
176
  mode = 'sketch',
177
  con_strength = con_strength)
178
 
179
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
180
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
181
  x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).cpu().numpy()[0]
182
  x_samples_ddim = 255.*x_samples_ddim
 
187
  @torch.inference_mode()
188
  def process_pose(self, input_img, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
189
  global current_base
190
+ device = 'cuda'
191
+ if base_model == 'sd-v1-4.ckpt':
192
+ model = self.model
193
+ else:
194
+ model = self.model_anything
195
  # if current_base != base_model:
196
  # ckpt = os.path.join("models", base_model)
197
  # pl_sd = torch.load(ckpt, map_location="cpu")
 
211
 
212
  im_pose = tensor2img(pose)
213
 
214
+ c = model.get_learned_conditioning([prompt])
215
+ nc = model.get_learned_conditioning([neg_prompt])
216
 
217
  with torch.no_grad():
218
  # extract condition features
 
234
  mode = 'sketch',
235
  con_strength = con_strength)
236
 
237
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
238
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
239
  x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).cpu().numpy()[0]
240
  x_samples_ddim = 255.*x_samples_ddim
241
  x_samples_ddim = x_samples_ddim.astype(np.uint8)
242
 
243
+ return [im_pose, x_samples_ddim]