Spaces:
Runtime error
Runtime error
zerogpu
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import argparse
|
2 |
import glob
|
3 |
import json
|
@@ -94,11 +95,12 @@ def create_msg(name, data):
|
|
94 |
def send_msgs(msgs):
|
95 |
return json.dumps(msgs)
|
96 |
|
97 |
-
|
98 |
def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm, time_sig, key_sig, mid, midi_events,
|
99 |
reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
|
100 |
gen_events, temp, top_p, top_k, allow_cc):
|
101 |
model = models[model_name]
|
|
|
102 |
tokenizer = model.tokenizer
|
103 |
bpm = int(bpm)
|
104 |
if time_sig == "auto":
|
@@ -300,10 +302,9 @@ if __name__ == "__main__":
|
|
300 |
for name, (repo_id, path, config) in models_info.items():
|
301 |
model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
|
302 |
model = MIDIModel(config=MIDIModelConfig.from_name(config))
|
303 |
-
ckpt = torch.load(model_path, map_location="cpu")
|
304 |
state_dict = ckpt.get("state_dict", ckpt)
|
305 |
model.load_state_dict(state_dict, strict=False)
|
306 |
-
model.to(device=opt.device, dtype=torch.bfloat16 if opt.device == "cuda" else torch.float32).eval()
|
307 |
models[name] = model
|
308 |
|
309 |
load_javascript()
|
|
|
1 |
+
import spaces
|
2 |
import argparse
|
3 |
import glob
|
4 |
import json
|
|
|
95 |
def send_msgs(msgs):
|
96 |
return json.dumps(msgs)
|
97 |
|
98 |
+
@spaces.GPU()
|
99 |
def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm, time_sig, key_sig, mid, midi_events,
|
100 |
reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
|
101 |
gen_events, temp, top_p, top_k, allow_cc):
|
102 |
model = models[model_name]
|
103 |
+
model.to(device=opt.device, dtype=torch.bfloat16 if opt.device == "cuda" else torch.float32).eval()
|
104 |
tokenizer = model.tokenizer
|
105 |
bpm = int(bpm)
|
106 |
if time_sig == "auto":
|
|
|
302 |
for name, (repo_id, path, config) in models_info.items():
|
303 |
model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
|
304 |
model = MIDIModel(config=MIDIModelConfig.from_name(config))
|
305 |
+
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
306 |
state_dict = ckpt.get("state_dict", ckpt)
|
307 |
model.load_state_dict(state_dict, strict=False)
|
|
|
308 |
models[name] = model
|
309 |
|
310 |
load_javascript()
|