skytnt commited on
Commit
a52dad5
·
1 Parent(s): a3e7293

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -206,7 +206,7 @@ if __name__ == "__main__":
206
  parser = argparse.ArgumentParser()
207
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
208
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
209
- parser.add_argument("--max-gen", type=int, default=512, help="max")
210
  opt = parser.parse_args()
211
  soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
212
  model_base_path = hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx")
@@ -257,17 +257,17 @@ if __name__ == "__main__":
257
  tab1.select(lambda: 0, None, tab_select, queue=False)
258
  tab2.select(lambda: 1, None, tab_select, queue=False)
259
  input_gen_events = gr.Slider(label="generate n midi events", minimum=1, maximum=opt.max_gen,
260
- step=1, value=opt.max_gen)
261
  input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
262
  input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
263
  input_top_k = gr.Slider(label="top k", minimum=1, maximum=50, step=1, value=20)
264
  input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
265
  run_btn = gr.Button("generate", variant="primary")
266
- stop_btn = gr.Button("stop")
267
  output_midi_seq = gr.Variable()
268
  output_midi_img = gr.Image(label="output image")
269
  output_midi = gr.File(label="output midi", file_types=[".mid"])
270
- output_audio = gr.Audio(label="output audio", format="wav")
271
  run_event = run_btn.click(run, [tab_select, input_instruments, input_drum_kit, input_midi, input_midi_events,
272
  input_gen_events, input_temp, input_top_p, input_top_k,
273
  input_allow_cc],
 
206
  parser = argparse.ArgumentParser()
207
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
208
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
209
+ parser.add_argument("--max-gen", type=int, default=1024, help="max")
210
  opt = parser.parse_args()
211
  soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
212
  model_base_path = hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx")
 
257
  tab1.select(lambda: 0, None, tab_select, queue=False)
258
  tab2.select(lambda: 1, None, tab_select, queue=False)
259
  input_gen_events = gr.Slider(label="generate n midi events", minimum=1, maximum=opt.max_gen,
260
+ step=1, value=opt.max_gen // 2)
261
  input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
262
  input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
263
  input_top_k = gr.Slider(label="top k", minimum=1, maximum=50, step=1, value=20)
264
  input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
265
  run_btn = gr.Button("generate", variant="primary")
266
+ stop_btn = gr.Button("stop and output")
267
  output_midi_seq = gr.Variable()
268
  output_midi_img = gr.Image(label="output image")
269
  output_midi = gr.File(label="output midi", file_types=[".mid"])
270
+ output_audio = gr.Audio(label="output audio", format="mp3")
271
  run_event = run_btn.click(run, [tab_select, input_instruments, input_drum_kit, input_midi, input_midi_events,
272
  input_gen_events, input_temp, input_top_p, input_top_k,
273
  input_allow_cc],