skytnt commited on
Commit
86f7f0a
1 Parent(s): f0ad71a
Files changed (4) hide show
  1. README.md +2 -2
  2. app.py +41 -60
  3. midi_model.py +129 -0
  4. requirements.txt +4 -1
README.md CHANGED
@@ -3,8 +3,8 @@ title: Midi Music Generator
3
  emoji: 🎼🎶
4
  colorFrom: red
5
  colorTo: indigo
6
- sdk: docker
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
 
3
  emoji: 🎼🎶
4
  colorFrom: red
5
  colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 4.43.0
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
app.py CHANGED
@@ -1,79 +1,54 @@
1
  import argparse
2
  import glob
3
- import os.path
 
4
  import time
5
- import uuid
6
 
7
  import gradio as gr
8
  import numpy as np
9
- import onnxruntime as rt
 
 
10
  import tqdm
11
- import json
12
- from huggingface_hub import hf_hub_download
13
 
14
  import MIDI
15
- from midi_synthesizer import synthesis
16
  from midi_tokenizer import MIDITokenizer
 
 
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  in_space = os.getenv("SYSTEM") == "spaces"
20
 
21
 
22
- def softmax(x, axis):
23
- x_max = np.amax(x, axis=axis, keepdims=True)
24
- exp_x_shifted = np.exp(x - x_max)
25
- return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
26
-
27
-
28
- def sample_top_p_k(probs, p, k, generator=None):
29
- if generator is None:
30
- generator = np.random
31
- probs_idx = np.argsort(-probs, axis=-1)
32
- probs_sort = np.take_along_axis(probs, probs_idx, -1)
33
- probs_sum = np.cumsum(probs_sort, axis=-1)
34
- mask = probs_sum - probs_sort > p
35
- probs_sort[mask] = 0.0
36
- mask = np.zeros(probs_sort.shape[-1])
37
- mask[:k] = 1
38
- probs_sort = probs_sort * mask
39
- probs_sort /= np.sum(probs_sort, axis=-1, keepdims=True)
40
- shape = probs_sort.shape
41
- probs_sort_flat = probs_sort.reshape(-1, shape[-1])
42
- probs_idx_flat = probs_idx.reshape(-1, shape[-1])
43
- next_token = np.stack([generator.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)])
44
- next_token = next_token.reshape(*shape[:-1])
45
- return next_token
46
-
47
-
48
  def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
49
- disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
50
  if disable_channels is not None:
51
  disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
52
  else:
53
  disable_channels = []
54
- if generator is None:
55
- generator = np.random
56
  max_token_seq = tokenizer.max_token_seq
57
  if prompt is None:
58
- input_tensor = np.full((1, max_token_seq), tokenizer.pad_id, dtype=np.int64)
59
  input_tensor[0, 0] = tokenizer.bos_id # bos
60
  else:
61
  prompt = prompt[:, :max_token_seq]
62
  if prompt.shape[-1] < max_token_seq:
63
  prompt = np.pad(prompt, ((0, 0), (0, max_token_seq - prompt.shape[-1])),
64
  mode="constant", constant_values=tokenizer.pad_id)
65
- input_tensor = prompt
66
- input_tensor = input_tensor[None, :, :]
67
  cur_len = input_tensor.shape[1]
68
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
69
- with bar:
70
  while cur_len < max_len:
71
  end = False
72
- hidden = model[0].run(None, {'x': input_tensor})[0][:, -1]
73
- next_token_seq = np.empty((1, 0), dtype=np.int64)
74
  event_name = ""
75
  for i in range(max_token_seq):
76
- mask = np.zeros(tokenizer.vocab_size, dtype=np.int64)
77
  if i == 0:
78
  mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
79
  if disable_patch_change:
@@ -87,9 +62,9 @@ def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
87
  if param_name == "channel":
88
  mask_ids = [i for i in mask_ids if i not in disable_channels]
89
  mask[mask_ids] = 1
90
- logits = model[1].run(None, {'x': next_token_seq, "hidden": hidden})[0][:, -1:]
91
- scores = softmax(logits / temp, -1) * mask
92
- sample = sample_top_p_k(scores, top_p, top_k, generator)
93
  if i == 0:
94
  next_token_seq = sample
95
  eid = sample.item()
@@ -98,17 +73,17 @@ def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
98
  break
99
  event_name = tokenizer.id_events[eid]
100
  else:
101
- next_token_seq = np.concatenate([next_token_seq, sample], axis=1)
102
  if len(tokenizer.events[event_name]) == i:
103
  break
104
  if next_token_seq.shape[1] < max_token_seq:
105
- next_token_seq = np.pad(next_token_seq, ((0, 0), (0, max_token_seq - next_token_seq.shape[-1])),
106
- mode="constant", constant_values=tokenizer.pad_id)
107
- next_token_seq = next_token_seq[None, :, :]
108
- input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
109
  cur_len += 1
110
  bar.update(1)
111
- yield next_token_seq.reshape(-1)
112
  if end:
113
  break
114
 
@@ -129,7 +104,7 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, midi_opt,
129
  max_len = gen_events
130
  if seed_rand:
131
  seed = np.random.randint(0, MAX_SEED)
132
- generator = np.random.RandomState(seed)
133
  disable_patch_change = False
134
  disable_channels = None
135
  if tab == 0:
@@ -160,14 +135,16 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, midi_opt,
160
  for token_seq in mid:
161
  mid_seq.append(token_seq.tolist())
162
  max_len += len(mid)
 
163
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
164
  init_msgs = [create_msg("visualizer_clear", None), create_msg("visualizer_append", events)]
165
  t = time.time() + 1
166
  yield mid_seq, None, None, seed, send_msgs(init_msgs)
167
  model = models[model_name]
 
168
  midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
169
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
170
- disable_channels=disable_channels, generator=generator)
171
  events = []
172
  for i, token_seq in enumerate(midi_generator):
173
  token_seq = token_seq.tolist()
@@ -245,15 +222,18 @@ if __name__ == "__main__":
245
  "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
246
  "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
247
  }
 
248
  models = {}
249
  tokenizer = MIDITokenizer()
250
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
251
  for name, (repo_id, path) in models_info.items():
252
- model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
253
- model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
254
- model_base = rt.InferenceSession(model_base_path, providers=providers)
255
- model_token = rt.InferenceSession(model_token_path, providers=providers)
256
- models[name] = [model_base, model_token]
 
 
 
257
 
258
  load_javascript()
259
  app = gr.Blocks()
@@ -265,7 +245,8 @@ if __name__ == "__main__":
265
  "[Open In Colab]"
266
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
267
  " for faster running and longer generation\n\n"
268
- "**Update v1.2**: Optimise the tokenizer and dataset"
 
269
  )
270
  js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
271
  js_msg.change(None, [js_msg], [], js="""
 
1
  import argparse
2
  import glob
3
+ import json
4
+ import os
5
  import time
 
6
 
7
  import gradio as gr
8
  import numpy as np
9
+ import torch
10
+
11
+ import torch.nn.functional as F
12
  import tqdm
 
 
13
 
14
  import MIDI
15
+ from midi_model import MIDIModel
16
  from midi_tokenizer import MIDITokenizer
17
+ from midi_synthesizer import synthesis
18
+ from huggingface_hub import hf_hub_download
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  in_space = os.getenv("SYSTEM") == "spaces"
22
 
23
 
24
+ @torch.inference_mode()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
26
+ disable_patch_change=False, disable_control_change=False, disable_channels=None, amp=True, generator=None):
27
  if disable_channels is not None:
28
  disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
29
  else:
30
  disable_channels = []
 
 
31
  max_token_seq = tokenizer.max_token_seq
32
  if prompt is None:
33
+ input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=model.device)
34
  input_tensor[0, 0] = tokenizer.bos_id # bos
35
  else:
36
  prompt = prompt[:, :max_token_seq]
37
  if prompt.shape[-1] < max_token_seq:
38
  prompt = np.pad(prompt, ((0, 0), (0, max_token_seq - prompt.shape[-1])),
39
  mode="constant", constant_values=tokenizer.pad_id)
40
+ input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=model.device)
41
+ input_tensor = input_tensor.unsqueeze(0)
42
  cur_len = input_tensor.shape[1]
43
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
44
+ with bar, torch.amp.autocast(device_type=model.device, enabled=amp):
45
  while cur_len < max_len:
46
  end = False
47
+ hidden = model.forward(input_tensor)[0, -1].unsqueeze(0)
48
+ next_token_seq = None
49
  event_name = ""
50
  for i in range(max_token_seq):
51
+ mask = torch.zeros(tokenizer.vocab_size, dtype=torch.int64, device=model.device)
52
  if i == 0:
53
  mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
54
  if disable_patch_change:
 
62
  if param_name == "channel":
63
  mask_ids = [i for i in mask_ids if i not in disable_channels]
64
  mask[mask_ids] = 1
65
+ logits = model.forward_token(hidden, next_token_seq)[:, -1:]
66
+ scores = torch.softmax(logits / temp, dim=-1) * mask
67
+ sample = model.sample_top_p_k(scores, top_p, top_k, generator=generator)
68
  if i == 0:
69
  next_token_seq = sample
70
  eid = sample.item()
 
73
  break
74
  event_name = tokenizer.id_events[eid]
75
  else:
76
+ next_token_seq = torch.cat([next_token_seq, sample], dim=1)
77
  if len(tokenizer.events[event_name]) == i:
78
  break
79
  if next_token_seq.shape[1] < max_token_seq:
80
+ next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
81
+ "constant", value=tokenizer.pad_id)
82
+ next_token_seq = next_token_seq.unsqueeze(1)
83
+ input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
84
  cur_len += 1
85
  bar.update(1)
86
+ yield next_token_seq.reshape(-1).cpu().numpy()
87
  if end:
88
  break
89
 
 
104
  max_len = gen_events
105
  if seed_rand:
106
  seed = np.random.randint(0, MAX_SEED)
107
+ generator = torch.Generator(device).manual_seed(seed)
108
  disable_patch_change = False
109
  disable_channels = None
110
  if tab == 0:
 
135
  for token_seq in mid:
136
  mid_seq.append(token_seq.tolist())
137
  max_len += len(mid)
138
+
139
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
140
  init_msgs = [create_msg("visualizer_clear", None), create_msg("visualizer_append", events)]
141
  t = time.time() + 1
142
  yield mid_seq, None, None, seed, send_msgs(init_msgs)
143
  model = models[model_name]
144
+ amp = device == "cuda"
145
  midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
146
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
147
+ disable_channels=disable_channels, amp=amp, generator=generator)
148
  events = []
149
  for i, token_seq in enumerate(midi_generator):
150
  token_seq = token_seq.tolist()
 
222
  "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
223
  "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
224
  }
225
+ device = "cuda" if torch.cuda.is_available() else "cpu"
226
  models = {}
227
  tokenizer = MIDITokenizer()
 
228
  for name, (repo_id, path) in models_info.items():
229
+
230
+ model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
231
+ model = MIDIModel(tokenizer).to(device=device)
232
+ ckpt = torch.load(model_path)
233
+ state_dict = ckpt.get("state_dict", ckpt)
234
+ model.load_state_dict(state_dict, strict=False)
235
+ model.eval()
236
+ models[name] = model
237
 
238
  load_javascript()
239
  app = gr.Blocks()
 
245
  "[Open In Colab]"
246
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
247
  " for faster running and longer generation\n\n"
248
+ "**Update v1.2**: Optimise the tokenizer and dataset\n\n"
249
+ f"Device: {device}"
250
  )
251
  js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
252
  js_msg.change(None, [js_msg], [], js="""
midi_model.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import tqdm
6
+ from transformers import LlamaModel, LlamaConfig
7
+
8
+ from midi_tokenizer import MIDITokenizer
9
+
10
+
11
+ class MIDIModel(nn.Module):
12
+ def __init__(self, tokenizer: MIDITokenizer, n_layer=12, n_head=16, n_embd=1024, n_inner=4096, flash=False,
13
+ *args, **kwargs):
14
+ super(MIDIModel, self).__init__()
15
+ self.tokenizer = tokenizer
16
+ self.net = LlamaModel(LlamaConfig(vocab_size=tokenizer.vocab_size,
17
+ hidden_size=n_embd, num_attention_heads=n_head,
18
+ num_hidden_layers=n_layer, intermediate_size=n_inner,
19
+ pad_token_id=tokenizer.pad_id, max_position_embeddings=4096))
20
+ self.net_token = LlamaModel(LlamaConfig(vocab_size=tokenizer.vocab_size,
21
+ hidden_size=n_embd, num_attention_heads=n_head // 4,
22
+ num_hidden_layers=n_layer // 4, intermediate_size=n_inner // 4,
23
+ pad_token_id=tokenizer.pad_id, max_position_embeddings=4096))
24
+ if flash:
25
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
26
+ torch.backends.cuda.enable_flash_sdp(True)
27
+ self.lm_head = nn.Linear(n_embd, tokenizer.vocab_size, bias=False)
28
+ self.device = "cpu"
29
+
30
+ def to(self, *args, **kwargs):
31
+ if "device" in kwargs:
32
+ self.device = kwargs["device"]
33
+ return super(MIDIModel, self).to(*args, **kwargs)
34
+
35
+ def forward_token(self, hidden_state, x=None):
36
+ """
37
+
38
+ :param hidden_state: (batch_size, n_embd)
39
+ :param x: (batch_size, token_sequence_length)
40
+ :return: (batch_size, 1 + token_sequence_length, vocab_size)
41
+ """
42
+ hidden_state = hidden_state.unsqueeze(1) # (batch_size, 1, n_embd)
43
+ if x is not None:
44
+ x = self.net_token.embed_tokens(x)
45
+ hidden_state = torch.cat([hidden_state, x], dim=1)
46
+ hidden_state = self.net_token.forward(inputs_embeds=hidden_state).last_hidden_state
47
+ return self.lm_head(hidden_state)
48
+
49
+ def forward(self, x):
50
+ """
51
+ :param x: (batch_size, midi_sequence_length, token_sequence_length)
52
+ :return: hidden (batch_size, midi_sequence_length, n_embd)
53
+ """
54
+
55
+ # merge token sequence
56
+ x = self.net.embed_tokens(x)
57
+ x = x.sum(dim=-2)
58
+ x = self.net.forward(inputs_embeds=x)
59
+ return x.last_hidden_state
60
+
61
+ def sample_top_p_k(self, probs, p, k, generator=None):
62
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
63
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
64
+ mask = probs_sum - probs_sort > p
65
+ probs_sort[mask] = 0.0
66
+ mask = torch.zeros(probs_sort.shape[-1], device=probs_sort.device)
67
+ mask[:k] = 1
68
+ probs_sort = probs_sort * mask
69
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
70
+ shape = probs_sort.shape
71
+ next_token = torch.multinomial(probs_sort.reshape(-1, shape[-1]),
72
+ num_samples=1, generator=generator).reshape(*shape[:-1], 1)
73
+ next_token = torch.gather(probs_idx, -1, next_token).reshape(*shape[:-1])
74
+ return next_token
75
+
76
+ @torch.inference_mode()
77
+ def generate(self, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20, amp=True, generator=None):
78
+ tokenizer = self.tokenizer
79
+ max_token_seq = tokenizer.max_token_seq
80
+ if prompt is None:
81
+ input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=self.device)
82
+ input_tensor[0, 0] = tokenizer.bos_id # bos
83
+ else:
84
+ prompt = prompt[:, :max_token_seq]
85
+ if prompt.shape[-1] < max_token_seq:
86
+ prompt = np.pad(prompt, ((0, 0), (0, max_token_seq - prompt.shape[-1])),
87
+ mode="constant", constant_values=tokenizer.pad_id)
88
+ input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=self.device)
89
+ input_tensor = input_tensor.unsqueeze(0)
90
+ cur_len = input_tensor.shape[1]
91
+ bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
92
+ with bar, torch.cuda.amp.autocast(enabled=amp):
93
+ while cur_len < max_len:
94
+ end = False
95
+ hidden = self.forward(input_tensor)[0, -1].unsqueeze(0)
96
+ next_token_seq = None
97
+ event_name = ""
98
+ for i in range(max_token_seq):
99
+ mask = torch.zeros(tokenizer.vocab_size, dtype=torch.int64, device=self.device)
100
+ if i == 0:
101
+ mask[list(tokenizer.event_ids.values()) + [tokenizer.eos_id]] = 1
102
+ else:
103
+ param_name = tokenizer.events[event_name][i - 1]
104
+ mask[tokenizer.parameter_ids[param_name]] = 1
105
+
106
+ logits = self.forward_token(hidden, next_token_seq)[:, -1:]
107
+ scores = torch.softmax(logits / temp, dim=-1) * mask
108
+ sample = self.sample_top_p_k(scores, top_p, top_k, generator=generator)
109
+ if i == 0:
110
+ next_token_seq = sample
111
+ eid = sample.item()
112
+ if eid == tokenizer.eos_id:
113
+ end = True
114
+ break
115
+ event_name = tokenizer.id_events[eid]
116
+ else:
117
+ next_token_seq = torch.cat([next_token_seq, sample], dim=1)
118
+ if len(tokenizer.events[event_name]) == i:
119
+ break
120
+ if next_token_seq.shape[1] < max_token_seq:
121
+ next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
122
+ "constant", value=tokenizer.pad_id)
123
+ next_token_seq = next_token_seq.unsqueeze(1)
124
+ input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
125
+ cur_len += 1
126
+ bar.update(1)
127
+ if end:
128
+ break
129
+ return input_tensor[0].cpu().numpy()
requirements.txt CHANGED
@@ -1,5 +1,8 @@
1
  Pillow
2
  numpy
3
- onnxruntime-gpu
 
4
  gradio==4.43.0
5
  pyfluidsynth
 
 
 
1
  Pillow
2
  numpy
3
+ torch
4
+ transformers>=4.36
5
  gradio==4.43.0
6
  pyfluidsynth
7
+ tqdm
8
+ huggingface_hub