skytnt commited on
Commit
8fdb145
1 Parent(s): fd012a7
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import argparse
2
  import glob
3
  import json
@@ -94,11 +95,12 @@ def create_msg(name, data):
94
  def send_msgs(msgs):
95
  return json.dumps(msgs)
96
 
97
-
98
  def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm, time_sig, key_sig, mid, midi_events,
99
  reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
100
  gen_events, temp, top_p, top_k, allow_cc):
101
  model = models[model_name]
 
102
  tokenizer = model.tokenizer
103
  bpm = int(bpm)
104
  if time_sig == "auto":
@@ -300,10 +302,9 @@ if __name__ == "__main__":
300
  for name, (repo_id, path, config) in models_info.items():
301
  model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
302
  model = MIDIModel(config=MIDIModelConfig.from_name(config))
303
- ckpt = torch.load(model_path, map_location="cpu")
304
  state_dict = ckpt.get("state_dict", ckpt)
305
  model.load_state_dict(state_dict, strict=False)
306
- model.to(device=opt.device, dtype=torch.bfloat16 if opt.device == "cuda" else torch.float32).eval()
307
  models[name] = model
308
 
309
  load_javascript()
 
1
+ import spaces
2
  import argparse
3
  import glob
4
  import json
 
95
  def send_msgs(msgs):
96
  return json.dumps(msgs)
97
 
98
+ @spaces.GPU()
99
  def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm, time_sig, key_sig, mid, midi_events,
100
  reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
101
  gen_events, temp, top_p, top_k, allow_cc):
102
  model = models[model_name]
103
+ model.to(device=opt.device, dtype=torch.bfloat16 if opt.device == "cuda" else torch.float32).eval()
104
  tokenizer = model.tokenizer
105
  bpm = int(bpm)
106
  if time_sig == "auto":
 
302
  for name, (repo_id, path, config) in models_info.items():
303
  model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
304
  model = MIDIModel(config=MIDIModelConfig.from_name(config))
305
+ ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
306
  state_dict = ckpt.get("state_dict", ckpt)
307
  model.load_state_dict(state_dict, strict=False)
 
308
  models[name] = model
309
 
310
  load_javascript()