# https://huggingface.co/spaces/asigalov61/LAKH-MIDI-Dataset-Search import os import time as reqtime import datetime from pytz import timezone from sentence_transformers import SentenceTransformer from sentence_transformers import util import numpy as np from datasets import load_dataset import gradio as gr import copy import random import pickle import zlib from midi_to_colab_audio import midi_to_colab_audio import TMIDIX import matplotlib.pyplot as plt #========================================================================================================== def find_midi(input_search_string): print('=' * 70) print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) start_time = reqtime.time() print('-' * 70) print('Req search str:', input_search_string) print('-' * 70) print('Searching...') query_embedding = model.encode([input_search_string]) # Compute cosine similarity between query and each sentence in the corpus similarities = util.cos_sim(query_embedding, corpus_embeddings) top_ten_matches_idxs = np.argsort(-similarities)[0][:10].tolist() # Find the index of the most similar sentence closest_index = np.argmax(similarities) closest_index_match_ratio = max(similarities[0].tolist()) best_corpus_match = mc_dataset['train'][closest_index.tolist()] print('Done!') print('=' * 70) print('Match corpus index', closest_index) print('Match corpus ratio', closest_index_match_ratio) print('=' * 70) print('Done!') print('=' * 70) LAKH_hash = best_corpus_match['location'].split('/')[-1].split('.mid')[0] LAKH_caption = str(best_corpus_match['caption']) zlib_file_name = all_MIDI_files_names[MIDI_files_names.index(LAKH_hash)][1] print('Fetching MIDI score...') with open(zlib_file_name, 'rb') as f: compressed_data = f.read() # Decompress the data decompressed_data = zlib.decompress(compressed_data) # Convert the bytes back to a list using pickle scores_data = pickle.loads(decompressed_data) fnames = [f[0] for f in scores_data] fnameidx = fnames.index(LAKH_hash) MIDI_score_metadata = scores_data[fnameidx][1] MIDI_score_data = scores_data[fnameidx][2] print('Rendering results...') print('=' * 70) print('MIDi Title:', LAKH_hash) print('Sample INTs', MIDI_score_data[:12]) print('=' * 70) if len(MIDI_score_data) != 0: song = MIDI_score_data song_f = [] time = 0 dur = 0 vel = 90 pitch = 0 channel = 0 patches = [-1] * 16 channels = [0] * 16 channels[9] = 1 for ss in song: if 0 <= ss < 256: time += ss * 16 if 256 <= ss < 512: dur = (ss-256) * 16 if 512 <= ss <= 640: patch = (ss-512) if patch < 128: if patch not in patches: if 0 in channels: cha = channels.index(0) channels[cha] = 1 else: cha = 15 patches[cha] = patch channel = patches.index(patch) else: channel = patches.index(patch) if patch == 128: channel = 9 if 640 < ss < 768: ptc = (ss-640) if 768 < ss < 896: vel = (ss - 768) song_f.append(['note', time, dur, channel, ptc, vel, patch ]) patches = [0 if x==-1 else x for x in patches] song_f = song_f[:3000] print('=' * 70) #=============================================================================== output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f) detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score, output_signature = 'LAKH MIDI Dataset Search', output_file_name = LAKH_hash, track_name='Project Los Angeles', list_of_MIDI_patches=patches ) new_fn = LAKH_hash + '.mid' audio = midi_to_colab_audio(new_fn, soundfont_path=soundfont, sample_rate=16000, volume_scale=10, output_for_gradio=True ) print('Done!') print('=' * 70) #======================================================== output_midi_title = str(LAKH_hash) output_midi_caption = str(best_corpus_match['caption']) output_midi_summary = str(MIDI_score_metadata) output_midi_caps = str(best_corpus_match) output_midi = str(new_fn) output_audio = (16000, audio) output_plot = TMIDIX.plot_ms_SONG(output_score, plot_title=output_midi_title, return_plt=True) print('Output MIDI file name:', output_midi) print('Output MIDI caption string:', output_midi_caption) print('Output MIDI title:', output_midi_title) print('Output MIDI summary:', output_midi_summary) print('Output MidiCaps information:', output_midi_caps) print('=' * 70) #======================================================== print('-' * 70) print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('-' * 70) print('Req execution time:', (reqtime.time() - start_time), 'sec') return output_midi_title, output_midi_caption, output_midi_summary, output_midi_caps, output_midi, output_audio, output_plot #========================================================================================================== if __name__ == "__main__": PDT = timezone('US/Pacific') print('=' * 70) print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('=' * 70) soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" print('Loading MidiCaps dataset...') mc_dataset = load_dataset("amaai-lab/MidiCaps") print('=' * 70) print('Loading files list...') all_MIDI_files_names = TMIDIX.Tegridy_Any_Pickle_File_Reader('LAKH_all_files_names') MIDI_files_names = [f[0] for f in all_MIDI_files_names] print('=' * 70) print('Loading MIDI corpus embeddings...') corpus_embeddings = np.load('MIDI_corpus_embeddings_all-mpnet-base-v2.npz')['data'] print('Done!') print('=' * 70) print('Loading Sentence Transformer model...') model = SentenceTransformer('all-mpnet-base-v2') print('Done!') print('=' * 70) app = gr.Blocks() with app: gr.Markdown("

LAKH MIDI Dataset Search

") gr.Markdown("

Search and explore LAKH MIDI dataset with MidiCaps dataset and sentence transformer

") gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.LAKH-MIDI-Dataset-Search&style=flat)\n\n" "This is a demo for MidiCaps dataset\n\n" "Check out [MidiCaps Dataset](https://huggingface.co/datasets/amaai-lab/MidiCaps) on Hugging Face!\n\n" ) gr.Markdown("# Enter any desired song description\n\n") input_search_string = gr.Textbox(label="Search string", value="Cheery pop song about love and happiness") submit = gr.Button(value='Search') gr.ClearButton(components=[input_search_string]) gr.Markdown("# Search results") output_midi_title = gr.Textbox(label="Output MIDI title") output_midi_caption = gr.Textbox(label="MIDI caption string") output_midi_summary = gr.Textbox(label="Aggregated MIDI file text metadata") output_midi_caps = gr.Textbox(label="MidiCaps dataset information") output_audio = gr.Audio(label="Output MIDI audio", format="mp3", elem_id="midi_audio") output_plot = gr.Plot(label="Output MIDI score plot") output_midi = gr.File(label="Output MIDI file", file_types=[".mid"]) run_event = submit.click(find_midi, [input_search_string], [output_midi_title, output_midi_caption, output_midi_summary, output_midi_caps, output_midi, output_audio, output_plot ]) app.launch()