RamAnanth1 commited on
Commit
f016e80
1 Parent(s): 55b8905

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +3 -3
model.py CHANGED
@@ -138,6 +138,7 @@ class Model:
138
  self.device = torch.device(
139
  'cuda:0' if torch.cuda.is_available() else 'cpu')
140
  self.model_dir = pathlib.Path(model_dir)
 
141
  self.download_pose_models()
142
  self.download_models()
143
 
@@ -147,14 +148,14 @@ class Model:
147
  device = "cuda"
148
  det_config_file = model_path+"faster_rcnn_r50_fpn_coco.py"
149
  subprocess.run(shlex.split(f'wget {det_config_file} -O models/faster_rcnn_r50_fpn_coco.py'))
150
- det_config = 'models/faster_rcnn_r50_fpn_coco.py'
151
 
152
  det_checkpoint_file = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
153
  subprocess.run(shlex.split(f'wget {det_checkpoint_file} -O models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'))
154
  det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
155
 
156
  pose_config_file = model_path+"hrnet_w48_coco_256x192.py"
157
- subprocess.run(shlex.split(f'wget {pose_config_file} -O models/rnet_w48_coco_256x192.py'))
158
  pose_config = 'models/hrnet_w48_coco_256x192.py'
159
 
160
  pose_checkpoint_file = "https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth"
@@ -170,7 +171,6 @@ class Model:
170
  self.pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device)
171
 
172
  def download_models(self) -> None:
173
- self.model_dir.mkdir(exist_ok=True, parents=True)
174
  device = 'cuda'
175
 
176
  config = OmegaConf.load("configs/stable-diffusion/test_sketch.yaml")
 
138
  self.device = torch.device(
139
  'cuda:0' if torch.cuda.is_available() else 'cpu')
140
  self.model_dir = pathlib.Path(model_dir)
141
+ self.model_dir.mkdir(exist_ok=True, parents=True)
142
  self.download_pose_models()
143
  self.download_models()
144
 
 
148
  device = "cuda"
149
  det_config_file = model_path+"faster_rcnn_r50_fpn_coco.py"
150
  subprocess.run(shlex.split(f'wget {det_config_file} -O models/faster_rcnn_r50_fpn_coco.py'))
151
+ det_config = 'models/faster_rcnn_r50_fpn_coco.py'
152
 
153
  det_checkpoint_file = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
154
  subprocess.run(shlex.split(f'wget {det_checkpoint_file} -O models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'))
155
  det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
156
 
157
  pose_config_file = model_path+"hrnet_w48_coco_256x192.py"
158
+ subprocess.run(shlex.split(f'wget {pose_config_file} -O models/hrnet_w48_coco_256x192.py'))
159
  pose_config = 'models/hrnet_w48_coco_256x192.py'
160
 
161
  pose_checkpoint_file = "https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth"
 
171
  self.pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device)
172
 
173
  def download_models(self) -> None:
 
174
  device = 'cuda'
175
 
176
  config = OmegaConf.load("configs/stable-diffusion/test_sketch.yaml")