Spaces:
Runtime error
Runtime error
RamAnanth1
commited on
Commit
•
98fc92a
1
Parent(s):
6a8f88e
Add mmpose
Browse files
model.py
CHANGED
@@ -41,6 +41,73 @@ sys.path.append('T2I-Adapter')
|
|
41 |
config_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/configs/stable-diffusion/'
|
42 |
model_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/models/'
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
def load_model_from_config(config, ckpt, verbose=False):
|
45 |
print(f"Loading model from {ckpt}")
|
46 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
@@ -71,10 +138,36 @@ class Model:
|
|
71 |
self.device = torch.device(
|
72 |
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
73 |
self.model_dir = pathlib.Path(model_dir)
|
74 |
-
|
75 |
self.download_models()
|
76 |
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
def download_models(self) -> None:
|
80 |
self.model_dir.mkdir(exist_ok=True, parents=True)
|
@@ -206,16 +299,49 @@ class Model:
|
|
206 |
seed_everything(42)
|
207 |
|
208 |
im = cv2.resize(input_img,(512,512))
|
209 |
-
pose = img2tensor(im, bgr2rgb=True, float32=True)/255.
|
210 |
-
pose = pose.unsqueeze(0)
|
211 |
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
c = model.get_learned_conditioning([prompt])
|
215 |
nc = model.get_learned_conditioning([neg_prompt])
|
216 |
|
217 |
with torch.no_grad():
|
218 |
# extract condition features
|
|
|
|
|
219 |
features_adapter = self.model_ad_pose(pose.to(device))
|
220 |
|
221 |
shape = [4, 64, 64]
|
|
|
41 |
config_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/configs/stable-diffusion/'
|
42 |
model_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/models/'
|
43 |
|
44 |
+
|
45 |
+
def imshow_keypoints(img,
|
46 |
+
pose_result,
|
47 |
+
skeleton=None,
|
48 |
+
kpt_score_thr=0.1,
|
49 |
+
pose_kpt_color=None,
|
50 |
+
pose_link_color=None,
|
51 |
+
radius=4,
|
52 |
+
thickness=1):
|
53 |
+
"""Draw keypoints and links on an image.
|
54 |
+
Args:
|
55 |
+
img (ndarry): The image to draw poses on.
|
56 |
+
pose_result (list[kpts]): The poses to draw. Each element kpts is
|
57 |
+
a set of K keypoints as an Kx3 numpy.ndarray, where each
|
58 |
+
keypoint is represented as x, y, score.
|
59 |
+
kpt_score_thr (float, optional): Minimum score of keypoints
|
60 |
+
to be shown. Default: 0.3.
|
61 |
+
pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
|
62 |
+
the keypoint will not be drawn.
|
63 |
+
pose_link_color (np.array[Mx3]): Color of M links. If None, the
|
64 |
+
links will not be drawn.
|
65 |
+
thickness (int): Thickness of lines.
|
66 |
+
"""
|
67 |
+
|
68 |
+
img_h, img_w, _ = img.shape
|
69 |
+
img = np.zeros(img.shape)
|
70 |
+
|
71 |
+
for idx, kpts in enumerate(pose_result):
|
72 |
+
if idx > 1:
|
73 |
+
continue
|
74 |
+
kpts = kpts['keypoints']
|
75 |
+
# print(kpts)
|
76 |
+
kpts = np.array(kpts, copy=False)
|
77 |
+
|
78 |
+
# draw each point on image
|
79 |
+
if pose_kpt_color is not None:
|
80 |
+
assert len(pose_kpt_color) == len(kpts)
|
81 |
+
|
82 |
+
for kid, kpt in enumerate(kpts):
|
83 |
+
x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
|
84 |
+
|
85 |
+
if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
|
86 |
+
# skip the point that should not be drawn
|
87 |
+
continue
|
88 |
+
|
89 |
+
color = tuple(int(c) for c in pose_kpt_color[kid])
|
90 |
+
cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1)
|
91 |
+
|
92 |
+
# draw links
|
93 |
+
if skeleton is not None and pose_link_color is not None:
|
94 |
+
assert len(pose_link_color) == len(skeleton)
|
95 |
+
|
96 |
+
for sk_id, sk in enumerate(skeleton):
|
97 |
+
pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
|
98 |
+
pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
|
99 |
+
|
100 |
+
if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0
|
101 |
+
or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr
|
102 |
+
or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None):
|
103 |
+
# skip the link that should not be drawn
|
104 |
+
continue
|
105 |
+
color = tuple(int(c) for c in pose_link_color[sk_id])
|
106 |
+
cv2.line(img, pos1, pos2, color, thickness=thickness)
|
107 |
+
|
108 |
+
return img
|
109 |
+
|
110 |
+
|
111 |
def load_model_from_config(config, ckpt, verbose=False):
|
112 |
print(f"Loading model from {ckpt}")
|
113 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
|
138 |
self.device = torch.device(
|
139 |
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
140 |
self.model_dir = pathlib.Path(model_dir)
|
141 |
+
self.download_pose_models()
|
142 |
self.download_models()
|
143 |
|
144 |
|
145 |
+
def download_pose_models(self) -> None:
|
146 |
+
## mmpose
|
147 |
+
device = "cuda"
|
148 |
+
det_config_file = model_path+"faster_rcnn_r50_fpn_coco.py"
|
149 |
+
subprocess.run(shlex.split(f'wget {det_config_file} -O models/faster_rcnn_r50_fpn_coco.py'))
|
150 |
+
det_config = 'models/faster_rcnn_r50_fpn_coco.py'
|
151 |
+
|
152 |
+
det_checkpoint_file = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
|
153 |
+
subprocess.run(shlex.split(f'wget {det_checkpoint_file} -O models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'))
|
154 |
+
det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
|
155 |
+
|
156 |
+
pose_config_file = model_path+"rnet_w48_coco_256x192.py"
|
157 |
+
subprocess.run(shlex.split(f'wget {pose_config_file} -O models/rnet_w48_coco_256x192.py'))
|
158 |
+
pose_config = 'models/hrnet_w48_coco_256x192.py'
|
159 |
+
|
160 |
+
pose_checkpoint_file = "https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth"
|
161 |
+
subprocess.run(shlex.split(f'wget {pose_checkpoint_file} -O models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'))
|
162 |
+
pose_checkpoint = 'models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
|
163 |
+
|
164 |
+
det_cat_id = 1
|
165 |
+
bbox_thr = 0.2
|
166 |
+
## detector
|
167 |
+
det_config_mmcv = mmcv.Config.fromfile(det_config)
|
168 |
+
self.det_model = init_detector(det_config_mmcv, det_checkpoint, device=device)
|
169 |
+
pose_config_mmcv = mmcv.Config.fromfile(pose_config)
|
170 |
+
self.pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device)
|
171 |
|
172 |
def download_models(self) -> None:
|
173 |
self.model_dir.mkdir(exist_ok=True, parents=True)
|
|
|
299 |
seed_everything(42)
|
300 |
|
301 |
im = cv2.resize(input_img,(512,512))
|
|
|
|
|
302 |
|
303 |
+
image = im.copy()
|
304 |
+
im = img2tensor(im).unsqueeze(0)/255.
|
305 |
+
mmdet_results = inference_detector(det_model, image)
|
306 |
+
# keep the person class bounding boxes.
|
307 |
+
person_results = process_mmdet_results(mmdet_results, det_cat_id)
|
308 |
+
|
309 |
+
# optional
|
310 |
+
return_heatmap = False
|
311 |
+
dataset = pose_model.cfg.data['test']['type']
|
312 |
+
|
313 |
+
# e.g. use ('backbone', ) to return backbone feature
|
314 |
+
output_layer_names = None
|
315 |
+
pose_results, returned_outputs = inference_top_down_pose_model(
|
316 |
+
pose_model,
|
317 |
+
image,
|
318 |
+
person_results,
|
319 |
+
bbox_thr=bbox_thr,
|
320 |
+
format='xyxy',
|
321 |
+
dataset=dataset,
|
322 |
+
dataset_info=None,
|
323 |
+
return_heatmap=return_heatmap,
|
324 |
+
outputs=output_layer_names)
|
325 |
+
|
326 |
+
# show the results
|
327 |
+
im_pose = imshow_keypoints(
|
328 |
+
image,
|
329 |
+
pose_results,
|
330 |
+
skeleton=skeleton,
|
331 |
+
pose_kpt_color=pose_kpt_color,
|
332 |
+
pose_link_color=pose_link_color,
|
333 |
+
radius=2,
|
334 |
+
thickness=2)
|
335 |
+
|
336 |
+
im_pose = cv2.resize(im_pose,(512,512))
|
337 |
|
338 |
c = model.get_learned_conditioning([prompt])
|
339 |
nc = model.get_learned_conditioning([neg_prompt])
|
340 |
|
341 |
with torch.no_grad():
|
342 |
# extract condition features
|
343 |
+
pose = img2tensor(im_pose, bgr2rgb=True, float32=True)/255.
|
344 |
+
pose = pose.unsqueeze(0)
|
345 |
features_adapter = self.model_ad_pose(pose.to(device))
|
346 |
|
347 |
shape = [4, 64, 64]
|