Spaces:
Runtime error
Runtime error
RamAnanth1
commited on
Commit
•
f016e80
1
Parent(s):
55b8905
Update model.py
Browse files
model.py
CHANGED
@@ -138,6 +138,7 @@ class Model:
|
|
138 |
self.device = torch.device(
|
139 |
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
140 |
self.model_dir = pathlib.Path(model_dir)
|
|
|
141 |
self.download_pose_models()
|
142 |
self.download_models()
|
143 |
|
@@ -147,14 +148,14 @@ class Model:
|
|
147 |
device = "cuda"
|
148 |
det_config_file = model_path+"faster_rcnn_r50_fpn_coco.py"
|
149 |
subprocess.run(shlex.split(f'wget {det_config_file} -O models/faster_rcnn_r50_fpn_coco.py'))
|
150 |
-
det_config = 'models/faster_rcnn_r50_fpn_coco.py'
|
151 |
|
152 |
det_checkpoint_file = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
|
153 |
subprocess.run(shlex.split(f'wget {det_checkpoint_file} -O models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'))
|
154 |
det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
|
155 |
|
156 |
pose_config_file = model_path+"hrnet_w48_coco_256x192.py"
|
157 |
-
subprocess.run(shlex.split(f'wget {pose_config_file} -O models/
|
158 |
pose_config = 'models/hrnet_w48_coco_256x192.py'
|
159 |
|
160 |
pose_checkpoint_file = "https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth"
|
@@ -170,7 +171,6 @@ class Model:
|
|
170 |
self.pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device)
|
171 |
|
172 |
def download_models(self) -> None:
|
173 |
-
self.model_dir.mkdir(exist_ok=True, parents=True)
|
174 |
device = 'cuda'
|
175 |
|
176 |
config = OmegaConf.load("configs/stable-diffusion/test_sketch.yaml")
|
|
|
138 |
self.device = torch.device(
|
139 |
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
140 |
self.model_dir = pathlib.Path(model_dir)
|
141 |
+
self.model_dir.mkdir(exist_ok=True, parents=True)
|
142 |
self.download_pose_models()
|
143 |
self.download_models()
|
144 |
|
|
|
148 |
device = "cuda"
|
149 |
det_config_file = model_path+"faster_rcnn_r50_fpn_coco.py"
|
150 |
subprocess.run(shlex.split(f'wget {det_config_file} -O models/faster_rcnn_r50_fpn_coco.py'))
|
151 |
+
det_config = 'models/faster_rcnn_r50_fpn_coco.py'
|
152 |
|
153 |
det_checkpoint_file = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
|
154 |
subprocess.run(shlex.split(f'wget {det_checkpoint_file} -O models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'))
|
155 |
det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
|
156 |
|
157 |
pose_config_file = model_path+"hrnet_w48_coco_256x192.py"
|
158 |
+
subprocess.run(shlex.split(f'wget {pose_config_file} -O models/hrnet_w48_coco_256x192.py'))
|
159 |
pose_config = 'models/hrnet_w48_coco_256x192.py'
|
160 |
|
161 |
pose_checkpoint_file = "https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth"
|
|
|
171 |
self.pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device)
|
172 |
|
173 |
def download_models(self) -> None:
|
|
|
174 |
device = 'cuda'
|
175 |
|
176 |
config = OmegaConf.load("configs/stable-diffusion/test_sketch.yaml")
|