skytnt commited on
Commit
91c8ccf
1 Parent(s): 4540a78

fix streaming

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -159,9 +159,8 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, seed, see
159
  for token_seq in mid:
160
  mid_seq.append(token_seq.tolist())
161
  max_len += len(mid)
162
- init_msgs = [create_msg("visualizer_clear", False)]
163
- for tokens in mid_seq:
164
- init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
165
  t = time.time()
166
  yield mid_seq, None, None, seed, send_msgs(init_msgs)
167
  model = models[model_name]
@@ -174,7 +173,7 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, seed, see
174
  mid_seq.append(token_seq)
175
  events.append(tokenizer.tokens2event(token_seq))
176
  ct = time.time()
177
- if ct - t > 1:
178
  yield mid_seq, None, None, seed, send_msgs([create_msg("visualizer_append", events), create_msg("progress", [i + 1, gen_events])])
179
  t = ct
180
  events = []
@@ -242,8 +241,8 @@ if __name__ == "__main__":
242
  opt = parser.parse_args()
243
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
244
  models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
245
- "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
246
- "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
247
  }
248
  models = {}
249
  tokenizer = MIDITokenizer()
 
159
  for token_seq in mid:
160
  mid_seq.append(token_seq.tolist())
161
  max_len += len(mid)
162
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
163
+ init_msgs = [create_msg("visualizer_clear", None), create_msg("visualizer_append", events)]
 
164
  t = time.time()
165
  yield mid_seq, None, None, seed, send_msgs(init_msgs)
166
  model = models[model_name]
 
173
  mid_seq.append(token_seq)
174
  events.append(tokenizer.tokens2event(token_seq))
175
  ct = time.time()
176
+ if ct - t > 0.1:
177
  yield mid_seq, None, None, seed, send_msgs([create_msg("visualizer_append", events), create_msg("progress", [i + 1, gen_events])])
178
  t = ct
179
  events = []
 
241
  opt = parser.parse_args()
242
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
243
  models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
244
+ # "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
245
+ # "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
246
  }
247
  models = {}
248
  tokenizer = MIDITokenizer()