Spaces:
Runtime error
Runtime error
fix midi visualizer
Browse files- app.py +20 -29
- javascript/app.js +4 -0
app.py
CHANGED
@@ -111,7 +111,15 @@ def create_msg(name, data):
|
|
111 |
return {"name": name, "data": data, "uuid": uuid.uuid4().hex}
|
112 |
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
|
|
|
115 |
mid_seq = []
|
116 |
gen_events = int(gen_events)
|
117 |
max_len = gen_events
|
@@ -146,7 +154,7 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
|
|
146 |
init_msgs = [create_msg("visualizer_clear", None)]
|
147 |
for tokens in mid_seq:
|
148 |
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
149 |
-
yield mid_seq, None, None, init_msgs
|
150 |
model = models[model_name]
|
151 |
generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
152 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
@@ -155,22 +163,22 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
|
|
155 |
token_seq = token_seq.tolist()
|
156 |
mid_seq.append(token_seq)
|
157 |
event = tokenizer.tokens2event(token_seq)
|
158 |
-
yield mid_seq, None, None, [create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])]
|
159 |
mid = tokenizer.detokenize(mid_seq)
|
160 |
with open(f"output.mid", 'wb') as f:
|
161 |
f.write(MIDI.score2midi(mid))
|
162 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
163 |
-
yield mid_seq, "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
|
164 |
|
165 |
|
166 |
-
def cancel_run(mid_seq):
|
167 |
if mid_seq is None:
|
168 |
return None, None, []
|
169 |
mid = tokenizer.detokenize(mid_seq)
|
170 |
with open(f"output.mid", 'wb') as f:
|
171 |
f.write(MIDI.score2midi(mid))
|
172 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
173 |
-
return "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
|
174 |
|
175 |
|
176 |
def load_javascript(dir="javascript"):
|
@@ -191,25 +199,6 @@ def load_javascript(dir="javascript"):
|
|
191 |
gr.routes.templates.TemplateResponse = template_response
|
192 |
|
193 |
|
194 |
-
# JSMsgReceiver
|
195 |
-
Textbox_postprocess_ori = gr.Textbox.postprocess
|
196 |
-
|
197 |
-
msg_history = []
|
198 |
-
|
199 |
-
|
200 |
-
# the change event may not trigger every time, so send msg history to avoid msg missing.
|
201 |
-
def JSMsgReceiver_postprocess(self, y):
|
202 |
-
global msg_history
|
203 |
-
if self.elem_id == "msg_receiver" and y:
|
204 |
-
msg_history.append(y)
|
205 |
-
if len(msg_history) > 50:
|
206 |
-
msg_history = msg_history[1:]
|
207 |
-
y = json.dumps(msg_history)
|
208 |
-
return Textbox_postprocess_ori(self, y)
|
209 |
-
|
210 |
-
|
211 |
-
gr.Textbox.postprocess = JSMsgReceiver_postprocess
|
212 |
-
|
213 |
number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
|
214 |
40: "Blush", 48: "Orchestra"}
|
215 |
patch2number = {v: k for k, v in MIDI.Number2patch.items()}
|
@@ -223,8 +212,8 @@ if __name__ == "__main__":
|
|
223 |
opt = parser.parse_args()
|
224 |
soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
225 |
models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
|
226 |
-
"j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
|
227 |
-
"touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
|
228 |
}
|
229 |
models = {}
|
230 |
tokenizer = MIDITokenizer()
|
@@ -247,6 +236,7 @@ if __name__ == "__main__":
|
|
247 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
248 |
" for faster running and longer generation"
|
249 |
)
|
|
|
250 |
js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
|
251 |
js_msg.change(None, [js_msg], [], js="""
|
252 |
(msg_json) =>{
|
@@ -302,6 +292,7 @@ if __name__ == "__main__":
|
|
302 |
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_midi,
|
303 |
input_midi_events, input_gen_events, input_temp, input_top_p, input_top_k,
|
304 |
input_allow_cc],
|
305 |
-
[output_midi_seq, output_midi, output_audio, js_msg]
|
306 |
-
|
307 |
-
|
|
|
|
111 |
return {"name": name, "data": data, "uuid": uuid.uuid4().hex}
|
112 |
|
113 |
|
114 |
+
def send_msgs(msgs, msgs_history):
|
115 |
+
msgs_history.append(msgs)
|
116 |
+
if len(msgs_history) > 50:
|
117 |
+
msgs_history.pop(0)
|
118 |
+
return json.dumps(msgs_history)
|
119 |
+
|
120 |
+
|
121 |
def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
|
122 |
+
msgs_history = []
|
123 |
mid_seq = []
|
124 |
gen_events = int(gen_events)
|
125 |
max_len = gen_events
|
|
|
154 |
init_msgs = [create_msg("visualizer_clear", None)]
|
155 |
for tokens in mid_seq:
|
156 |
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
157 |
+
yield mid_seq, None, None, send_msgs(init_msgs, msgs_history), msgs_history
|
158 |
model = models[model_name]
|
159 |
generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
160 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
|
|
163 |
token_seq = token_seq.tolist()
|
164 |
mid_seq.append(token_seq)
|
165 |
event = tokenizer.tokens2event(token_seq)
|
166 |
+
yield mid_seq, None, None, send_msgs([create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])], msgs_history), msgs_history
|
167 |
mid = tokenizer.detokenize(mid_seq)
|
168 |
with open(f"output.mid", 'wb') as f:
|
169 |
f.write(MIDI.score2midi(mid))
|
170 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
171 |
+
yield mid_seq, "output.mid", (44100, audio), send_msgs([create_msg("visualizer_end", None)], msgs_history), msgs_history
|
172 |
|
173 |
|
174 |
+
def cancel_run(mid_seq, msgs_history):
|
175 |
if mid_seq is None:
|
176 |
return None, None, []
|
177 |
mid = tokenizer.detokenize(mid_seq)
|
178 |
with open(f"output.mid", 'wb') as f:
|
179 |
f.write(MIDI.score2midi(mid))
|
180 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
181 |
+
return "output.mid", (44100, audio), send_msgs([create_msg("visualizer_end", None)], msgs_history)
|
182 |
|
183 |
|
184 |
def load_javascript(dir="javascript"):
|
|
|
199 |
gr.routes.templates.TemplateResponse = template_response
|
200 |
|
201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
|
203 |
40: "Blush", 48: "Orchestra"}
|
204 |
patch2number = {v: k for k, v in MIDI.Number2patch.items()}
|
|
|
212 |
opt = parser.parse_args()
|
213 |
soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
214 |
models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
|
215 |
+
# "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
|
216 |
+
# "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
|
217 |
}
|
218 |
models = {}
|
219 |
tokenizer = MIDITokenizer()
|
|
|
236 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
237 |
" for faster running and longer generation"
|
238 |
)
|
239 |
+
js_msg_history_state = gr.State(value=[])
|
240 |
js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
|
241 |
js_msg.change(None, [js_msg], [], js="""
|
242 |
(msg_json) =>{
|
|
|
292 |
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_midi,
|
293 |
input_midi_events, input_gen_events, input_temp, input_top_p, input_top_k,
|
294 |
input_allow_cc],
|
295 |
+
[output_midi_seq, output_midi, output_audio, js_msg, js_msg_history_state],
|
296 |
+
concurrency_limit=3)
|
297 |
+
stop_btn.click(cancel_run, [output_midi_seq, js_msg_history_state], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
|
298 |
+
app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
|
javascript/app.js
CHANGED
@@ -316,6 +316,10 @@ class MidiVisualizer extends HTMLElement{
|
|
316 |
audio.addEventListener("pause", (event)=>{
|
317 |
this.pause()
|
318 |
})
|
|
|
|
|
|
|
|
|
319 |
}
|
320 |
|
321 |
bindWaveformCursor(cursor){
|
|
|
316 |
audio.addEventListener("pause", (event)=>{
|
317 |
this.pause()
|
318 |
})
|
319 |
+
audio.addEventListener("loadedmetadata", (event)=>{
|
320 |
+
//I don't know why the calculated totalTimeMs is different from audio.duration*10**3
|
321 |
+
this.totalTimeMs = audio.duration*10**3;
|
322 |
+
})
|
323 |
}
|
324 |
|
325 |
bindWaveformCursor(cursor){
|