Spaces:
Running
on
Zero
Running
on
Zero
No Cuda Stuff for Zero
Browse files- models/salmonn.py +2 -2
models/salmonn.py
CHANGED
@@ -66,7 +66,7 @@ class SALMONN(nn.Module):
|
|
66 |
|
67 |
# beats
|
68 |
self.beats_ckpt = beats_path
|
69 |
-
beats_checkpoint = torch.load(self.beats_ckpt
|
70 |
beats_cfg = BEATsConfig(beats_checkpoint["cfg"])
|
71 |
beats = BEATs(beats_cfg)
|
72 |
beats.load_state_dict(beats_checkpoint["model"])
|
@@ -130,7 +130,7 @@ class SALMONN(nn.Module):
|
|
130 |
).to(device)
|
131 |
|
132 |
# load ckpt
|
133 |
-
ckpt_dict = torch.load(ckpt)["model"]
|
134 |
self.load_state_dict(ckpt_dict, strict=False)
|
135 |
|
136 |
def generate(
|
|
|
66 |
|
67 |
# beats
|
68 |
self.beats_ckpt = beats_path
|
69 |
+
beats_checkpoint = torch.load(self.beats_ckpt)
|
70 |
beats_cfg = BEATsConfig(beats_checkpoint["cfg"])
|
71 |
beats = BEATs(beats_cfg)
|
72 |
beats.load_state_dict(beats_checkpoint["model"])
|
|
|
130 |
).to(device)
|
131 |
|
132 |
# load ckpt
|
133 |
+
ckpt_dict = torch.load(ckpt, map_location="cpu")["model"]
|
134 |
self.load_state_dict(ckpt_dict, strict=False)
|
135 |
|
136 |
def generate(
|