skytnt commited on
Commit
1f0da43
1 Parent(s): aa6fbf4
Files changed (1) hide show
  1. app.py +29 -18
app.py CHANGED
@@ -14,6 +14,7 @@ import MIDI
14
  from midi_synthesizer import synthesis
15
  from midi_tokenizer import MIDITokenizer
16
 
 
17
  in_space = os.getenv("SYSTEM") == "spaces"
18
 
19
 
@@ -23,7 +24,9 @@ def softmax(x, axis):
23
  return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
24
 
25
 
26
- def sample_top_p_k(probs, p, k):
 
 
27
  probs_idx = np.argsort(-probs, axis=-1)
28
  probs_sort = np.take_along_axis(probs, probs_idx, -1)
29
  probs_sum = np.cumsum(probs_sort, axis=-1)
@@ -36,17 +39,19 @@ def sample_top_p_k(probs, p, k):
36
  shape = probs_sort.shape
37
  probs_sort_flat = probs_sort.reshape(-1, shape[-1])
38
  probs_idx_flat = probs_idx.reshape(-1, shape[-1])
39
- next_token = np.stack([np.random.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)])
40
  next_token = next_token.reshape(*shape[:-1])
41
  return next_token
42
 
43
 
44
  def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
45
- disable_patch_change=False, disable_control_change=False, disable_channels=None):
46
  if disable_channels is not None:
47
  disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
48
  else:
49
  disable_channels = []
 
 
50
  max_token_seq = tokenizer.max_token_seq
51
  if prompt is None:
52
  input_tensor = np.full((1, max_token_seq), tokenizer.pad_id, dtype=np.int64)
@@ -83,7 +88,7 @@ def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
83
  mask[mask_ids] = 1
84
  logits = model[1].run(None, {'x': next_token_seq, "hidden": hidden})[0][:, -1:]
85
  scores = softmax(logits / temp, -1) * mask
86
- sample = sample_top_p_k(scores, top_p, top_k)
87
  if i == 0:
88
  next_token_seq = sample
89
  eid = sample.item()
@@ -120,13 +125,16 @@ def send_msgs(msgs, msgs_history=None):
120
  return json.dumps(msgs_history)
121
 
122
 
123
- def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
 
124
  msgs_history = []
125
  mid_seq = []
126
  bpm = int(bpm)
127
  gen_events = int(gen_events)
128
  max_len = gen_events
129
-
 
 
130
  disable_patch_change = False
131
  disable_channels = None
132
  if tab == 0:
@@ -159,22 +167,22 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, gen_event
159
  init_msgs = [create_msg("visualizer_clear", False)]
160
  for tokens in mid_seq:
161
  init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
162
- yield mid_seq, None, None, send_msgs(init_msgs, msgs_history)
163
  model = models[model_name]
164
- generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
165
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
166
- disable_channels=disable_channels)
167
- for i, token_seq in enumerate(generator):
168
  token_seq = token_seq.tolist()
169
  mid_seq.append(token_seq)
170
  event = tokenizer.tokens2event(token_seq)
171
- yield mid_seq, None, None, send_msgs([create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])], msgs_history)
172
  mid = tokenizer.detokenize(mid_seq)
173
  with open(f"output.mid", 'wb') as f:
174
  f.write(MIDI.score2midi(mid))
175
  audio = synthesis(MIDI.score2opus(mid), soundfont_path)
176
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
177
- yield mid_seq, "output.mid", (44100, audio), send_msgs([create_msg("visualizer_end", events)])
178
 
179
 
180
  def cancel_run(mid_seq):
@@ -232,8 +240,8 @@ if __name__ == "__main__":
232
  opt = parser.parse_args()
233
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
234
  models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
235
- "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
236
- "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
237
  }
238
  models = {}
239
  tokenizer = MIDITokenizer()
@@ -301,7 +309,10 @@ if __name__ == "__main__":
301
 
302
  tab1.select(lambda: 0, None, tab_select, queue=False)
303
  tab2.select(lambda: 1, None, tab_select, queue=False)
304
- input_gen_events = gr.Slider(label="generate n midi events", minimum=1, maximum=opt.max_gen,
 
 
 
305
  step=1, value=opt.max_gen // 2)
306
  with gr.Accordion("options", open=False):
307
  input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
@@ -316,9 +327,9 @@ if __name__ == "__main__":
316
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
317
  output_midi = gr.File(label="output midi", file_types=[".mid"])
318
  run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_bpm,
319
- input_midi, input_midi_events, input_gen_events, input_temp,
320
- input_top_p, input_top_k, input_allow_cc],
321
- [output_midi_seq, output_midi, output_audio, js_msg],
322
  concurrency_limit=3)
323
  stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
324
  app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
 
14
  from midi_synthesizer import synthesis
15
  from midi_tokenizer import MIDITokenizer
16
 
17
+ MAX_SEED = np.iinfo(np.int32).max
18
  in_space = os.getenv("SYSTEM") == "spaces"
19
 
20
 
 
24
  return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
25
 
26
 
27
+ def sample_top_p_k(probs, p, k, generator=None):
28
+ if generator is None:
29
+ generator = np.random
30
  probs_idx = np.argsort(-probs, axis=-1)
31
  probs_sort = np.take_along_axis(probs, probs_idx, -1)
32
  probs_sum = np.cumsum(probs_sort, axis=-1)
 
39
  shape = probs_sort.shape
40
  probs_sort_flat = probs_sort.reshape(-1, shape[-1])
41
  probs_idx_flat = probs_idx.reshape(-1, shape[-1])
42
+ next_token = np.stack([generator.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)])
43
  next_token = next_token.reshape(*shape[:-1])
44
  return next_token
45
 
46
 
47
  def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
48
+ disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
49
  if disable_channels is not None:
50
  disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
51
  else:
52
  disable_channels = []
53
+ if generator is None:
54
+ generator = np.random
55
  max_token_seq = tokenizer.max_token_seq
56
  if prompt is None:
57
  input_tensor = np.full((1, max_token_seq), tokenizer.pad_id, dtype=np.int64)
 
88
  mask[mask_ids] = 1
89
  logits = model[1].run(None, {'x': next_token_seq, "hidden": hidden})[0][:, -1:]
90
  scores = softmax(logits / temp, -1) * mask
91
+ sample = sample_top_p_k(scores, top_p, top_k, generator)
92
  if i == 0:
93
  next_token_seq = sample
94
  eid = sample.item()
 
125
  return json.dumps(msgs_history)
126
 
127
 
128
+ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, seed, seed_rand,
129
+ gen_events, temp, top_p, top_k, allow_cc):
130
  msgs_history = []
131
  mid_seq = []
132
  bpm = int(bpm)
133
  gen_events = int(gen_events)
134
  max_len = gen_events
135
+ if seed_rand:
136
+ seed = np.random.randint(0, MAX_SEED)
137
+ generator = np.random.RandomState(seed)
138
  disable_patch_change = False
139
  disable_channels = None
140
  if tab == 0:
 
167
  init_msgs = [create_msg("visualizer_clear", False)]
168
  for tokens in mid_seq:
169
  init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
170
+ yield mid_seq, None, None, seed, send_msgs(init_msgs, msgs_history)
171
  model = models[model_name]
172
+ midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
173
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
174
+ disable_channels=disable_channels, generator=generator)
175
+ for i, token_seq in enumerate(midi_generator):
176
  token_seq = token_seq.tolist()
177
  mid_seq.append(token_seq)
178
  event = tokenizer.tokens2event(token_seq)
179
+ yield mid_seq, None, None, seed, send_msgs([create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])], msgs_history)
180
  mid = tokenizer.detokenize(mid_seq)
181
  with open(f"output.mid", 'wb') as f:
182
  f.write(MIDI.score2midi(mid))
183
  audio = synthesis(MIDI.score2opus(mid), soundfont_path)
184
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
185
+ yield mid_seq, "output.mid", (44100, audio), seed, send_msgs([create_msg("visualizer_end", events)])
186
 
187
 
188
  def cancel_run(mid_seq):
 
240
  opt = parser.parse_args()
241
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
242
  models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
243
+ # "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
244
+ # "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
245
  }
246
  models = {}
247
  tokenizer = MIDITokenizer()
 
309
 
310
  tab1.select(lambda: 0, None, tab_select, queue=False)
311
  tab2.select(lambda: 1, None, tab_select, queue=False)
312
+ input_seed = gr.Slider(label="seed", minimum=0, maximum=2 ** 31 - 1,
313
+ step=1, value=0)
314
+ input_seed_rand = gr.Checkbox(label="random seed", value=True)
315
+ input_gen_events = gr.Slider(label="generate max n midi events", minimum=1, maximum=opt.max_gen,
316
  step=1, value=opt.max_gen // 2)
317
  with gr.Accordion("options", open=False):
318
  input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
 
327
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
328
  output_midi = gr.File(label="output midi", file_types=[".mid"])
329
  run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_bpm,
330
+ input_midi, input_midi_events, input_seed, input_seed_rand, input_gen_events,
331
+ input_temp, input_top_p, input_top_k, input_allow_cc],
332
+ [output_midi_seq, output_midi, output_audio, input_seed, js_msg],
333
  concurrency_limit=3)
334
  stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
335
  app.launch(server_port=opt.port, share=opt.share, inbrowser=True)