WindVChen commited on
Commit
a9ce22e
β€’
1 Parent(s): ef47508

Upload 2 files

Browse files
efficient_inference_for_square_image.py CHANGED
@@ -335,7 +335,7 @@ def main_process(opt, composite_image=None, mask=None):
335
 
336
  model = build_model(opt).to(opt.device)
337
 
338
- load_dict = torch.load(opt.pretrained)['model']
339
  for k in load_dict.keys():
340
  if k not in model.state_dict().keys():
341
  print(f"Skip {k}")
 
335
 
336
  model = build_model(opt).to(opt.device)
337
 
338
+ load_dict = torch.load(opt.pretrained, map_location='cpu')['model']
339
  for k in load_dict.keys():
340
  if k not in model.state_dict().keys():
341
  print(f"Skip {k}")
inference_for_arbitrary_resolution_image.py CHANGED
@@ -327,7 +327,7 @@ def main_process(opt, composite_image=None, mask=None):
327
 
328
  model = build_model(opt).to(opt.device)
329
 
330
- load_dict = torch.load(opt.pretrained)['model']
331
  for k in load_dict.keys():
332
  if k not in model.state_dict().keys():
333
  print(f"Skip {k}")
 
327
 
328
  model = build_model(opt).to(opt.device)
329
 
330
+ load_dict = torch.load(opt.pretrained, map_location='cpu')['model']
331
  for k in load_dict.keys():
332
  if k not in model.state_dict().keys():
333
  print(f"Skip {k}")