Spaces:
Running
Running
import argparse | |
import glob | |
import os.path | |
import gradio as gr | |
import pickle | |
import tqdm | |
import json | |
import MIDI | |
from midi_synthesizer import synthesis | |
import copy | |
from collections import Counter | |
import random | |
import statistics | |
import matplotlib.pyplot as plt | |
#========================================================================================================== | |
in_space = os.getenv("SYSTEM") == "spaces" | |
#========================================================================================================== | |
def match_midi(midi, max_match_ratio, progress=gr.Progress()): | |
print('=' * 70) | |
print('Loading MIDI file...') | |
#================================================== | |
score = MIDI.midi2score(midi) | |
events_matrix = [] | |
track_count = 0 | |
for s in score: | |
if track_count > 0: | |
track = s | |
track.sort(key=lambda x: x[1]) | |
events_matrix.extend(track) | |
else: | |
midi_ticks = s | |
track_count += 1 | |
events_matrix.sort(key=lambda x: x[1]) | |
mult_pitches_counts = [] | |
for i in range(-6, 6): | |
events_matrix1 = [] | |
for e in events_matrix: | |
ev = copy.deepcopy(e) | |
if e[0] == 'note': | |
if e[3] == 9: | |
ev[4] = ((e[4] % 128) + 128) | |
else: | |
ev[4] = ((e[4] % 128) + i) | |
events_matrix1.append(ev) | |
pitches_counts = [[y[0],y[1]] for y in Counter([y[4] for y in events_matrix1 if y[0] == 'note']).most_common()] | |
pitches_counts.sort(key=lambda x: x[0], reverse=True) | |
mult_pitches_counts.append(pitches_counts) | |
patches_list = sorted(list(set([y[3] for y in events_matrix if y[0] == 'patch_change']))) | |
#================================================== | |
ms_score = MIDI.midi2ms_score(midi) | |
ms_events_matrix = [] | |
itrack1 = 1 | |
while itrack1 < len(ms_score): | |
for event in ms_score[itrack1]: | |
if event[0] == 'note': | |
ms_events_matrix.append(event) | |
itrack1 += 1 | |
ms_events_matrix.sort(key=lambda x: x[1]) | |
chords = [] | |
pe = ms_events_matrix[0] | |
cho = [] | |
for e in ms_events_matrix: | |
if (e[1] - pe[1]) == 0: | |
if e[3] != 9: | |
if (e[4] % 12) not in cho: | |
cho.append(e[4] % 12) | |
else: | |
if len(cho) > 0: | |
chords.append(sorted(cho)) | |
cho = [] | |
if e[3] != 9: | |
if (e[4] % 12) not in cho: | |
cho.append(e[4] % 12) | |
pe = e | |
if len(cho) > 0: | |
chords.append(sorted(cho)) | |
ms_chords_counts = sorted([[list(key), val] for key,val in Counter([tuple(c) for c in chords if len(c) > 1]).most_common()], reverse=True, key = lambda x: x[1]) | |
times = [] | |
pt = ms_events_matrix[0][1] | |
start = True | |
for e in ms_events_matrix: | |
if (e[1]-pt) != 0 or start == True: | |
times.append((e[1]-pt)) | |
start = False | |
pt = e[1] | |
durs = [e[2] for e in ms_events_matrix] | |
vels = [e[5] for e in ms_events_matrix] | |
avg_time = int(sum(times) / len(times)) | |
avg_dur = int(sum(durs) / len(durs)) | |
mode_time = statistics.mode(times) | |
mode_dur = statistics.mode(durs) | |
median_time = int(statistics.median(times)) | |
median_dur = int(statistics.median(durs)) | |
#================================================== | |
print('=' * 70) | |
print('Done!') | |
print('=' * 70) | |
#========================================================================================================== | |
#@title MIDI Pitches Search | |
#@markdown Match ratio control option | |
maximum_match_ratio_to_search_for = max_match_ratio #@param {type:"slider", min:0, max:1, step:0.01} | |
#@markdown MIDI pitches search options | |
pitches_counts_cutoff_threshold_ratio = 0 #@param {type:"slider", min:0, max:1, step:0.05} | |
search_transposed_pitches = False #@param {type:"boolean"} | |
skip_exact_matches = False #@param {type:"boolean"} | |
#@markdown Additional search options | |
add_pitches_counts_ratios = False #@param {type:"boolean"} | |
add_timings_ratios = False #@param {type:"boolean"} | |
add_durations_ratios = False #@param {type:"boolean"} | |
print('=' * 70) | |
print('MIDI Pitches Search') | |
print('=' * 70) | |
final_ratios = [] | |
for d in progress.tqdm(meta_data): | |
p_counts = d[1][10][1] | |
p_counts.sort(reverse = True, key = lambda x: x[1]) | |
max_p_count = p_counts[0][1] | |
trimmed_p_counts = [y for y in p_counts if y[1] >= (max_p_count * pitches_counts_cutoff_threshold_ratio)] | |
total_p_counts = sum([y[1] for y in trimmed_p_counts]) | |
if search_transposed_pitches: | |
search_pitches = mult_pitches_counts | |
else: | |
search_pitches = [mult_pitches_counts[6]] | |
#=================================================== | |
ratios_list = [] | |
#=================================================== | |
atrat = [0] | |
if add_timings_ratios: | |
source_times = [avg_time, | |
median_time, | |
mode_time] | |
match_times = meta_data[0][1][3][1] | |
times_ratios = [] | |
for i in range(len(source_times)): | |
maxtratio = max(source_times[i], match_times[i]) | |
mintratio = min(source_times[i], match_times[i]) | |
times_ratios.append(mintratio / maxtratio) | |
avg_times_ratio = sum(times_ratios) / len(times_ratios) | |
atrat[0] = avg_times_ratio | |
#=================================================== | |
adrat = [0] | |
if add_durations_ratios: | |
source_durs = [avg_dur, | |
median_dur, | |
mode_dur] | |
match_durs = meta_data[0][1][4][1] | |
durs_ratios = [] | |
for i in range(len(source_durs)): | |
maxtratio = max(source_durs[i], match_durs[i]) | |
mintratio = min(source_durs[i], match_durs[i]) | |
durs_ratios.append(mintratio / maxtratio) | |
avg_durs_ratio = sum(durs_ratios) / len(durs_ratios) | |
adrat[0] = avg_durs_ratio | |
#=================================================== | |
for m in search_pitches: | |
sprat = [] | |
m.sort(reverse = True, key = lambda x: x[1]) | |
max_pitches_count = m[0][1] | |
trimmed_pitches_counts = [y for y in m if y[1] >= (max_pitches_count * pitches_counts_cutoff_threshold_ratio)] | |
total_pitches_counts = sum([y[1] for y in trimmed_pitches_counts]) | |
same_pitches = set([T[0] for T in trimmed_p_counts]) & set([m[0] for m in trimmed_pitches_counts]) | |
num_same_pitches = len(same_pitches) | |
if num_same_pitches == len(trimmed_pitches_counts): | |
same_pitches_ratio = (num_same_pitches / len(trimmed_p_counts)) | |
else: | |
same_pitches_ratio = (num_same_pitches / max(len(trimmed_p_counts), len(trimmed_pitches_counts))) | |
if skip_exact_matches: | |
if same_pitches_ratio == 1: | |
same_pitches_ratio = 0 | |
sprat.append(same_pitches_ratio) | |
#=================================================== | |
spcrat = [0] | |
if add_pitches_counts_ratios: | |
same_trimmed_p_counts = sorted([T for T in trimmed_p_counts if T[0] in same_pitches], reverse = True) | |
same_trimmed_pitches_counts = sorted([T for T in trimmed_pitches_counts if T[0] in same_pitches], reverse = True) | |
same_trimmed_p_counts_ratios = [[s[0], s[1] / total_p_counts] for s in same_trimmed_p_counts] | |
same_trimmed_pitches_counts_ratios = [[s[0], s[1] / total_pitches_counts] for s in same_trimmed_pitches_counts] | |
same_pitches_counts_ratios = [] | |
for i in range(len(same_trimmed_p_counts_ratios)): | |
mincratio = min(same_trimmed_p_counts_ratios[i][1], same_trimmed_pitches_counts_ratios[i][1]) | |
maxcratio = max(same_trimmed_p_counts_ratios[i][1], same_trimmed_pitches_counts_ratios[i][1]) | |
same_pitches_counts_ratios.append([same_trimmed_p_counts_ratios[i][0], mincratio / maxcratio]) | |
same_counts_ratios = [s[1] for s in same_pitches_counts_ratios] | |
if len(same_counts_ratios) > 0: | |
avg_same_pitches_counts_ratio = sum(same_counts_ratios) / len(same_counts_ratios) | |
else: | |
avg_same_pitches_counts_ratio = 0 | |
spcrat[0] = avg_same_pitches_counts_ratio | |
#=================================================== | |
r_list = [sprat[0]] | |
if add_pitches_counts_ratios: | |
r_list.append(spcrat[0]) | |
if add_timings_ratios: | |
r_list.append(atrat[0]) | |
if add_durations_ratios: | |
r_list.append(adrat[0]) | |
ratios_list.append(r_list) | |
#=================================================== | |
avg_ratios_list = [] | |
for r in ratios_list: | |
avg_ratios_list.append(sum(r) / len(r)) | |
#=================================================== | |
final_ratio = max(avg_ratios_list) | |
if final_ratio > maximum_match_ratio_to_search_for: | |
final_ratio = 0 | |
final_ratios.append(final_ratio) | |
#=================================================== | |
max_ratio = max(final_ratios) | |
max_ratio_index = final_ratios.index(max_ratio) | |
print('FOUND') | |
print('=' * 70) | |
print('Match ratio', max_ratio) | |
print('MIDI file name', meta_data[max_ratio_index][0]) | |
print('=' * 70) | |
fn = meta_data[max_ratio_index][0] | |
#========================================================================================================== | |
md = meta_data[max_ratio_index] | |
mid_seq = md[1][17:-1] | |
mid_seq_ticks = md[1][16][1] | |
mdata = md[1][:16] | |
txt_mdata = '' | |
txt_mdata += '==============================================================' + chr(10) | |
txt_mdata += 'MIDI MATCH RATIO: ' + str(max_ratio) + chr(10) | |
txt_mdata += '==============================================================' + chr(10) | |
txt_mdata += 'MIDI MATCH MD5 HASH: ' + str(fn) + chr(10) | |
txt_mdata += '==============================================================' + chr(10) | |
for m in mdata: | |
txt_mdata += str(m[0]) + ': ' + str(m[1]) | |
txt_mdata += chr(10) | |
txt_mdata += '==============================================================' + chr(10) | |
for m in [d for d in md[1][16:] if d[0] != 'note']: | |
txt_mdata += str(m) | |
txt_mdata += chr(10) | |
txt_mdata += '==============================================================' + chr(10) | |
txt_mdata += 'MIDI MATCH RATIO: ' + str(max_ratio) + chr(10) | |
txt_mdata += '==============================================================' + chr(10) | |
txt_mdata += 'MIDI MATCH MD5 HASH: ' + str(fn) + chr(10) | |
txt_mdata += '==============================================================' + chr(10) | |
x = [] | |
y = [] | |
c = [] | |
colors = ['red', 'yellow', 'green', 'cyan', | |
'blue', 'pink', 'orange', 'purple', | |
'gray', 'white', 'gold', 'silver', | |
'lightgreen', 'indigo', 'maroon', 'turquoise'] | |
for s in [m for m in mid_seq if m[0] == 'note']: | |
x.append(s[1]) | |
y.append(s[4]) | |
c.append(colors[s[3]]) | |
plt.close() | |
plt.figure(figsize=(14,5)) | |
ax=plt.axes(title='MIDI Match Plot') | |
ax.set_facecolor('black') | |
plt.scatter(x,y, c=c) | |
plt.xlabel("Time in MIDI ticks") | |
plt.ylabel("MIDI Pitch") | |
with open(f"MIDI-Match-Sample.mid", 'wb') as f: | |
f.write(MIDI.score2midi([mid_seq_ticks, mid_seq])) | |
audio = synthesis(MIDI.score2opus([mid_seq_ticks, mid_seq]), soundfont_path) | |
yield txt_mdata, "MIDI-Match-Sample.mid", (44100, audio), plt | |
#========================================================================================================== | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--share", action="store_true", default=False, help="share gradio app") | |
parser.add_argument("--port", type=int, default=7860, help="gradio server port") | |
parser.add_argument("--max-gen", type=int, default=1024, help="max") | |
opt = parser.parse_args() | |
soundfont_path = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" | |
meta_data_path = "meta-data/LAMDa_META_DATA_81000.pickle" | |
print('Loading meta-data...') | |
with open(meta_data_path, 'rb') as f: | |
meta_data = pickle.load(f) | |
print('Done!') | |
app = gr.Blocks() | |
with app: | |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Match</h1>") | |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Upload any MIDI file to find its closest match</h1>") | |
gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.MIDI-Match&style=flat)\n\n" | |
"Los Angeles MIDI Dataset Search and Explore Demo\n\n" | |
"Please see [Los Angeles MIDI Dataset](https://github.com/asigalov61/Los-Angeles-MIDI-Dataset) for more information and features\n\n" | |
"[Open In Colab]" | |
"(https://colab.research.google.com/github/asigalov61/Los-Angeles-MIDI-Dataset/blob/main/Los_Angeles_MIDI_Dataset_Search_and_Explore.ipynb)" | |
" for faster execution" | |
) | |
gr.Markdown("# Upload MIDI") | |
maximum_match_ratio = gr.Slider(0.5, 1, value=1.0, label="Maximum match ratio to search for", info="Lower this value to see less precise matches") | |
input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"], type="binary") | |
gr.Markdown("# Match results") | |
output_audio = gr.Audio(label="Output MIDI match sample audio", format="mp3", elem_id="midi_audio") | |
output_plot = gr.Plot(label="Output MIDI match sample plot") | |
output_midi = gr.File(label="Output MIDI match sample MIDI", file_types=[".mid"]) | |
output_midi_seq = gr.Textbox(label="Output MIDI match metadata") | |
run_event = input_midi.upload(match_midi, [input_midi, maximum_match_ratio], | |
[output_midi_seq, output_midi, output_audio, output_plot]) | |
app.queue(1).launch(server_port=opt.port, share=opt.share, inbrowser=True) |