skytnt commited on
Commit
4540a78
1 Parent(s): 9958d06

fix streaming

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -162,19 +162,19 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, seed, see
162
  init_msgs = [create_msg("visualizer_clear", False)]
163
  for tokens in mid_seq:
164
  init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
 
165
  yield mid_seq, None, None, seed, send_msgs(init_msgs)
166
  model = models[model_name]
167
  midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
168
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
169
  disable_channels=disable_channels, generator=generator)
170
- t = time.time()
171
  events = []
172
  for i, token_seq in enumerate(midi_generator):
173
  token_seq = token_seq.tolist()
174
  mid_seq.append(token_seq)
175
  events.append(tokenizer.tokens2event(token_seq))
176
  ct = time.time()
177
- if ct - t > 0.2:
178
  yield mid_seq, None, None, seed, send_msgs([create_msg("visualizer_append", events), create_msg("progress", [i + 1, gen_events])])
179
  t = ct
180
  events = []
 
162
  init_msgs = [create_msg("visualizer_clear", False)]
163
  for tokens in mid_seq:
164
  init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
165
+ t = time.time()
166
  yield mid_seq, None, None, seed, send_msgs(init_msgs)
167
  model = models[model_name]
168
  midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
169
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
170
  disable_channels=disable_channels, generator=generator)
 
171
  events = []
172
  for i, token_seq in enumerate(midi_generator):
173
  token_seq = token_seq.tolist()
174
  mid_seq.append(token_seq)
175
  events.append(tokenizer.tokens2event(token_seq))
176
  ct = time.time()
177
+ if ct - t > 1:
178
  yield mid_seq, None, None, seed, send_msgs([create_msg("visualizer_append", events), create_msg("progress", [i + 1, gen_events])])
179
  t = ct
180
  events = []