RamAnanth1 commited on
Commit
98fc92a
1 Parent(s): 6a8f88e

Add mmpose

Browse files
Files changed (1) hide show
  1. model.py +130 -4
model.py CHANGED
@@ -41,6 +41,73 @@ sys.path.append('T2I-Adapter')
41
  config_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/configs/stable-diffusion/'
42
  model_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/models/'
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def load_model_from_config(config, ckpt, verbose=False):
45
  print(f"Loading model from {ckpt}")
46
  pl_sd = torch.load(ckpt, map_location="cpu")
@@ -71,10 +138,36 @@ class Model:
71
  self.device = torch.device(
72
  'cuda:0' if torch.cuda.is_available() else 'cpu')
73
  self.model_dir = pathlib.Path(model_dir)
74
-
75
  self.download_models()
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  def download_models(self) -> None:
80
  self.model_dir.mkdir(exist_ok=True, parents=True)
@@ -206,16 +299,49 @@ class Model:
206
  seed_everything(42)
207
 
208
  im = cv2.resize(input_img,(512,512))
209
- pose = img2tensor(im, bgr2rgb=True, float32=True)/255.
210
- pose = pose.unsqueeze(0)
211
 
212
- im_pose = tensor2img(pose)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  c = model.get_learned_conditioning([prompt])
215
  nc = model.get_learned_conditioning([neg_prompt])
216
 
217
  with torch.no_grad():
218
  # extract condition features
 
 
219
  features_adapter = self.model_ad_pose(pose.to(device))
220
 
221
  shape = [4, 64, 64]
 
41
  config_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/configs/stable-diffusion/'
42
  model_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/models/'
43
 
44
+
45
+ def imshow_keypoints(img,
46
+ pose_result,
47
+ skeleton=None,
48
+ kpt_score_thr=0.1,
49
+ pose_kpt_color=None,
50
+ pose_link_color=None,
51
+ radius=4,
52
+ thickness=1):
53
+ """Draw keypoints and links on an image.
54
+ Args:
55
+ img (ndarry): The image to draw poses on.
56
+ pose_result (list[kpts]): The poses to draw. Each element kpts is
57
+ a set of K keypoints as an Kx3 numpy.ndarray, where each
58
+ keypoint is represented as x, y, score.
59
+ kpt_score_thr (float, optional): Minimum score of keypoints
60
+ to be shown. Default: 0.3.
61
+ pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
62
+ the keypoint will not be drawn.
63
+ pose_link_color (np.array[Mx3]): Color of M links. If None, the
64
+ links will not be drawn.
65
+ thickness (int): Thickness of lines.
66
+ """
67
+
68
+ img_h, img_w, _ = img.shape
69
+ img = np.zeros(img.shape)
70
+
71
+ for idx, kpts in enumerate(pose_result):
72
+ if idx > 1:
73
+ continue
74
+ kpts = kpts['keypoints']
75
+ # print(kpts)
76
+ kpts = np.array(kpts, copy=False)
77
+
78
+ # draw each point on image
79
+ if pose_kpt_color is not None:
80
+ assert len(pose_kpt_color) == len(kpts)
81
+
82
+ for kid, kpt in enumerate(kpts):
83
+ x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
84
+
85
+ if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
86
+ # skip the point that should not be drawn
87
+ continue
88
+
89
+ color = tuple(int(c) for c in pose_kpt_color[kid])
90
+ cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1)
91
+
92
+ # draw links
93
+ if skeleton is not None and pose_link_color is not None:
94
+ assert len(pose_link_color) == len(skeleton)
95
+
96
+ for sk_id, sk in enumerate(skeleton):
97
+ pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
98
+ pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
99
+
100
+ if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0
101
+ or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr
102
+ or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None):
103
+ # skip the link that should not be drawn
104
+ continue
105
+ color = tuple(int(c) for c in pose_link_color[sk_id])
106
+ cv2.line(img, pos1, pos2, color, thickness=thickness)
107
+
108
+ return img
109
+
110
+
111
  def load_model_from_config(config, ckpt, verbose=False):
112
  print(f"Loading model from {ckpt}")
113
  pl_sd = torch.load(ckpt, map_location="cpu")
 
138
  self.device = torch.device(
139
  'cuda:0' if torch.cuda.is_available() else 'cpu')
140
  self.model_dir = pathlib.Path(model_dir)
141
+ self.download_pose_models()
142
  self.download_models()
143
 
144
 
145
+ def download_pose_models(self) -> None:
146
+ ## mmpose
147
+ device = "cuda"
148
+ det_config_file = model_path+"faster_rcnn_r50_fpn_coco.py"
149
+ subprocess.run(shlex.split(f'wget {det_config_file} -O models/faster_rcnn_r50_fpn_coco.py'))
150
+ det_config = 'models/faster_rcnn_r50_fpn_coco.py'
151
+
152
+ det_checkpoint_file = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
153
+ subprocess.run(shlex.split(f'wget {det_checkpoint_file} -O models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'))
154
+ det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
155
+
156
+ pose_config_file = model_path+"rnet_w48_coco_256x192.py"
157
+ subprocess.run(shlex.split(f'wget {pose_config_file} -O models/rnet_w48_coco_256x192.py'))
158
+ pose_config = 'models/hrnet_w48_coco_256x192.py'
159
+
160
+ pose_checkpoint_file = "https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth"
161
+ subprocess.run(shlex.split(f'wget {pose_checkpoint_file} -O models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'))
162
+ pose_checkpoint = 'models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
163
+
164
+ det_cat_id = 1
165
+ bbox_thr = 0.2
166
+ ## detector
167
+ det_config_mmcv = mmcv.Config.fromfile(det_config)
168
+ self.det_model = init_detector(det_config_mmcv, det_checkpoint, device=device)
169
+ pose_config_mmcv = mmcv.Config.fromfile(pose_config)
170
+ self.pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device)
171
 
172
  def download_models(self) -> None:
173
  self.model_dir.mkdir(exist_ok=True, parents=True)
 
299
  seed_everything(42)
300
 
301
  im = cv2.resize(input_img,(512,512))
 
 
302
 
303
+ image = im.copy()
304
+ im = img2tensor(im).unsqueeze(0)/255.
305
+ mmdet_results = inference_detector(det_model, image)
306
+ # keep the person class bounding boxes.
307
+ person_results = process_mmdet_results(mmdet_results, det_cat_id)
308
+
309
+ # optional
310
+ return_heatmap = False
311
+ dataset = pose_model.cfg.data['test']['type']
312
+
313
+ # e.g. use ('backbone', ) to return backbone feature
314
+ output_layer_names = None
315
+ pose_results, returned_outputs = inference_top_down_pose_model(
316
+ pose_model,
317
+ image,
318
+ person_results,
319
+ bbox_thr=bbox_thr,
320
+ format='xyxy',
321
+ dataset=dataset,
322
+ dataset_info=None,
323
+ return_heatmap=return_heatmap,
324
+ outputs=output_layer_names)
325
+
326
+ # show the results
327
+ im_pose = imshow_keypoints(
328
+ image,
329
+ pose_results,
330
+ skeleton=skeleton,
331
+ pose_kpt_color=pose_kpt_color,
332
+ pose_link_color=pose_link_color,
333
+ radius=2,
334
+ thickness=2)
335
+
336
+ im_pose = cv2.resize(im_pose,(512,512))
337
 
338
  c = model.get_learned_conditioning([prompt])
339
  nc = model.get_learned_conditioning([neg_prompt])
340
 
341
  with torch.no_grad():
342
  # extract condition features
343
+ pose = img2tensor(im_pose, bgr2rgb=True, float32=True)/255.
344
+ pose = pose.unsqueeze(0)
345
  features_adapter = self.model_ad_pose(pose.to(device))
346
 
347
  shape = [4, 64, 64]