WillHeld commited on
Commit
6b02748
1 Parent(s): db669a0

No Cuda Stuff for Zero

Browse files
Files changed (1) hide show
  1. models/salmonn.py +2 -2
models/salmonn.py CHANGED
@@ -66,7 +66,7 @@ class SALMONN(nn.Module):
66
 
67
  # beats
68
  self.beats_ckpt = beats_path
69
- beats_checkpoint = torch.load(self.beats_ckpt, map_location=device)
70
  beats_cfg = BEATsConfig(beats_checkpoint["cfg"])
71
  beats = BEATs(beats_cfg)
72
  beats.load_state_dict(beats_checkpoint["model"])
@@ -130,7 +130,7 @@ class SALMONN(nn.Module):
130
  ).to(device)
131
 
132
  # load ckpt
133
- ckpt_dict = torch.load(ckpt)["model"]
134
  self.load_state_dict(ckpt_dict, strict=False)
135
 
136
  def generate(
 
66
 
67
  # beats
68
  self.beats_ckpt = beats_path
69
+ beats_checkpoint = torch.load(self.beats_ckpt)
70
  beats_cfg = BEATsConfig(beats_checkpoint["cfg"])
71
  beats = BEATs(beats_cfg)
72
  beats.load_state_dict(beats_checkpoint["model"])
 
130
  ).to(device)
131
 
132
  # load ckpt
133
+ ckpt_dict = torch.load(ckpt, map_location="cpu")["model"]
134
  self.load_state_dict(ckpt_dict, strict=False)
135
 
136
  def generate(