zaidmehdi commited on
Commit
9d724cb
1 Parent(s): a025587

specify map_location

Browse files
Files changed (1) hide show
  1. src/main.py +2 -1
src/main.py CHANGED
@@ -17,7 +17,8 @@ models_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
17
  model_file = os.path.join(models_dir, 'best_model_checkpoint.pth')
18
  if os.path.exists(model_file):
19
  with open(model_file, "rb") as f:
20
- checkpoint = torch.load(model_file)
 
21
  model.load_state_dict(checkpoint)
22
  else:
23
  print(f"Error: {model_file} not found.")
 
17
  model_file = os.path.join(models_dir, 'best_model_checkpoint.pth')
18
  if os.path.exists(model_file):
19
  with open(model_file, "rb") as f:
20
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
21
+ checkpoint = torch.load(model_file, map_location=device)
22
  model.load_state_dict(checkpoint)
23
  else:
24
  print(f"Error: {model_file} not found.")