skytnt commited on
Commit
5825808
1 Parent(s): 3c03946

fix midi visualizer

Browse files
Files changed (2) hide show
  1. app.py +20 -29
  2. javascript/app.js +4 -0
app.py CHANGED
@@ -111,7 +111,15 @@ def create_msg(name, data):
111
  return {"name": name, "data": data, "uuid": uuid.uuid4().hex}
112
 
113
 
 
 
 
 
 
 
 
114
  def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
 
115
  mid_seq = []
116
  gen_events = int(gen_events)
117
  max_len = gen_events
@@ -146,7 +154,7 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
146
  init_msgs = [create_msg("visualizer_clear", None)]
147
  for tokens in mid_seq:
148
  init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
149
- yield mid_seq, None, None, init_msgs
150
  model = models[model_name]
151
  generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
152
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
@@ -155,22 +163,22 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
155
  token_seq = token_seq.tolist()
156
  mid_seq.append(token_seq)
157
  event = tokenizer.tokens2event(token_seq)
158
- yield mid_seq, None, None, [create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])]
159
  mid = tokenizer.detokenize(mid_seq)
160
  with open(f"output.mid", 'wb') as f:
161
  f.write(MIDI.score2midi(mid))
162
  audio = synthesis(MIDI.score2opus(mid), soundfont_path)
163
- yield mid_seq, "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
164
 
165
 
166
- def cancel_run(mid_seq):
167
  if mid_seq is None:
168
  return None, None, []
169
  mid = tokenizer.detokenize(mid_seq)
170
  with open(f"output.mid", 'wb') as f:
171
  f.write(MIDI.score2midi(mid))
172
  audio = synthesis(MIDI.score2opus(mid), soundfont_path)
173
- return "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
174
 
175
 
176
  def load_javascript(dir="javascript"):
@@ -191,25 +199,6 @@ def load_javascript(dir="javascript"):
191
  gr.routes.templates.TemplateResponse = template_response
192
 
193
 
194
- # JSMsgReceiver
195
- Textbox_postprocess_ori = gr.Textbox.postprocess
196
-
197
- msg_history = []
198
-
199
-
200
- # the change event may not trigger every time, so send msg history to avoid msg missing.
201
- def JSMsgReceiver_postprocess(self, y):
202
- global msg_history
203
- if self.elem_id == "msg_receiver" and y:
204
- msg_history.append(y)
205
- if len(msg_history) > 50:
206
- msg_history = msg_history[1:]
207
- y = json.dumps(msg_history)
208
- return Textbox_postprocess_ori(self, y)
209
-
210
-
211
- gr.Textbox.postprocess = JSMsgReceiver_postprocess
212
-
213
  number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
214
  40: "Blush", 48: "Orchestra"}
215
  patch2number = {v: k for k, v in MIDI.Number2patch.items()}
@@ -223,8 +212,8 @@ if __name__ == "__main__":
223
  opt = parser.parse_args()
224
  soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
225
  models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
226
- "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
227
- "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
228
  }
229
  models = {}
230
  tokenizer = MIDITokenizer()
@@ -247,6 +236,7 @@ if __name__ == "__main__":
247
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
248
  " for faster running and longer generation"
249
  )
 
250
  js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
251
  js_msg.change(None, [js_msg], [], js="""
252
  (msg_json) =>{
@@ -302,6 +292,7 @@ if __name__ == "__main__":
302
  run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_midi,
303
  input_midi_events, input_gen_events, input_temp, input_top_p, input_top_k,
304
  input_allow_cc],
305
- [output_midi_seq, output_midi, output_audio, js_msg])
306
- stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
307
- app.queue(2).launch(server_port=opt.port, share=opt.share, inbrowser=True)
 
 
111
  return {"name": name, "data": data, "uuid": uuid.uuid4().hex}
112
 
113
 
114
+ def send_msgs(msgs, msgs_history):
115
+ msgs_history.append(msgs)
116
+ if len(msgs_history) > 50:
117
+ msgs_history.pop(0)
118
+ return json.dumps(msgs_history)
119
+
120
+
121
  def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
122
+ msgs_history = []
123
  mid_seq = []
124
  gen_events = int(gen_events)
125
  max_len = gen_events
 
154
  init_msgs = [create_msg("visualizer_clear", None)]
155
  for tokens in mid_seq:
156
  init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
157
+ yield mid_seq, None, None, send_msgs(init_msgs, msgs_history), msgs_history
158
  model = models[model_name]
159
  generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
160
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
 
163
  token_seq = token_seq.tolist()
164
  mid_seq.append(token_seq)
165
  event = tokenizer.tokens2event(token_seq)
166
+ yield mid_seq, None, None, send_msgs([create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])], msgs_history), msgs_history
167
  mid = tokenizer.detokenize(mid_seq)
168
  with open(f"output.mid", 'wb') as f:
169
  f.write(MIDI.score2midi(mid))
170
  audio = synthesis(MIDI.score2opus(mid), soundfont_path)
171
+ yield mid_seq, "output.mid", (44100, audio), send_msgs([create_msg("visualizer_end", None)], msgs_history), msgs_history
172
 
173
 
174
+ def cancel_run(mid_seq, msgs_history):
175
  if mid_seq is None:
176
  return None, None, []
177
  mid = tokenizer.detokenize(mid_seq)
178
  with open(f"output.mid", 'wb') as f:
179
  f.write(MIDI.score2midi(mid))
180
  audio = synthesis(MIDI.score2opus(mid), soundfont_path)
181
+ return "output.mid", (44100, audio), send_msgs([create_msg("visualizer_end", None)], msgs_history)
182
 
183
 
184
  def load_javascript(dir="javascript"):
 
199
  gr.routes.templates.TemplateResponse = template_response
200
 
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
203
  40: "Blush", 48: "Orchestra"}
204
  patch2number = {v: k for k, v in MIDI.Number2patch.items()}
 
212
  opt = parser.parse_args()
213
  soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
214
  models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
215
+ # "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
216
+ # "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
217
  }
218
  models = {}
219
  tokenizer = MIDITokenizer()
 
236
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
237
  " for faster running and longer generation"
238
  )
239
+ js_msg_history_state = gr.State(value=[])
240
  js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
241
  js_msg.change(None, [js_msg], [], js="""
242
  (msg_json) =>{
 
292
  run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_midi,
293
  input_midi_events, input_gen_events, input_temp, input_top_p, input_top_k,
294
  input_allow_cc],
295
+ [output_midi_seq, output_midi, output_audio, js_msg, js_msg_history_state],
296
+ concurrency_limit=3)
297
+ stop_btn.click(cancel_run, [output_midi_seq, js_msg_history_state], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
298
+ app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
javascript/app.js CHANGED
@@ -316,6 +316,10 @@ class MidiVisualizer extends HTMLElement{
316
  audio.addEventListener("pause", (event)=>{
317
  this.pause()
318
  })
 
 
 
 
319
  }
320
 
321
  bindWaveformCursor(cursor){
 
316
  audio.addEventListener("pause", (event)=>{
317
  this.pause()
318
  })
319
+ audio.addEventListener("loadedmetadata", (event)=>{
320
+ //I don't know why the calculated totalTimeMs is different from audio.duration*10**3
321
+ this.totalTimeMs = audio.duration*10**3;
322
+ })
323
  }
324
 
325
  bindWaveformCursor(cursor){