skytnt commited on
Commit
e47ecb8
1 Parent(s): d660a99
Files changed (2) hide show
  1. midi_synthesizer.py +1 -1
  2. midi_tokenizer.py +77 -6
midi_synthesizer.py CHANGED
@@ -14,7 +14,7 @@ def synthesis(midi_opus, soundfont_path, sample_rate=44100):
14
  event_list.append(event_new)
15
  event_list = sorted(event_list, key=lambda e: e[1])
16
 
17
- tempo = int((60 / 140) * 10 ** 6) # default 140 bpm
18
  ss = np.empty((0, 2), dtype=np.int16)
19
  fl = fluidsynth.Synth(samplerate=float(sample_rate))
20
  sfid = fl.sfload(soundfont_path)
 
14
  event_list.append(event_new)
15
  event_list = sorted(event_list, key=lambda e: e[1])
16
 
17
+ tempo = int((60 / 120) * 10 ** 6) # default 120 bpm
18
  ss = np.empty((0, 2), dtype=np.int16)
19
  fl = fluidsynth.Synth(samplerate=float(sample_rate))
20
  sfid = fl.sfload(soundfont_path)
midi_tokenizer.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import PIL
2
  import numpy as np
3
 
@@ -43,19 +45,31 @@ class MIDITokenizer:
43
  def tokenize(self, midi_score, add_bos_eos=True):
44
  ticks_per_beat = midi_score[0]
45
  event_list = {}
46
- track_num = len(midi_score[1:])
47
  for track_idx, track in enumerate(midi_score[1:129]):
 
48
  for event in track:
49
- t = round(16 * event[1] / ticks_per_beat)
50
  new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
51
  if event[0] == "note":
52
  new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
53
  elif event[0] == "set_tempo":
54
  new_event[4] = int(self.tempo2bpm(new_event[4]))
55
- key = hash(tuple(new_event[:-1]))
 
 
 
 
 
 
 
 
 
 
 
 
56
  event_list[key] = new_event
57
  event_list = list(event_list.values())
58
- event_list = sorted(event_list, key=lambda e: (e[1] * 16 + e[2]) * track_num + e[3])
59
  midi_seq = []
60
 
61
  last_t1 = 0
@@ -113,18 +127,24 @@ class MIDITokenizer:
113
  tracks_dict[track_idx] = []
114
  tracks_dict[track_idx].append([event[0], t] + event[4:])
115
  tracks = list(tracks_dict.values())
116
- for i in range(len(tracks)):
 
117
  track = tracks[i]
118
  track = sorted(track, key=lambda e: e[1])
119
  last_note_t = {}
 
120
  for e in reversed(track):
121
  if e[0] == "note":
122
  t, d, c, p = e[1:5]
123
  key = (c, p)
124
  if key in last_note_t:
125
- d = min(d, max(last_note_t[key] - t, 0)) # to avoid note overlap
126
  last_note_t[key] = t
127
  e[2] = d
 
 
 
 
128
  tracks[i] = track
129
  return [ticks_per_beat, *tracks]
130
 
@@ -148,3 +168,54 @@ class MIDITokenizer:
148
  img[p, t: t + d] = colors[(tr, c)]
149
  img = PIL.Image.fromarray(np.flip(img, 0))
150
  return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
  import PIL
4
  import numpy as np
5
 
 
45
  def tokenize(self, midi_score, add_bos_eos=True):
46
  ticks_per_beat = midi_score[0]
47
  event_list = {}
 
48
  for track_idx, track in enumerate(midi_score[1:129]):
49
+ last_notes = {}
50
  for event in track:
51
+ t = round(16 * event[1] / ticks_per_beat) # quantization
52
  new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
53
  if event[0] == "note":
54
  new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
55
  elif event[0] == "set_tempo":
56
  new_event[4] = int(self.tempo2bpm(new_event[4]))
57
+ if event[0] == "note":
58
+ key = tuple(new_event[:4] + new_event[5:-1])
59
+ else:
60
+ key = tuple(new_event[:-1])
61
+ if event[0] == "note": # to eliminate note overlap due to quantization
62
+ cp = tuple(new_event[5:7])
63
+ if cp in last_notes:
64
+ last_note_key, last_note = last_notes[cp]
65
+ last_t = last_note[1] * 16 + last_note[2]
66
+ last_note[4] = max(0, min(last_note[4], t - last_t))
67
+ if last_note[4] == 0:
68
+ event_list.pop(last_note_key)
69
+ last_notes[cp] = (key, new_event)
70
  event_list[key] = new_event
71
  event_list = list(event_list.values())
72
+ event_list = sorted(event_list, key=lambda e: e[1:4])
73
  midi_seq = []
74
 
75
  last_t1 = 0
 
127
  tracks_dict[track_idx] = []
128
  tracks_dict[track_idx].append([event[0], t] + event[4:])
129
  tracks = list(tracks_dict.values())
130
+
131
+ for i in range(len(tracks)): # to eliminate note overlap
132
  track = tracks[i]
133
  track = sorted(track, key=lambda e: e[1])
134
  last_note_t = {}
135
+ zero_len_notes = []
136
  for e in reversed(track):
137
  if e[0] == "note":
138
  t, d, c, p = e[1:5]
139
  key = (c, p)
140
  if key in last_note_t:
141
+ d = min(d, max(last_note_t[key] - t, 0))
142
  last_note_t[key] = t
143
  e[2] = d
144
+ if d == 0:
145
+ zero_len_notes.append(e)
146
+ for e in zero_len_notes:
147
+ track.remove(e)
148
  tracks[i] = track
149
  return [ticks_per_beat, *tracks]
150
 
 
168
  img[p, t: t + d] = colors[(tr, c)]
169
  img = PIL.Image.fromarray(np.flip(img, 0))
170
  return img
171
+
172
+ def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10):
173
+ pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
174
+ vel_shift = random.randint(-max_vel_shift, max_vel_shift)
175
+ cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
176
+ bpm_shift = random.randint(-max_bpm_shift, max_bpm_shift)
177
+ midi_seq_new = []
178
+ for tokens in midi_seq:
179
+ tokens_new = [*tokens]
180
+ if tokens[0] in self.id_events:
181
+ name = self.id_events[tokens[0]]
182
+ if name == "note":
183
+ c = tokens[5] - self.parameter_ids["channel"][0]
184
+ p = tokens[6] - self.parameter_ids["pitch"][0]
185
+ v = tokens[7] - self.parameter_ids["velocity"][0]
186
+ if c != 9: # no shift for drums
187
+ p += pitch_shift
188
+ if not 0 <= p < 128:
189
+ return midi_seq
190
+ v += vel_shift
191
+ v = max(1, min(127, v))
192
+ tokens_new[6] = self.parameter_ids["pitch"][p]
193
+ tokens_new[7] = self.parameter_ids["velocity"][v]
194
+ elif name == "control_change":
195
+ cc = tokens[5] - self.parameter_ids["controller"][0]
196
+ val = tokens[6] - self.parameter_ids["value"][0]
197
+ if cc in [1, 2, 7, 11]:
198
+ val += cc_val_shift
199
+ val = max(1, min(127, val))
200
+ tokens_new[6] = self.parameter_ids["value"][val]
201
+ elif name == "set_tempo":
202
+ bpm = tokens[4] - self.parameter_ids["bpm"][0]
203
+ bpm += bpm_shift
204
+ bpm = max(1, min(255, bpm))
205
+ tokens_new[4] = self.parameter_ids["bpm"][bpm]
206
+ midi_seq_new.append(tokens_new)
207
+ return midi_seq_new
208
+
209
+ def check_alignment(self, midi_seq, threshold=0.4):
210
+ total = 0
211
+ hist = [0] * 16
212
+ for tokens in midi_seq:
213
+ if tokens[0] in self.id_events and self.id_events[tokens[0]] == "note":
214
+ t2 = tokens[2] - self.parameter_ids["time2"][0]
215
+ total += 1
216
+ hist[t2] += 1
217
+ if total == 0:
218
+ return False
219
+ hist = sorted(hist, reverse=True)
220
+ p = sum(hist[:2]) / total
221
+ return p > threshold