diff --git a/README.md b/README.md index 176fdbd2acb9adb502e31eea2cf90a987c699ac8..30b8f8d7d9b70be65252890375481ed7d178d98d 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ --- title: Ultimate Vocal Remover WebUI -emoji: 🌍 -colorFrom: gray -colorTo: gray +emoji: 🗣️-🎵 +colorFrom: purple +colorTo: yellow sdk: gradio -sdk_version: 4.26.0 +sdk_version: 3.44.2 app_file: app.py pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +license: mit +short_description: Remove Vocal and Instrument from Music! +--- \ No newline at end of file diff --git a/UVR_interface.py b/UVR_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..9f31d4fecd574e68fdb37a4031c97779000f43c0 --- /dev/null +++ b/UVR_interface.py @@ -0,0 +1,852 @@ +import audioread +import librosa + +import os +import sys +import json +import time +from tqdm import tqdm +import pickle +import hashlib +import logging +import traceback +import shutil +import soundfile as sf + +import torch + +from gui_data.constants import * +from gui_data.old_data_check import file_check, remove_unneeded_yamls, remove_temps +from lib_v5.vr_network.model_param_init import ModelParameters +from lib_v5 import spec_utils +from pathlib import Path +from separate import SeperateAttributes, SeperateDemucs, SeperateMDX, SeperateVR, save_format +from typing import List + + +logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO) +logging.info('UVR BEGIN') + +PREVIOUS_PATCH_WIN = 'UVR_Patch_1_12_23_14_54' + +is_dnd_compatible = True +banner_placement = -2 + +def save_data(data): + """ + Saves given data as a .pkl (pickle) file + + Paramters: + data(dict): + Dictionary containing all the necessary data to save + """ + # Open data file, create it if it does not exist + with open('data.pkl', 'wb') as data_file: + pickle.dump(data, data_file) + +def load_data() -> dict: + """ + Loads saved pkl file and returns the stored data + + Returns(dict): + Dictionary containing all the saved data + """ + try: + with open('data.pkl', 'rb') as data_file: # Open data file + data = pickle.load(data_file) + + return data + except (ValueError, FileNotFoundError): + # Data File is corrupted or not found so recreate it + + save_data(data=DEFAULT_DATA) + + return load_data() + +def load_model_hash_data(dictionary): + '''Get the model hash dictionary''' + + with open(dictionary) as d: + data = d.read() + + return json.loads(data) + +# Change the current working directory to the directory +# this file sits in +if getattr(sys, 'frozen', False): + # If the application is run as a bundle, the PyInstaller bootloader + # extends the sys module by a flag frozen=True and sets the app + # path into variable _MEIPASS'. + BASE_PATH = sys._MEIPASS +else: + BASE_PATH = os.path.dirname(os.path.abspath(__file__)) + +os.chdir(BASE_PATH) # Change the current working directory to the base path + +debugger = [] + +#--Constants-- +#Models +MODELS_DIR = os.path.join(BASE_PATH, 'models') +VR_MODELS_DIR = os.path.join(MODELS_DIR, 'VR_Models') +MDX_MODELS_DIR = os.path.join(MODELS_DIR, 'MDX_Net_Models') +DEMUCS_MODELS_DIR = os.path.join(MODELS_DIR, 'Demucs_Models') +DEMUCS_NEWER_REPO_DIR = os.path.join(DEMUCS_MODELS_DIR, 'v3_v4_repo') +MDX_MIXER_PATH = os.path.join(BASE_PATH, 'lib_v5', 'mixer.ckpt') + +#Cache & Parameters +VR_HASH_DIR = os.path.join(VR_MODELS_DIR, 'model_data') +VR_HASH_JSON = os.path.join(VR_MODELS_DIR, 'model_data', 'model_data.json') +MDX_HASH_DIR = os.path.join(MDX_MODELS_DIR, 'model_data') +MDX_HASH_JSON = os.path.join(MDX_MODELS_DIR, 'model_data', 'model_data.json') +DEMUCS_MODEL_NAME_SELECT = os.path.join(DEMUCS_MODELS_DIR, 'model_data', 'model_name_mapper.json') +MDX_MODEL_NAME_SELECT = os.path.join(MDX_MODELS_DIR, 'model_data', 'model_name_mapper.json') +ENSEMBLE_CACHE_DIR = os.path.join(BASE_PATH, 'gui_data', 'saved_ensembles') +SETTINGS_CACHE_DIR = os.path.join(BASE_PATH, 'gui_data', 'saved_settings') +VR_PARAM_DIR = os.path.join(BASE_PATH, 'lib_v5', 'vr_network', 'modelparams') +SAMPLE_CLIP_PATH = os.path.join(BASE_PATH, 'temp_sample_clips') +ENSEMBLE_TEMP_PATH = os.path.join(BASE_PATH, 'ensemble_temps') + +#Style +ICON_IMG_PATH = os.path.join(BASE_PATH, 'gui_data', 'img', 'GUI-Icon.ico') +FONT_PATH = os.path.join(BASE_PATH, 'gui_data', 'fonts', 'centurygothic', 'GOTHIC.TTF')#ensemble_temps + +#Other +COMPLETE_CHIME = os.path.join(BASE_PATH, 'gui_data', 'complete_chime.wav') +FAIL_CHIME = os.path.join(BASE_PATH, 'gui_data', 'fail_chime.wav') +CHANGE_LOG = os.path.join(BASE_PATH, 'gui_data', 'change_log.txt') +SPLASH_DOC = os.path.join(BASE_PATH, 'tmp', 'splash.txt') + +file_check(os.path.join(MODELS_DIR, 'Main_Models'), VR_MODELS_DIR) +file_check(os.path.join(DEMUCS_MODELS_DIR, 'v3_repo'), DEMUCS_NEWER_REPO_DIR) +remove_unneeded_yamls(DEMUCS_MODELS_DIR) + +remove_temps(ENSEMBLE_TEMP_PATH) +remove_temps(SAMPLE_CLIP_PATH) +remove_temps(os.path.join(BASE_PATH, 'img')) + +if not os.path.isdir(ENSEMBLE_TEMP_PATH): + os.mkdir(ENSEMBLE_TEMP_PATH) + +if not os.path.isdir(SAMPLE_CLIP_PATH): + os.mkdir(SAMPLE_CLIP_PATH) + +model_hash_table = {} +data = load_data() + +class ModelData(): + def __init__(self, model_name: str, + selected_process_method=ENSEMBLE_MODE, + is_secondary_model=False, + primary_model_primary_stem=None, + is_primary_model_primary_stem_only=False, + is_primary_model_secondary_stem_only=False, + is_pre_proc_model=False, + is_dry_check=False): + + self.is_gpu_conversion = 0 if root.is_gpu_conversion_var.get() else -1 + self.is_normalization = root.is_normalization_var.get() + self.is_primary_stem_only = root.is_primary_stem_only_var.get() + self.is_secondary_stem_only = root.is_secondary_stem_only_var.get() + self.is_denoise = root.is_denoise_var.get() + self.mdx_batch_size = 1 if root.mdx_batch_size_var.get() == DEF_OPT else int(root.mdx_batch_size_var.get()) + self.is_mdx_ckpt = False + self.wav_type_set = root.wav_type_set + self.mp3_bit_set = root.mp3_bit_set_var.get() + self.save_format = root.save_format_var.get() + self.is_invert_spec = root.is_invert_spec_var.get() + self.is_mixer_mode = root.is_mixer_mode_var.get() + self.demucs_stems = root.demucs_stems_var.get() + self.demucs_source_list = [] + self.demucs_stem_count = 0 + self.mixer_path = MDX_MIXER_PATH + self.model_name = model_name + self.process_method = selected_process_method + self.model_status = False if self.model_name == CHOOSE_MODEL or self.model_name == NO_MODEL else True + self.primary_stem = None + self.secondary_stem = None + self.is_ensemble_mode = False + self.ensemble_primary_stem = None + self.ensemble_secondary_stem = None + self.primary_model_primary_stem = primary_model_primary_stem + self.is_secondary_model = is_secondary_model + self.secondary_model = None + self.secondary_model_scale = None + self.demucs_4_stem_added_count = 0 + self.is_demucs_4_stem_secondaries = False + self.is_4_stem_ensemble = False + self.pre_proc_model = None + self.pre_proc_model_activated = False + self.is_pre_proc_model = is_pre_proc_model + self.is_dry_check = is_dry_check + self.model_samplerate = 44100 + self.model_capacity = 32, 128 + self.is_vr_51_model = False + self.is_demucs_pre_proc_model_inst_mix = False + self.manual_download_Button = None + self.secondary_model_4_stem = [] + self.secondary_model_4_stem_scale = [] + self.secondary_model_4_stem_names = [] + self.secondary_model_4_stem_model_names_list = [] + self.all_models = [] + self.secondary_model_other = None + self.secondary_model_scale_other = None + self.secondary_model_bass = None + self.secondary_model_scale_bass = None + self.secondary_model_drums = None + self.secondary_model_scale_drums = None + + if selected_process_method == ENSEMBLE_MODE: + partitioned_name = model_name.partition(ENSEMBLE_PARTITION) + self.process_method = partitioned_name[0] + self.model_name = partitioned_name[2] + self.model_and_process_tag = model_name + self.ensemble_primary_stem, self.ensemble_secondary_stem = root.return_ensemble_stems() + self.is_ensemble_mode = True if not is_secondary_model and not is_pre_proc_model else False + self.is_4_stem_ensemble = True if root.ensemble_main_stem_var.get() == FOUR_STEM_ENSEMBLE and self.is_ensemble_mode else False + self.pre_proc_model_activated = root.is_demucs_pre_proc_model_activate_var.get() if not self.ensemble_primary_stem == VOCAL_STEM else False + + if self.process_method == VR_ARCH_TYPE: + self.is_secondary_model_activated = root.vr_is_secondary_model_activate_var.get() if not self.is_secondary_model else False + self.aggression_setting = float(int(root.aggression_setting_var.get())/100) + self.is_tta = root.is_tta_var.get() + self.is_post_process = root.is_post_process_var.get() + self.window_size = int(root.window_size_var.get()) + self.batch_size = 1 if root.batch_size_var.get() == DEF_OPT else int(root.batch_size_var.get()) + self.crop_size = int(root.crop_size_var.get()) + self.is_high_end_process = 'mirroring' if root.is_high_end_process_var.get() else 'None' + self.post_process_threshold = float(root.post_process_threshold_var.get()) + self.model_capacity = 32, 128 + self.model_path = os.path.join(VR_MODELS_DIR, f"{self.model_name}.pth") + self.get_model_hash() + if self.model_hash: + self.model_data = self.get_model_data(VR_HASH_DIR, root.vr_hash_MAPPER) if not self.model_hash == WOOD_INST_MODEL_HASH else WOOD_INST_PARAMS + if self.model_data: + vr_model_param = os.path.join(VR_PARAM_DIR, "{}.json".format(self.model_data["vr_model_param"])) + self.primary_stem = self.model_data["primary_stem"] + self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem] + self.vr_model_param = ModelParameters(vr_model_param) + self.model_samplerate = self.vr_model_param.param['sr'] + if "nout" in self.model_data.keys() and "nout_lstm" in self.model_data.keys(): + self.model_capacity = self.model_data["nout"], self.model_data["nout_lstm"] + self.is_vr_51_model = True + else: + self.model_status = False + + if self.process_method == MDX_ARCH_TYPE: + self.is_secondary_model_activated = root.mdx_is_secondary_model_activate_var.get() if not is_secondary_model else False + self.margin = int(root.margin_var.get()) + self.chunks = root.determine_auto_chunks(root.chunks_var.get(), self.is_gpu_conversion) if root.is_chunk_mdxnet_var.get() else 0 + self.get_mdx_model_path() + self.get_model_hash() + if self.model_hash: + self.model_data = self.get_model_data(MDX_HASH_DIR, root.mdx_hash_MAPPER) + if self.model_data: + self.compensate = self.model_data["compensate"] if root.compensate_var.get() == AUTO_SELECT else float(root.compensate_var.get()) + self.mdx_dim_f_set = self.model_data["mdx_dim_f_set"] + self.mdx_dim_t_set = self.model_data["mdx_dim_t_set"] + self.mdx_n_fft_scale_set = self.model_data["mdx_n_fft_scale_set"] + self.primary_stem = self.model_data["primary_stem"] + self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem] + else: + self.model_status = False + + if self.process_method == DEMUCS_ARCH_TYPE: + self.is_secondary_model_activated = root.demucs_is_secondary_model_activate_var.get() if not is_secondary_model else False + if not self.is_ensemble_mode: + self.pre_proc_model_activated = root.is_demucs_pre_proc_model_activate_var.get() if not root.demucs_stems_var.get() in [VOCAL_STEM, INST_STEM] else False + self.overlap = float(root.overlap_var.get()) + self.margin_demucs = int(root.margin_demucs_var.get()) + self.chunks_demucs = root.determine_auto_chunks(root.chunks_demucs_var.get(), self.is_gpu_conversion) + self.shifts = int(root.shifts_var.get()) + self.is_split_mode = root.is_split_mode_var.get() + self.segment = root.segment_var.get() + self.is_chunk_demucs = root.is_chunk_demucs_var.get() + self.is_demucs_combine_stems = root.is_demucs_combine_stems_var.get() + self.is_primary_stem_only = root.is_primary_stem_only_var.get() if self.is_ensemble_mode else root.is_primary_stem_only_Demucs_var.get() + self.is_secondary_stem_only = root.is_secondary_stem_only_var.get() if self.is_ensemble_mode else root.is_secondary_stem_only_Demucs_var.get() + self.get_demucs_model_path() + self.get_demucs_model_data() + + self.model_basename = os.path.splitext(os.path.basename(self.model_path))[0] if self.model_status else None + self.pre_proc_model_activated = self.pre_proc_model_activated if not self.is_secondary_model else False + + self.is_primary_model_primary_stem_only = is_primary_model_primary_stem_only + self.is_primary_model_secondary_stem_only = is_primary_model_secondary_stem_only + + if self.is_secondary_model_activated and self.model_status: + if (not self.is_ensemble_mode and root.demucs_stems_var.get() == ALL_STEMS and self.process_method == DEMUCS_ARCH_TYPE) or self.is_4_stem_ensemble: + for key in DEMUCS_4_SOURCE_LIST: + self.secondary_model_data(key) + self.secondary_model_4_stem.append(self.secondary_model) + self.secondary_model_4_stem_scale.append(self.secondary_model_scale) + self.secondary_model_4_stem_names.append(key) + self.demucs_4_stem_added_count = sum(i is not None for i in self.secondary_model_4_stem) + self.is_secondary_model_activated = False if all(i is None for i in self.secondary_model_4_stem) else True + self.demucs_4_stem_added_count = self.demucs_4_stem_added_count - 1 if self.is_secondary_model_activated else self.demucs_4_stem_added_count + if self.is_secondary_model_activated: + self.secondary_model_4_stem_model_names_list = [None if i is None else i.model_basename for i in self.secondary_model_4_stem] + self.is_demucs_4_stem_secondaries = True + else: + primary_stem = self.ensemble_primary_stem if self.is_ensemble_mode and self.process_method == DEMUCS_ARCH_TYPE else self.primary_stem + self.secondary_model_data(primary_stem) + + if self.process_method == DEMUCS_ARCH_TYPE and not is_secondary_model: + if self.demucs_stem_count >= 3 and self.pre_proc_model_activated: + self.pre_proc_model_activated = True + self.pre_proc_model = root.process_determine_demucs_pre_proc_model(self.primary_stem) + self.is_demucs_pre_proc_model_inst_mix = root.is_demucs_pre_proc_model_inst_mix_var.get() if self.pre_proc_model else False + + def secondary_model_data(self, primary_stem): + secondary_model_data = root.process_determine_secondary_model(self.process_method, primary_stem, self.is_primary_stem_only, self.is_secondary_stem_only) + self.secondary_model = secondary_model_data[0] + self.secondary_model_scale = secondary_model_data[1] + self.is_secondary_model_activated = False if not self.secondary_model else True + if self.secondary_model: + self.is_secondary_model_activated = False if self.secondary_model.model_basename == self.model_basename else True + + def get_mdx_model_path(self): + + if self.model_name.endswith(CKPT): + # self.chunks = 0 + # self.is_mdx_batch_mode = True + self.is_mdx_ckpt = True + + ext = '' if self.is_mdx_ckpt else ONNX + + for file_name, chosen_mdx_model in root.mdx_name_select_MAPPER.items(): + if self.model_name in chosen_mdx_model: + self.model_path = os.path.join(MDX_MODELS_DIR, f"{file_name}{ext}") + break + else: + self.model_path = os.path.join(MDX_MODELS_DIR, f"{self.model_name}{ext}") + + self.mixer_path = os.path.join(MDX_MODELS_DIR, f"mixer_val.ckpt") + + def get_demucs_model_path(self): + + demucs_newer = [True for x in DEMUCS_NEWER_TAGS if x in self.model_name] + demucs_model_dir = DEMUCS_NEWER_REPO_DIR if demucs_newer else DEMUCS_MODELS_DIR + + for file_name, chosen_model in root.demucs_name_select_MAPPER.items(): + if self.model_name in chosen_model: + self.model_path = os.path.join(demucs_model_dir, file_name) + break + else: + self.model_path = os.path.join(DEMUCS_NEWER_REPO_DIR, f'{self.model_name}.yaml') + + def get_demucs_model_data(self): + + self.demucs_version = DEMUCS_V4 + + for key, value in DEMUCS_VERSION_MAPPER.items(): + if value in self.model_name: + self.demucs_version = key + + self.demucs_source_list = DEMUCS_2_SOURCE if DEMUCS_UVR_MODEL in self.model_name else DEMUCS_4_SOURCE + self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER if DEMUCS_UVR_MODEL in self.model_name else DEMUCS_4_SOURCE_MAPPER + self.demucs_stem_count = 2 if DEMUCS_UVR_MODEL in self.model_name else 4 + + if not self.is_ensemble_mode: + self.primary_stem = PRIMARY_STEM if self.demucs_stems == ALL_STEMS else self.demucs_stems + self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem] + + def get_model_data(self, model_hash_dir, hash_mapper): + model_settings_json = os.path.join(model_hash_dir, "{}.json".format(self.model_hash)) + + if os.path.isfile(model_settings_json): + return json.load(open(model_settings_json)) + else: + for hash, settings in hash_mapper.items(): + if self.model_hash in hash: + return settings + else: + return self.get_model_data_from_popup() + + def get_model_data_from_popup(self): + return None + + def get_model_hash(self): + self.model_hash = None + + if not os.path.isfile(self.model_path): + self.model_status = False + self.model_hash is None + else: + if model_hash_table: + for (key, value) in model_hash_table.items(): + if self.model_path == key: + self.model_hash = value + break + + if not self.model_hash: + try: + with open(self.model_path, 'rb') as f: + f.seek(- 10000 * 1024, 2) + self.model_hash = hashlib.md5(f.read()).hexdigest() + except: + self.model_hash = hashlib.md5(open(self.model_path,'rb').read()).hexdigest() + + table_entry = {self.model_path: self.model_hash} + model_hash_table.update(table_entry) + + +class Ensembler(): + def __init__(self, is_manual_ensemble=False): + self.is_save_all_outputs_ensemble = root.is_save_all_outputs_ensemble_var.get() + chosen_ensemble_name = '{}'.format(root.chosen_ensemble_var.get().replace(" ", "_")) if not root.chosen_ensemble_var.get() == CHOOSE_ENSEMBLE_OPTION else 'Ensembled' + ensemble_algorithm = root.ensemble_type_var.get().partition("/") + ensemble_main_stem_pair = root.ensemble_main_stem_var.get().partition("/") + time_stamp = round(time.time()) + self.audio_tool = MANUAL_ENSEMBLE + self.main_export_path = Path(root.export_path_var.get()) + self.chosen_ensemble = f"_{chosen_ensemble_name}" if root.is_append_ensemble_name_var.get() else '' + ensemble_folder_name = self.main_export_path if self.is_save_all_outputs_ensemble else ENSEMBLE_TEMP_PATH + self.ensemble_folder_name = os.path.join(ensemble_folder_name, '{}_Outputs_{}'.format(chosen_ensemble_name, time_stamp)) + self.is_testing_audio = f"{time_stamp}_" if root.is_testing_audio_var.get() else '' + self.primary_algorithm = ensemble_algorithm[0] + self.secondary_algorithm = ensemble_algorithm[2] + self.ensemble_primary_stem = ensemble_main_stem_pair[0] + self.ensemble_secondary_stem = ensemble_main_stem_pair[2] + self.is_normalization = root.is_normalization_var.get() + self.wav_type_set = root.wav_type_set + self.mp3_bit_set = root.mp3_bit_set_var.get() + self.save_format = root.save_format_var.get() + if not is_manual_ensemble: + os.mkdir(self.ensemble_folder_name) + + def ensemble_outputs(self, audio_file_base, export_path, stem, is_4_stem=False, is_inst_mix=False): + """Processes the given outputs and ensembles them with the chosen algorithm""" + + if is_4_stem: + algorithm = root.ensemble_type_var.get() + stem_tag = stem + else: + if is_inst_mix: + algorithm = self.secondary_algorithm + stem_tag = f"{self.ensemble_secondary_stem} {INST_STEM}" + else: + algorithm = self.primary_algorithm if stem == PRIMARY_STEM else self.secondary_algorithm + stem_tag = self.ensemble_primary_stem if stem == PRIMARY_STEM else self.ensemble_secondary_stem + + stem_outputs = self.get_files_to_ensemble(folder=export_path, prefix=audio_file_base, suffix=f"_({stem_tag}).wav") + audio_file_output = f"{self.is_testing_audio}{audio_file_base}{self.chosen_ensemble}_({stem_tag})" + stem_save_path = os.path.join('{}'.format(self.main_export_path),'{}.wav'.format(audio_file_output)) + + if stem_outputs: + spec_utils.ensemble_inputs(stem_outputs, algorithm, self.is_normalization, self.wav_type_set, stem_save_path) + save_format(stem_save_path, self.save_format, self.mp3_bit_set) + + if self.is_save_all_outputs_ensemble: + for i in stem_outputs: + save_format(i, self.save_format, self.mp3_bit_set) + else: + for i in stem_outputs: + try: + os.remove(i) + except Exception as e: + print(e) + + def ensemble_manual(self, audio_inputs, audio_file_base, is_bulk=False): + """Processes the given outputs and ensembles them with the chosen algorithm""" + + is_mv_sep = True + + if is_bulk: + number_list = list(set([os.path.basename(i).split("_")[0] for i in audio_inputs])) + for n in number_list: + current_list = [i for i in audio_inputs if os.path.basename(i).startswith(n)] + audio_file_base = os.path.basename(current_list[0]).split('.wav')[0] + stem_testing = "instrum" if "Instrumental" in audio_file_base else "vocals" + if is_mv_sep: + audio_file_base = audio_file_base.split("_") + audio_file_base = f"{audio_file_base[1]}_{audio_file_base[2]}_{stem_testing}" + self.ensemble_manual_process(current_list, audio_file_base, is_bulk) + else: + self.ensemble_manual_process(audio_inputs, audio_file_base, is_bulk) + + def ensemble_manual_process(self, audio_inputs, audio_file_base, is_bulk): + + algorithm = root.choose_algorithm_var.get() + algorithm_text = "" if is_bulk else f"_({root.choose_algorithm_var.get()})" + stem_save_path = os.path.join('{}'.format(self.main_export_path),'{}{}{}.wav'.format(self.is_testing_audio, audio_file_base, algorithm_text)) + spec_utils.ensemble_inputs(audio_inputs, algorithm, self.is_normalization, self.wav_type_set, stem_save_path) + save_format(stem_save_path, self.save_format, self.mp3_bit_set) + + def get_files_to_ensemble(self, folder="", prefix="", suffix=""): + """Grab all the files to be ensembled""" + + return [os.path.join(folder, i) for i in os.listdir(folder) if i.startswith(prefix) and i.endswith(suffix)] + + +def secondary_stem(stem): + """Determines secondary stem""" + + for key, value in STEM_PAIR_MAPPER.items(): + if stem in key: + secondary_stem = value + + return secondary_stem + + +class UVRInterface: + def __init__(self) -> None: + pass + + def assemble_model_data(self, model=None, arch_type=ENSEMBLE_MODE, is_dry_check=False) -> List[ModelData]: + if arch_type == ENSEMBLE_STEM_CHECK: + model_data = self.model_data_table + missing_models = [model.model_status for model in model_data if not model.model_status] + + if missing_models or not model_data: + model_data: List[ModelData] = [ModelData(model_name, is_dry_check=is_dry_check) for model_name in self.ensemble_model_list] + self.model_data_table = model_data + + if arch_type == ENSEMBLE_MODE: + model_data: List[ModelData] = [ModelData(model_name) for model_name in self.ensemble_listbox_get_all_selected_models()] + if arch_type == ENSEMBLE_CHECK: + model_data: List[ModelData] = [ModelData(model)] + if arch_type == VR_ARCH_TYPE or arch_type == VR_ARCH_PM: + model_data: List[ModelData] = [ModelData(model, VR_ARCH_TYPE)] + if arch_type == MDX_ARCH_TYPE: + model_data: List[ModelData] = [ModelData(model, MDX_ARCH_TYPE)] + if arch_type == DEMUCS_ARCH_TYPE: + model_data: List[ModelData] = [ModelData(model, DEMUCS_ARCH_TYPE)]# + + return model_data + + def create_sample(self, audio_file, sample_path=SAMPLE_CLIP_PATH): + try: + with audioread.audio_open(audio_file) as f: + track_length = int(f.duration) + except Exception as e: + print('Audioread failed to get duration. Trying Librosa...') + y, sr = librosa.load(audio_file, mono=False, sr=44100) + track_length = int(librosa.get_duration(y=y, sr=sr)) + + clip_duration = int(root.model_sample_mode_duration_var.get()) + + if track_length >= clip_duration: + offset_cut = track_length//3 + off_cut = offset_cut + track_length + if not off_cut >= clip_duration: + offset_cut = 0 + name_apped = f'{clip_duration}_second_' + else: + offset_cut, clip_duration = 0, track_length + name_apped = '' + + sample = librosa.load(audio_file, offset=offset_cut, duration=clip_duration, mono=False, sr=44100)[0].T + audio_sample = os.path.join(sample_path, f'{os.path.splitext(os.path.basename(audio_file))[0]}_{name_apped}sample.wav') + sf.write(audio_sample, sample, 44100) + + return audio_sample + + def verify_audio(self, audio_file, is_process=True, sample_path=None): + is_good = False + error_data = '' + + if os.path.isfile(audio_file): + try: + librosa.load(audio_file, duration=3, mono=False, sr=44100) if not type(sample_path) is str else self.create_sample(audio_file, sample_path) + is_good = True + except Exception as e: + error_name = f'{type(e).__name__}' + traceback_text = ''.join(traceback.format_tb(e.__traceback__)) + message = f'{error_name}: "{e}"\n{traceback_text}"' + if is_process: + audio_base_name = os.path.basename(audio_file) + self.error_log_var.set(f'Error Loading the Following File:\n\n\"{audio_base_name}\"\n\nRaw Error Details:\n\n{message}') + else: + error_data = AUDIO_VERIFICATION_CHECK(audio_file, message) + + if is_process: + return is_good + else: + return is_good, error_data + + def cached_sources_clear(self): + self.vr_cache_source_mapper = {} + self.mdx_cache_source_mapper = {} + self.demucs_cache_source_mapper = {} + + def cached_model_source_holder(self, process_method, sources, model_name=None): + if process_method == VR_ARCH_TYPE: + self.vr_cache_source_mapper = {**self.vr_cache_source_mapper, **{model_name: sources}} + if process_method == MDX_ARCH_TYPE: + self.mdx_cache_source_mapper = {**self.mdx_cache_source_mapper, **{model_name: sources}} + if process_method == DEMUCS_ARCH_TYPE: + self.demucs_cache_source_mapper = {**self.demucs_cache_source_mapper, **{model_name: sources}} + + def cached_source_callback(self, process_method, model_name=None): + model, sources = None, None + + if process_method == VR_ARCH_TYPE: + mapper = self.vr_cache_source_mapper + if process_method == MDX_ARCH_TYPE: + mapper = self.mdx_cache_source_mapper + if process_method == DEMUCS_ARCH_TYPE: + mapper = self.demucs_cache_source_mapper + + for key, value in mapper.items(): + if model_name in key: + model = key + sources = value + + return model, sources + + def cached_source_model_list_check(self, model_list: List[ModelData]): + model: ModelData + primary_model_names = lambda process_method:[model.model_basename if model.process_method == process_method else None for model in model_list] + secondary_model_names = lambda process_method:[model.secondary_model.model_basename if model.is_secondary_model_activated and model.process_method == process_method else None for model in model_list] + + self.vr_primary_model_names = primary_model_names(VR_ARCH_TYPE) + self.mdx_primary_model_names = primary_model_names(MDX_ARCH_TYPE) + self.demucs_primary_model_names = primary_model_names(DEMUCS_ARCH_TYPE) + self.vr_secondary_model_names = secondary_model_names(VR_ARCH_TYPE) + self.mdx_secondary_model_names = secondary_model_names(MDX_ARCH_TYPE) + self.demucs_secondary_model_names = [model.secondary_model.model_basename if model.is_secondary_model_activated and model.process_method == DEMUCS_ARCH_TYPE and not model.secondary_model is None else None for model in model_list] + self.demucs_pre_proc_model_name = [model.pre_proc_model.model_basename if model.pre_proc_model else None for model in model_list]#list(dict.fromkeys()) + + for model in model_list: + if model.process_method == DEMUCS_ARCH_TYPE and model.is_demucs_4_stem_secondaries: + if not model.is_4_stem_ensemble: + self.demucs_secondary_model_names = model.secondary_model_4_stem_model_names_list + break + else: + for i in model.secondary_model_4_stem_model_names_list: + self.demucs_secondary_model_names.append(i) + + self.all_models = self.vr_primary_model_names + self.mdx_primary_model_names + self.demucs_primary_model_names + self.vr_secondary_model_names + self.mdx_secondary_model_names + self.demucs_secondary_model_names + self.demucs_pre_proc_model_name + + def process(self, model_name, arch_type, audio_file, export_path, is_model_sample_mode=False, is_4_stem_ensemble=False, set_progress_func=None, console_write=print) -> SeperateAttributes: + stime = time.perf_counter() + time_elapsed = lambda:f'Time Elapsed: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - stime)))}' + + if arch_type==ENSEMBLE_MODE: + model_list, ensemble = self.assemble_model_data(), Ensembler() + export_path = ensemble.ensemble_folder_name + is_ensemble = True + else: + model_list = self.assemble_model_data(model_name, arch_type) + is_ensemble = False + self.cached_source_model_list_check(model_list) + model = model_list[0] + + if self.verify_audio(audio_file): + audio_file = self.create_sample(audio_file) if is_model_sample_mode else audio_file + else: + print(f'"{os.path.basename(audio_file)}\" is missing or currupted.\n') + exit() + + audio_file_base = f"{os.path.splitext(os.path.basename(audio_file))[0]}" + audio_file_base = audio_file_base if is_ensemble else f"{round(time.time())}_{audio_file_base}" + audio_file_base = audio_file_base if not is_ensemble else f"{audio_file_base}_{model.model_basename}" + if not is_ensemble: + audio_file_base = f"{audio_file_base}_{model.model_basename}" + + if not is_ensemble: + export_path = os.path.join(Path(export_path), model.model_basename, os.path.splitext(os.path.basename(audio_file))[0]) + if not os.path.isdir(export_path): + os.makedirs(export_path) + + if set_progress_func is None: + pbar = tqdm(total=1) + self._progress = 0 + def set_progress_func(step, inference_iterations=0): + progress_curr = step + inference_iterations + pbar.update(progress_curr-self._progress) + self._progress = progress_curr + + def postprocess(): + pbar.close() + else: + def postprocess(): + pass + + process_data = { + 'model_data': model, + 'export_path': export_path, + 'audio_file_base': audio_file_base, + 'audio_file': audio_file, + 'set_progress_bar': set_progress_func, + 'write_to_console': lambda progress_text, base_text='': console_write(base_text + progress_text), + 'process_iteration': lambda:None, + 'cached_source_callback': self.cached_source_callback, + 'cached_model_source_holder': self.cached_model_source_holder, + 'list_all_models': self.all_models, + 'is_ensemble_master': is_ensemble, + 'is_4_stem_ensemble': is_ensemble and is_4_stem_ensemble + } + if model.process_method == VR_ARCH_TYPE: + seperator = SeperateVR(model, process_data) + if model.process_method == MDX_ARCH_TYPE: + seperator = SeperateMDX(model, process_data) + if model.process_method == DEMUCS_ARCH_TYPE: + seperator = SeperateDemucs(model, process_data) + + seperator.seperate() + postprocess() + + if is_ensemble: + audio_file_base = audio_file_base.replace(f"_{model.model_basename}", "") + console_write(ENSEMBLING_OUTPUTS) + + if is_4_stem_ensemble: + for output_stem in DEMUCS_4_SOURCE_LIST: + ensemble.ensemble_outputs(audio_file_base, export_path, output_stem, is_4_stem=True) + else: + if not root.is_secondary_stem_only_var.get(): + ensemble.ensemble_outputs(audio_file_base, export_path, PRIMARY_STEM) + if not root.is_primary_stem_only_var.get(): + ensemble.ensemble_outputs(audio_file_base, export_path, SECONDARY_STEM) + ensemble.ensemble_outputs(audio_file_base, export_path, SECONDARY_STEM, is_inst_mix=True) + + console_write(DONE) + + if is_model_sample_mode: + if os.path.isfile(audio_file): + os.remove(audio_file) + + torch.cuda.empty_cache() + + if is_ensemble and len(os.listdir(export_path)) == 0: + shutil.rmtree(export_path) + console_write(f'Process Complete, using time: {time_elapsed()}\nOutput path: {export_path}') + self.cached_sources_clear() + return seperator + + +class RootWrapper: + def __init__(self, var) -> None: + self.var=var + + def set(self, val): + self.var=val + + def get(self): + return self.var + +class FakeRoot: + def __init__(self) -> None: + self.wav_type_set = 'PCM_16' + self.vr_hash_MAPPER = load_model_hash_data(VR_HASH_JSON) + self.mdx_hash_MAPPER = load_model_hash_data(MDX_HASH_JSON) + self.mdx_name_select_MAPPER = load_model_hash_data(MDX_MODEL_NAME_SELECT) + self.demucs_name_select_MAPPER = load_model_hash_data(DEMUCS_MODEL_NAME_SELECT) + + def __getattribute__(self, __name: str): + try: + return super().__getattribute__(__name) + except AttributeError: + wrapped=RootWrapper(None) + super().__setattr__(__name, wrapped) + return wrapped + + def load_saved_settings(self, loaded_setting: dict, process_method=None): + """Loads user saved application settings or resets to default""" + + for key, value in DEFAULT_DATA.items(): + if not key in loaded_setting.keys(): + loaded_setting = {**loaded_setting, **{key:value}} + loaded_setting['batch_size'] = DEF_OPT + + is_ensemble = True if process_method == ENSEMBLE_MODE else False + + if not process_method or process_method == VR_ARCH_PM or is_ensemble: + self.vr_model_var.set(loaded_setting['vr_model']) + self.aggression_setting_var.set(loaded_setting['aggression_setting']) + self.window_size_var.set(loaded_setting['window_size']) + self.batch_size_var.set(loaded_setting['batch_size']) + self.crop_size_var.set(loaded_setting['crop_size']) + self.is_tta_var.set(loaded_setting['is_tta']) + self.is_output_image_var.set(loaded_setting['is_output_image']) + self.is_post_process_var.set(loaded_setting['is_post_process']) + self.is_high_end_process_var.set(loaded_setting['is_high_end_process']) + self.post_process_threshold_var.set(loaded_setting['post_process_threshold']) + self.vr_voc_inst_secondary_model_var.set(loaded_setting['vr_voc_inst_secondary_model']) + self.vr_other_secondary_model_var.set(loaded_setting['vr_other_secondary_model']) + self.vr_bass_secondary_model_var.set(loaded_setting['vr_bass_secondary_model']) + self.vr_drums_secondary_model_var.set(loaded_setting['vr_drums_secondary_model']) + self.vr_is_secondary_model_activate_var.set(loaded_setting['vr_is_secondary_model_activate']) + self.vr_voc_inst_secondary_model_scale_var.set(loaded_setting['vr_voc_inst_secondary_model_scale']) + self.vr_other_secondary_model_scale_var.set(loaded_setting['vr_other_secondary_model_scale']) + self.vr_bass_secondary_model_scale_var.set(loaded_setting['vr_bass_secondary_model_scale']) + self.vr_drums_secondary_model_scale_var.set(loaded_setting['vr_drums_secondary_model_scale']) + + if not process_method or process_method == DEMUCS_ARCH_TYPE or is_ensemble: + self.demucs_model_var.set(loaded_setting['demucs_model']) + self.segment_var.set(loaded_setting['segment']) + self.overlap_var.set(loaded_setting['overlap']) + self.shifts_var.set(loaded_setting['shifts']) + self.chunks_demucs_var.set(loaded_setting['chunks_demucs']) + self.margin_demucs_var.set(loaded_setting['margin_demucs']) + self.is_chunk_demucs_var.set(loaded_setting['is_chunk_demucs']) + self.is_chunk_mdxnet_var.set(loaded_setting['is_chunk_mdxnet']) + self.is_primary_stem_only_Demucs_var.set(loaded_setting['is_primary_stem_only_Demucs']) + self.is_secondary_stem_only_Demucs_var.set(loaded_setting['is_secondary_stem_only_Demucs']) + self.is_split_mode_var.set(loaded_setting['is_split_mode']) + self.is_demucs_combine_stems_var.set(loaded_setting['is_demucs_combine_stems']) + self.demucs_voc_inst_secondary_model_var.set(loaded_setting['demucs_voc_inst_secondary_model']) + self.demucs_other_secondary_model_var.set(loaded_setting['demucs_other_secondary_model']) + self.demucs_bass_secondary_model_var.set(loaded_setting['demucs_bass_secondary_model']) + self.demucs_drums_secondary_model_var.set(loaded_setting['demucs_drums_secondary_model']) + self.demucs_is_secondary_model_activate_var.set(loaded_setting['demucs_is_secondary_model_activate']) + self.demucs_voc_inst_secondary_model_scale_var.set(loaded_setting['demucs_voc_inst_secondary_model_scale']) + self.demucs_other_secondary_model_scale_var.set(loaded_setting['demucs_other_secondary_model_scale']) + self.demucs_bass_secondary_model_scale_var.set(loaded_setting['demucs_bass_secondary_model_scale']) + self.demucs_drums_secondary_model_scale_var.set(loaded_setting['demucs_drums_secondary_model_scale']) + self.demucs_stems_var.set(loaded_setting['demucs_stems']) + # self.update_stem_checkbox_labels(self.demucs_stems_var.get(), demucs=True) + self.demucs_pre_proc_model_var.set(data['demucs_pre_proc_model']) + self.is_demucs_pre_proc_model_activate_var.set(data['is_demucs_pre_proc_model_activate']) + self.is_demucs_pre_proc_model_inst_mix_var.set(data['is_demucs_pre_proc_model_inst_mix']) + + if not process_method or process_method == MDX_ARCH_TYPE or is_ensemble: + self.mdx_net_model_var.set(loaded_setting['mdx_net_model']) + self.chunks_var.set(loaded_setting['chunks']) + self.margin_var.set(loaded_setting['margin']) + self.compensate_var.set(loaded_setting['compensate']) + self.is_denoise_var.set(loaded_setting['is_denoise']) + self.is_invert_spec_var.set(loaded_setting['is_invert_spec']) + self.is_mixer_mode_var.set(loaded_setting['is_mixer_mode']) + self.mdx_batch_size_var.set(loaded_setting['mdx_batch_size']) + self.mdx_voc_inst_secondary_model_var.set(loaded_setting['mdx_voc_inst_secondary_model']) + self.mdx_other_secondary_model_var.set(loaded_setting['mdx_other_secondary_model']) + self.mdx_bass_secondary_model_var.set(loaded_setting['mdx_bass_secondary_model']) + self.mdx_drums_secondary_model_var.set(loaded_setting['mdx_drums_secondary_model']) + self.mdx_is_secondary_model_activate_var.set(loaded_setting['mdx_is_secondary_model_activate']) + self.mdx_voc_inst_secondary_model_scale_var.set(loaded_setting['mdx_voc_inst_secondary_model_scale']) + self.mdx_other_secondary_model_scale_var.set(loaded_setting['mdx_other_secondary_model_scale']) + self.mdx_bass_secondary_model_scale_var.set(loaded_setting['mdx_bass_secondary_model_scale']) + self.mdx_drums_secondary_model_scale_var.set(loaded_setting['mdx_drums_secondary_model_scale']) + + if not process_method or is_ensemble: + self.is_save_all_outputs_ensemble_var.set(loaded_setting['is_save_all_outputs_ensemble']) + self.is_append_ensemble_name_var.set(loaded_setting['is_append_ensemble_name']) + self.chosen_audio_tool_var.set(loaded_setting['chosen_audio_tool']) + self.choose_algorithm_var.set(loaded_setting['choose_algorithm']) + self.time_stretch_rate_var.set(loaded_setting['time_stretch_rate']) + self.pitch_rate_var.set(loaded_setting['pitch_rate']) + self.is_primary_stem_only_var.set(loaded_setting['is_primary_stem_only']) + self.is_secondary_stem_only_var.set(loaded_setting['is_secondary_stem_only']) + self.is_testing_audio_var.set(loaded_setting['is_testing_audio']) + self.is_add_model_name_var.set(loaded_setting['is_add_model_name']) + self.is_accept_any_input_var.set(loaded_setting["is_accept_any_input"]) + self.is_task_complete_var.set(loaded_setting['is_task_complete']) + self.is_create_model_folder_var.set(loaded_setting['is_create_model_folder']) + self.mp3_bit_set_var.set(loaded_setting['mp3_bit_set']) + self.save_format_var.set(loaded_setting['save_format']) + self.wav_type_set_var.set(loaded_setting['wav_type_set']) + self.user_code_var.set(loaded_setting['user_code']) + + self.is_gpu_conversion_var.set(loaded_setting['is_gpu_conversion']) + self.is_normalization_var.set(loaded_setting['is_normalization']) + self.help_hints_var.set(loaded_setting['help_hints_var']) + + self.model_sample_mode_var.set(loaded_setting['model_sample_mode']) + self.model_sample_mode_duration_var.set(loaded_setting['model_sample_mode_duration']) + + +root = FakeRoot() +root.load_saved_settings(DEFAULT_DATA) \ No newline at end of file diff --git a/app.py b/app.py index b7b9ede96c1f1b67e0701480cd277d2d6212ca80..4dbcdf59f121093d857455078386d15ff0342aeb 100644 --- a/app.py +++ b/app.py @@ -1,2 +1,3 @@ import os + os.system("python webUI.py") \ No newline at end of file diff --git a/demucs/__init__.py b/demucs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5656d59e07f3fa33dd3bad1a0f9279ff4b8a6128 --- /dev/null +++ b/demucs/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/demucs/__main__.py b/demucs/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..5de878f198029af74afa2db7e6c06b9b306bf99d --- /dev/null +++ b/demucs/__main__.py @@ -0,0 +1,272 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import sys +import time +from dataclasses import dataclass, field +from fractions import Fraction + +import torch as th +from torch import distributed, nn +from torch.nn.parallel.distributed import DistributedDataParallel + +from .augment import FlipChannels, FlipSign, Remix, Shift +from .compressed import StemsSet, build_musdb_metadata, get_musdb_tracks +from .model import Demucs +from .parser import get_name, get_parser +from .raw import Rawset +from .tasnet import ConvTasNet +from .test import evaluate +from .train import train_model, validate_model +from .utils import human_seconds, load_model, save_model, sizeof_fmt + + +@dataclass +class SavedState: + metrics: list = field(default_factory=list) + last_state: dict = None + best_state: dict = None + optimizer: dict = None + + +def main(): + parser = get_parser() + args = parser.parse_args() + name = get_name(parser, args) + print(f"Experiment {name}") + + if args.musdb is None and args.rank == 0: + print( + "You must provide the path to the MusDB dataset with the --musdb flag. " + "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.", + file=sys.stderr) + sys.exit(1) + + eval_folder = args.evals / name + eval_folder.mkdir(exist_ok=True, parents=True) + args.logs.mkdir(exist_ok=True) + metrics_path = args.logs / f"{name}.json" + eval_folder.mkdir(exist_ok=True, parents=True) + args.checkpoints.mkdir(exist_ok=True, parents=True) + args.models.mkdir(exist_ok=True, parents=True) + + if args.device is None: + device = "cpu" + if th.cuda.is_available(): + device = "cuda" + else: + device = args.device + + th.manual_seed(args.seed) + # Prevents too many threads to be started when running `museval` as it can be quite + # inefficient on NUMA architectures. + os.environ["OMP_NUM_THREADS"] = "1" + + if args.world_size > 1: + if device != "cuda" and args.rank == 0: + print("Error: distributed training is only available with cuda device", file=sys.stderr) + sys.exit(1) + th.cuda.set_device(args.rank % th.cuda.device_count()) + distributed.init_process_group(backend="nccl", + init_method="tcp://" + args.master, + rank=args.rank, + world_size=args.world_size) + + checkpoint = args.checkpoints / f"{name}.th" + checkpoint_tmp = args.checkpoints / f"{name}.th.tmp" + if args.restart and checkpoint.exists(): + checkpoint.unlink() + + if args.test: + args.epochs = 1 + args.repeat = 0 + model = load_model(args.models / args.test) + elif args.tasnet: + model = ConvTasNet(audio_channels=args.audio_channels, samplerate=args.samplerate, X=args.X) + else: + model = Demucs( + audio_channels=args.audio_channels, + channels=args.channels, + context=args.context, + depth=args.depth, + glu=args.glu, + growth=args.growth, + kernel_size=args.kernel_size, + lstm_layers=args.lstm_layers, + rescale=args.rescale, + rewrite=args.rewrite, + sources=4, + stride=args.conv_stride, + upsample=args.upsample, + samplerate=args.samplerate + ) + model.to(device) + if args.show: + print(model) + size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters())) + print(f"Model size {size}") + return + + optimizer = th.optim.Adam(model.parameters(), lr=args.lr) + + try: + saved = th.load(checkpoint, map_location='cpu') + except IOError: + saved = SavedState() + else: + model.load_state_dict(saved.last_state) + optimizer.load_state_dict(saved.optimizer) + + if args.save_model: + if args.rank == 0: + model.to("cpu") + model.load_state_dict(saved.best_state) + save_model(model, args.models / f"{name}.th") + return + + if args.rank == 0: + done = args.logs / f"{name}.done" + if done.exists(): + done.unlink() + + if args.augment: + augment = nn.Sequential(FlipSign(), FlipChannels(), Shift(args.data_stride), + Remix(group_size=args.remix_group_size)).to(device) + else: + augment = Shift(args.data_stride) + + if args.mse: + criterion = nn.MSELoss() + else: + criterion = nn.L1Loss() + + # Setting number of samples so that all convolution windows are full. + # Prevents hard to debug mistake with the prediction being shifted compared + # to the input mixture. + samples = model.valid_length(args.samples) + print(f"Number of training samples adjusted to {samples}") + + if args.raw: + train_set = Rawset(args.raw / "train", + samples=samples + args.data_stride, + channels=args.audio_channels, + streams=[0, 1, 2, 3, 4], + stride=args.data_stride) + + valid_set = Rawset(args.raw / "valid", channels=args.audio_channels) + else: + if not args.metadata.is_file() and args.rank == 0: + build_musdb_metadata(args.metadata, args.musdb, args.workers) + if args.world_size > 1: + distributed.barrier() + metadata = json.load(open(args.metadata)) + duration = Fraction(samples + args.data_stride, args.samplerate) + stride = Fraction(args.data_stride, args.samplerate) + train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"), + metadata, + duration=duration, + stride=stride, + samplerate=args.samplerate, + channels=args.audio_channels) + valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"), + metadata, + samplerate=args.samplerate, + channels=args.audio_channels) + + best_loss = float("inf") + for epoch, metrics in enumerate(saved.metrics): + print(f"Epoch {epoch:03d}: " + f"train={metrics['train']:.8f} " + f"valid={metrics['valid']:.8f} " + f"best={metrics['best']:.4f} " + f"duration={human_seconds(metrics['duration'])}") + best_loss = metrics['best'] + + if args.world_size > 1: + dmodel = DistributedDataParallel(model, + device_ids=[th.cuda.current_device()], + output_device=th.cuda.current_device()) + else: + dmodel = model + + for epoch in range(len(saved.metrics), args.epochs): + begin = time.time() + model.train() + train_loss = train_model(epoch, + train_set, + dmodel, + criterion, + optimizer, + augment, + batch_size=args.batch_size, + device=device, + repeat=args.repeat, + seed=args.seed, + workers=args.workers, + world_size=args.world_size) + model.eval() + valid_loss = validate_model(epoch, + valid_set, + model, + criterion, + device=device, + rank=args.rank, + split=args.split_valid, + world_size=args.world_size) + + duration = time.time() - begin + if valid_loss < best_loss: + best_loss = valid_loss + saved.best_state = { + key: value.to("cpu").clone() + for key, value in model.state_dict().items() + } + saved.metrics.append({ + "train": train_loss, + "valid": valid_loss, + "best": best_loss, + "duration": duration + }) + if args.rank == 0: + json.dump(saved.metrics, open(metrics_path, "w")) + + saved.last_state = model.state_dict() + saved.optimizer = optimizer.state_dict() + if args.rank == 0 and not args.test: + th.save(saved, checkpoint_tmp) + checkpoint_tmp.rename(checkpoint) + + print(f"Epoch {epoch:03d}: " + f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} " + f"duration={human_seconds(duration)}") + + del dmodel + model.load_state_dict(saved.best_state) + if args.eval_cpu: + device = "cpu" + model.to(device) + model.eval() + evaluate(model, + args.musdb, + eval_folder, + rank=args.rank, + world_size=args.world_size, + device=device, + save=args.save, + split=args.split_valid, + shifts=args.shifts, + workers=args.eval_workers) + model.to("cpu") + save_model(model, args.models / f"{name}.th") + if args.rank == 0: + print("done") + done.write_text("done") + + +if __name__ == "__main__": + main() diff --git a/demucs/apply.py b/demucs/apply.py new file mode 100644 index 0000000000000000000000000000000000000000..5769376c645eaa02c85eac13867ac9ed285ddac5 --- /dev/null +++ b/demucs/apply.py @@ -0,0 +1,294 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Code to apply a model to a mix. It will handle chunking with overlaps and +inteprolation between chunks, as well as the "shift trick". +""" +from concurrent.futures import ThreadPoolExecutor +import random +import typing as tp +from multiprocessing import Process,Queue,Pipe + +import torch as th +from torch import nn +from torch.nn import functional as F +import tqdm +import tkinter as tk + +from .demucs import Demucs +from .hdemucs import HDemucs +from .utils import center_trim, DummyPoolExecutor + +Model = tp.Union[Demucs, HDemucs] + +progress_bar_num = 0 + +class BagOfModels(nn.Module): + def __init__(self, models: tp.List[Model], + weights: tp.Optional[tp.List[tp.List[float]]] = None, + segment: tp.Optional[float] = None): + """ + Represents a bag of models with specific weights. + You should call `apply_model` rather than calling directly the forward here for + optimal performance. + + Args: + models (list[nn.Module]): list of Demucs/HDemucs models. + weights (list[list[float]]): list of weights. If None, assumed to + be all ones, otherwise it should be a list of N list (N number of models), + each containing S floats (S number of sources). + segment (None or float): overrides the `segment` attribute of each model + (this is performed inplace, be careful if you reuse the models passed). + """ + + super().__init__() + assert len(models) > 0 + first = models[0] + for other in models: + assert other.sources == first.sources + assert other.samplerate == first.samplerate + assert other.audio_channels == first.audio_channels + if segment is not None: + other.segment = segment + + self.audio_channels = first.audio_channels + self.samplerate = first.samplerate + self.sources = first.sources + self.models = nn.ModuleList(models) + + if weights is None: + weights = [[1. for _ in first.sources] for _ in models] + else: + assert len(weights) == len(models) + for weight in weights: + assert len(weight) == len(first.sources) + self.weights = weights + + def forward(self, x): + raise NotImplementedError("Call `apply_model` on this.") + +class TensorChunk: + def __init__(self, tensor, offset=0, length=None): + total_length = tensor.shape[-1] + assert offset >= 0 + assert offset < total_length + + if length is None: + length = total_length - offset + else: + length = min(total_length - offset, length) + + if isinstance(tensor, TensorChunk): + self.tensor = tensor.tensor + self.offset = offset + tensor.offset + else: + self.tensor = tensor + self.offset = offset + self.length = length + self.device = tensor.device + + @property + def shape(self): + shape = list(self.tensor.shape) + shape[-1] = self.length + return shape + + def padded(self, target_length): + delta = target_length - self.length + total_length = self.tensor.shape[-1] + assert delta >= 0 + + start = self.offset - delta // 2 + end = start + target_length + + correct_start = max(0, start) + correct_end = min(total_length, end) + + pad_left = correct_start - start + pad_right = end - correct_end + + out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) + assert out.shape[-1] == target_length + return out + +def tensor_chunk(tensor_or_chunk): + if isinstance(tensor_or_chunk, TensorChunk): + return tensor_or_chunk + else: + assert isinstance(tensor_or_chunk, th.Tensor) + return TensorChunk(tensor_or_chunk) + +def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1., static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None): + """ + Apply model to a given mixture. + + Args: + shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec + and apply the oppositve shift to the output. This is repeated `shifts` time and + all predictions are averaged. This effectively makes the model time equivariant + and improves SDR by up to 0.2 points. + split (bool): if True, the input will be broken down in 8 seconds extracts + and predictions will be performed individually on each and concatenated. + Useful for model with large memory footprint like Tasnet. + progress (bool): if True, show a progress bar (requires split=True) + device (torch.device, str, or None): if provided, device on which to + execute the computation, otherwise `mix.device` is assumed. + When `device` is different from `mix.device`, only local computations will + be on `device`, while the entire tracks will be stored on `mix.device`. + """ + + global fut_length + global bag_num + global prog_bar + + if device is None: + device = mix.device + else: + device = th.device(device) + if pool is None: + if num_workers > 0 and device.type == 'cpu': + pool = ThreadPoolExecutor(num_workers) + else: + pool = DummyPoolExecutor() + + kwargs = { + 'shifts': shifts, + 'split': split, + 'overlap': overlap, + 'transition_power': transition_power, + 'progress': progress, + 'device': device, + 'pool': pool, + 'set_progress_bar': set_progress_bar, + 'static_shifts': static_shifts, + } + + if isinstance(model, BagOfModels): + # Special treatment for bag of model. + # We explicitely apply multiple times `apply_model` so that the random shifts + # are different for each model. + + estimates = 0 + totals = [0] * len(model.sources) + bag_num = len(model.models) + fut_length = 0 + prog_bar = 0 + current_model = 0 #(bag_num + 1) + for sub_model, weight in zip(model.models, model.weights): + original_model_device = next(iter(sub_model.parameters())).device + sub_model.to(device) + fut_length += fut_length + current_model += 1 + out = apply_model(sub_model, mix, **kwargs) + sub_model.to(original_model_device) + for k, inst_weight in enumerate(weight): + out[:, k, :, :] *= inst_weight + totals[k] += inst_weight + estimates += out + del out + + for k in range(estimates.shape[1]): + estimates[:, k, :, :] /= totals[k] + return estimates + + model.to(device) + model.eval() + assert transition_power >= 1, "transition_power < 1 leads to weird behavior." + batch, channels, length = mix.shape + + if shifts: + kwargs['shifts'] = 0 + max_shift = int(0.5 * model.samplerate) + mix = tensor_chunk(mix) + padded_mix = mix.padded(length + 2 * max_shift) + out = 0 + for _ in range(shifts): + offset = random.randint(0, max_shift) + shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) + shifted_out = apply_model(model, shifted, **kwargs) + out += shifted_out[..., max_shift - offset:] + out /= shifts + return out + elif split: + kwargs['split'] = False + out = th.zeros(batch, len(model.sources), channels, length, device=mix.device) + sum_weight = th.zeros(length, device=mix.device) + segment = int(model.samplerate * model.segment) + stride = int((1 - overlap) * segment) + offsets = range(0, length, stride) + scale = float(format(stride / model.samplerate, ".2f")) + # We start from a triangle shaped weight, with maximal weight in the middle + # of the segment. Then we normalize and take to the power `transition_power`. + # Large values of transition power will lead to sharper transitions. + weight = th.cat([th.arange(1, segment // 2 + 1, device=device), + th.arange(segment - segment // 2, 0, -1, device=device)]) + assert len(weight) == segment + # If the overlap < 50%, this will translate to linear transition when + # transition_power is 1. + weight = (weight / weight.max())**transition_power + futures = [] + for offset in offsets: + chunk = TensorChunk(mix, offset, segment) + future = pool.submit(apply_model, model, chunk, **kwargs) + futures.append((future, offset)) + offset += segment + if progress: + futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds') + for future, offset in futures: + if set_progress_bar: + fut_length = (len(futures) * bag_num * static_shifts) + prog_bar += 1 + set_progress_bar(0.1, (0.8/fut_length*prog_bar)) + chunk_out = future.result() + chunk_length = chunk_out.shape[-1] + out[..., offset:offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device) + sum_weight[offset:offset + segment] += weight[:chunk_length].to(mix.device) + assert sum_weight.min() > 0 + out /= sum_weight + return out + else: + if hasattr(model, 'valid_length'): + valid_length = model.valid_length(length) + else: + valid_length = length + mix = tensor_chunk(mix) + padded_mix = mix.padded(valid_length).to(device) + with th.no_grad(): + out = model(padded_mix) + return center_trim(out, length) + +def demucs_segments(demucs_segment, demucs_model): + + if demucs_segment == 'Default': + segment = None + if isinstance(demucs_model, BagOfModels): + if segment is not None: + for sub in demucs_model.models: + sub.segment = segment + else: + if segment is not None: + sub.segment = segment + else: + try: + segment = int(demucs_segment) + if isinstance(demucs_model, BagOfModels): + if segment is not None: + for sub in demucs_model.models: + sub.segment = segment + else: + if segment is not None: + sub.segment = segment + except: + segment = None + if isinstance(demucs_model, BagOfModels): + if segment is not None: + for sub in demucs_model.models: + sub.segment = segment + else: + if segment is not None: + sub.segment = segment + + return demucs_model \ No newline at end of file diff --git a/demucs/demucs.py b/demucs/demucs.py new file mode 100644 index 0000000000000000000000000000000000000000..d2c08e73d65de3031a1e1be545b68afd5554f7a5 --- /dev/null +++ b/demucs/demucs.py @@ -0,0 +1,459 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import typing as tp + +import julius +import torch +from torch import nn +from torch.nn import functional as F + +from .states import capture_init +from .utils import center_trim, unfold + + +class BLSTM(nn.Module): + """ + BiLSTM with same hidden units as input dim. + If `max_steps` is not None, input will be splitting in overlapping + chunks and the LSTM applied separately on each chunk. + """ + def __init__(self, dim, layers=1, max_steps=None, skip=False): + super().__init__() + assert max_steps is None or max_steps % 4 == 0 + self.max_steps = max_steps + self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) + self.linear = nn.Linear(2 * dim, dim) + self.skip = skip + + def forward(self, x): + B, C, T = x.shape + y = x + framed = False + if self.max_steps is not None and T > self.max_steps: + width = self.max_steps + stride = width // 2 + frames = unfold(x, width, stride) + nframes = frames.shape[2] + framed = True + x = frames.permute(0, 2, 1, 3).reshape(-1, C, width) + + x = x.permute(2, 0, 1) + + x = self.lstm(x)[0] + x = self.linear(x) + x = x.permute(1, 2, 0) + if framed: + out = [] + frames = x.reshape(B, -1, C, width) + limit = stride // 2 + for k in range(nframes): + if k == 0: + out.append(frames[:, k, :, :-limit]) + elif k == nframes - 1: + out.append(frames[:, k, :, limit:]) + else: + out.append(frames[:, k, :, limit:-limit]) + out = torch.cat(out, -1) + out = out[..., :T] + x = out + if self.skip: + x = x + y + return x + + +def rescale_conv(conv, reference): + """Rescale initial weight scale. It is unclear why it helps but it certainly does. + """ + std = conv.weight.std().detach() + scale = (std / reference)**0.5 + conv.weight.data /= scale + if conv.bias is not None: + conv.bias.data /= scale + + +def rescale_module(module, reference): + for sub in module.modules(): + if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)): + rescale_conv(sub, reference) + + +class LayerScale(nn.Module): + """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). + This rescales diagonaly residual outputs close to 0 initially, then learnt. + """ + def __init__(self, channels: int, init: float = 0): + super().__init__() + self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) + self.scale.data[:] = init + + def forward(self, x): + return self.scale[:, None] * x + + +class DConv(nn.Module): + """ + New residual branches in each encoder layer. + This alternates dilated convolutions, potentially with LSTMs and attention. + Also before entering each residual branch, dimension is projected on a smaller subspace, + e.g. of dim `channels // compress`. + """ + def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4, + norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True, + kernel=3, dilate=True): + """ + Args: + channels: input/output channels for residual branch. + compress: amount of channel compression inside the branch. + depth: number of layers in the residual branch. Each layer has its own + projection, and potentially LSTM and attention. + init: initial scale for LayerNorm. + norm: use GroupNorm. + attn: use LocalAttention. + heads: number of heads for the LocalAttention. + ndecay: number of decay controls in the LocalAttention. + lstm: use LSTM. + gelu: Use GELU activation. + kernel: kernel size for the (dilated) convolutions. + dilate: if true, use dilation, increasing with the depth. + """ + + super().__init__() + assert kernel % 2 == 1 + self.channels = channels + self.compress = compress + self.depth = abs(depth) + dilate = depth > 0 + + norm_fn: tp.Callable[[int], nn.Module] + norm_fn = lambda d: nn.Identity() # noqa + if norm: + norm_fn = lambda d: nn.GroupNorm(1, d) # noqa + + hidden = int(channels / compress) + + act: tp.Type[nn.Module] + if gelu: + act = nn.GELU + else: + act = nn.ReLU + + self.layers = nn.ModuleList([]) + for d in range(self.depth): + dilation = 2 ** d if dilate else 1 + padding = dilation * (kernel // 2) + mods = [ + nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding), + norm_fn(hidden), act(), + nn.Conv1d(hidden, 2 * channels, 1), + norm_fn(2 * channels), nn.GLU(1), + LayerScale(channels, init), + ] + if attn: + mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay)) + if lstm: + mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True)) + layer = nn.Sequential(*mods) + self.layers.append(layer) + + def forward(self, x): + for layer in self.layers: + x = x + layer(x) + return x + + +class LocalState(nn.Module): + """Local state allows to have attention based only on data (no positional embedding), + but while setting a constraint on the time window (e.g. decaying penalty term). + + Also a failed experiments with trying to provide some frequency based attention. + """ + def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4): + super().__init__() + assert channels % heads == 0, (channels, heads) + self.heads = heads + self.nfreqs = nfreqs + self.ndecay = ndecay + self.content = nn.Conv1d(channels, channels, 1) + self.query = nn.Conv1d(channels, channels, 1) + self.key = nn.Conv1d(channels, channels, 1) + if nfreqs: + self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1) + if ndecay: + self.query_decay = nn.Conv1d(channels, heads * ndecay, 1) + # Initialize decay close to zero (there is a sigmoid), for maximum initial window. + self.query_decay.weight.data *= 0.01 + assert self.query_decay.bias is not None # stupid type checker + self.query_decay.bias.data[:] = -2 + self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1) + + def forward(self, x): + B, C, T = x.shape + heads = self.heads + indexes = torch.arange(T, device=x.device, dtype=x.dtype) + # left index are keys, right index are queries + delta = indexes[:, None] - indexes[None, :] + + queries = self.query(x).view(B, heads, -1, T) + keys = self.key(x).view(B, heads, -1, T) + # t are keys, s are queries + dots = torch.einsum("bhct,bhcs->bhts", keys, queries) + dots /= keys.shape[2]**0.5 + if self.nfreqs: + periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype) + freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1)) + freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5 + dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q) + if self.ndecay: + decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype) + decay_q = self.query_decay(x).view(B, heads, -1, T) + decay_q = torch.sigmoid(decay_q) / 2 + decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5 + dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q) + + # Kill self reference. + dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100) + weights = torch.softmax(dots, dim=2) + + content = self.content(x).view(B, heads, -1, T) + result = torch.einsum("bhts,bhct->bhcs", weights, content) + if self.nfreqs: + time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel) + result = torch.cat([result, time_sig], 2) + result = result.reshape(B, -1, T) + return x + self.proj(result) + + +class Demucs(nn.Module): + @capture_init + def __init__(self, + sources, + # Channels + audio_channels=2, + channels=64, + growth=2., + # Main structure + depth=6, + rewrite=True, + lstm_layers=0, + # Convolutions + kernel_size=8, + stride=4, + context=1, + # Activations + gelu=True, + glu=True, + # Normalization + norm_starts=4, + norm_groups=4, + # DConv residual branch + dconv_mode=1, + dconv_depth=2, + dconv_comp=4, + dconv_attn=4, + dconv_lstm=4, + dconv_init=1e-4, + # Pre/post processing + normalize=True, + resample=True, + # Weight init + rescale=0.1, + # Metadata + samplerate=44100, + segment=4 * 10): + """ + Args: + sources (list[str]): list of source names + audio_channels (int): stereo or mono + channels (int): first convolution channels + depth (int): number of encoder/decoder layers + growth (float): multiply (resp divide) number of channels by that + for each layer of the encoder (resp decoder) + depth (int): number of layers in the encoder and in the decoder. + rewrite (bool): add 1x1 convolution to each layer. + lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated + by default, as this is now replaced by the smaller and faster small LSTMs + in the DConv branches. + kernel_size (int): kernel size for convolutions + stride (int): stride for convolutions + context (int): kernel size of the convolution in the + decoder before the transposed convolution. If > 1, + will provide some context from neighboring time steps. + gelu: use GELU activation function. + glu (bool): use glu instead of ReLU for the 1x1 rewrite conv. + norm_starts: layer at which group norm starts being used. + decoder layers are numbered in reverse order. + norm_groups: number of groups for group norm. + dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. + dconv_depth: depth of residual DConv branch. + dconv_comp: compression of DConv branch. + dconv_attn: adds attention layers in DConv branch starting at this layer. + dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. + dconv_init: initial scale for the DConv branch LayerScale. + normalize (bool): normalizes the input audio on the fly, and scales back + the output by the same amount. + resample (bool): upsample x2 the input and downsample /2 the output. + rescale (int): rescale initial weights of convolutions + to get their standard deviation closer to `rescale`. + samplerate (int): stored as meta information for easing + future evaluations of the model. + segment (float): duration of the chunks of audio to ideally evaluate the model on. + This is used by `demucs.apply.apply_model`. + """ + + super().__init__() + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.resample = resample + self.channels = channels + self.normalize = normalize + self.samplerate = samplerate + self.segment = segment + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + self.skip_scales = nn.ModuleList() + + if glu: + activation = nn.GLU(dim=1) + ch_scale = 2 + else: + activation = nn.ReLU() + ch_scale = 1 + if gelu: + act2 = nn.GELU + else: + act2 = nn.ReLU + + in_channels = audio_channels + padding = 0 + for index in range(depth): + norm_fn = lambda d: nn.Identity() # noqa + if index >= norm_starts: + norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa + + encode = [] + encode += [ + nn.Conv1d(in_channels, channels, kernel_size, stride), + norm_fn(channels), + act2(), + ] + attn = index >= dconv_attn + lstm = index >= dconv_lstm + if dconv_mode & 1: + encode += [DConv(channels, depth=dconv_depth, init=dconv_init, + compress=dconv_comp, attn=attn, lstm=lstm)] + if rewrite: + encode += [ + nn.Conv1d(channels, ch_scale * channels, 1), + norm_fn(ch_scale * channels), activation] + self.encoder.append(nn.Sequential(*encode)) + + decode = [] + if index > 0: + out_channels = in_channels + else: + out_channels = len(self.sources) * audio_channels + if rewrite: + decode += [ + nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context), + norm_fn(ch_scale * channels), activation] + if dconv_mode & 2: + decode += [DConv(channels, depth=dconv_depth, init=dconv_init, + compress=dconv_comp, attn=attn, lstm=lstm)] + decode += [nn.ConvTranspose1d(channels, out_channels, + kernel_size, stride, padding=padding)] + if index > 0: + decode += [norm_fn(out_channels), act2()] + self.decoder.insert(0, nn.Sequential(*decode)) + in_channels = channels + channels = int(growth * channels) + + channels = in_channels + if lstm_layers: + self.lstm = BLSTM(channels, lstm_layers) + else: + self.lstm = None + + if rescale: + rescale_module(self, reference=rescale) + + def valid_length(self, length): + """ + Return the nearest valid length to use with the model so that + there is no time steps left over in a convolution, e.g. for all + layers, size of the input - kernel_size % stride = 0. + + Note that input are automatically padded if necessary to ensure that the output + has the same length as the input. + """ + if self.resample: + length *= 2 + + for _ in range(self.depth): + length = math.ceil((length - self.kernel_size) / self.stride) + 1 + length = max(1, length) + + for idx in range(self.depth): + length = (length - 1) * self.stride + self.kernel_size + + if self.resample: + length = math.ceil(length / 2) + return int(length) + + def forward(self, mix): + x = mix + length = x.shape[-1] + + if self.normalize: + mono = mix.mean(dim=1, keepdim=True) + mean = mono.mean(dim=-1, keepdim=True) + std = mono.std(dim=-1, keepdim=True) + x = (x - mean) / (1e-5 + std) + else: + mean = 0 + std = 1 + + delta = self.valid_length(length) - length + x = F.pad(x, (delta // 2, delta - delta // 2)) + + if self.resample: + x = julius.resample_frac(x, 1, 2) + + saved = [] + for encode in self.encoder: + x = encode(x) + saved.append(x) + + if self.lstm: + x = self.lstm(x) + + for decode in self.decoder: + skip = saved.pop(-1) + skip = center_trim(skip, x) + x = decode(x + skip) + + if self.resample: + x = julius.resample_frac(x, 2, 1) + x = x * std + mean + x = center_trim(x, length) + x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1)) + return x + + def load_state_dict(self, state, strict=True): + # fix a mismatch with previous generation Demucs models. + for idx in range(self.depth): + for a in ['encoder', 'decoder']: + for b in ['bias', 'weight']: + new = f'{a}.{idx}.3.{b}' + old = f'{a}.{idx}.2.{b}' + if old in state and new not in state: + state[new] = state.pop(old) + super().load_state_dict(state, strict=strict) diff --git a/demucs/filtering.py b/demucs/filtering.py new file mode 100644 index 0000000000000000000000000000000000000000..08a2c17deae2dd4ace033eaa14619821e2a70b03 --- /dev/null +++ b/demucs/filtering.py @@ -0,0 +1,502 @@ +from typing import Optional +import torch +import torch.nn as nn +from torch import Tensor +from torch.utils.data import DataLoader + +def atan2(y, x): + r"""Element-wise arctangent function of y/x. + Returns a new tensor with signed angles in radians. + It is an alternative implementation of torch.atan2 + + Args: + y (Tensor): First input tensor + x (Tensor): Second input tensor [shape=y.shape] + + Returns: + Tensor: [shape=y.shape]. + """ + pi = 2 * torch.asin(torch.tensor(1.0)) + x += ((x == 0) & (y == 0)) * 1.0 + out = torch.atan(y / x) + out += ((y >= 0) & (x < 0)) * pi + out -= ((y < 0) & (x < 0)) * pi + out *= 1 - ((y > 0) & (x == 0)) * 1.0 + out += ((y > 0) & (x == 0)) * (pi / 2) + out *= 1 - ((y < 0) & (x == 0)) * 1.0 + out += ((y < 0) & (x == 0)) * (-pi / 2) + return out + + +# Define basic complex operations on torch.Tensor objects whose last dimension +# consists in the concatenation of the real and imaginary parts. + + +def _norm(x: torch.Tensor) -> torch.Tensor: + r"""Computes the norm value of a torch Tensor, assuming that it + comes as real and imaginary part in its last dimension. + + Args: + x (Tensor): Input Tensor of shape [shape=(..., 2)] + + Returns: + Tensor: shape as x excluding the last dimension. + """ + return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2 + + +def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: + """Element-wise multiplication of two complex Tensors described + through their real and imaginary parts. + The result is added to the `out` tensor""" + + # check `out` and allocate it if needed + target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)]) + if out is None or out.shape != target_shape: + out = torch.zeros(target_shape, dtype=a.dtype, device=a.device) + if out is a: + real_a = a[..., 0] + out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1]) + out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0]) + else: + out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]) + out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]) + return out + + +def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: + """Element-wise multiplication of two complex Tensors described + through their real and imaginary parts + can work in place in case out is a only""" + target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)]) + if out is None or out.shape != target_shape: + out = torch.zeros(target_shape, dtype=a.dtype, device=a.device) + if out is a: + real_a = a[..., 0] + out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1] + out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0] + else: + out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1] + out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0] + return out + + +def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: + """Element-wise multiplicative inverse of a Tensor with complex + entries described through their real and imaginary parts. + can work in place in case out is z""" + ez = _norm(z) + if out is None or out.shape != z.shape: + out = torch.zeros_like(z) + out[..., 0] = z[..., 0] / ez + out[..., 1] = -z[..., 1] / ez + return out + + +def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor: + """Element-wise complex conjugate of a Tensor with complex entries + described through their real and imaginary parts. + can work in place in case out is z""" + if out is None or out.shape != z.shape: + out = torch.zeros_like(z) + out[..., 0] = z[..., 0] + out[..., 1] = -z[..., 1] + return out + + +def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Invert 1x1 or 2x2 matrices + + Will generate errors if the matrices are singular: user must handle this + through his own regularization schemes. + + Args: + M (Tensor): [shape=(..., nb_channels, nb_channels, 2)] + matrices to invert: must be square along dimensions -3 and -2 + + Returns: + invM (Tensor): [shape=M.shape] + inverses of M + """ + nb_channels = M.shape[-2] + + if out is None or out.shape != M.shape: + out = torch.empty_like(M) + + if nb_channels == 1: + # scalar case + out = _inv(M, out) + elif nb_channels == 2: + # two channels case: analytical expression + + # first compute the determinent + det = _mul(M[..., 0, 0, :], M[..., 1, 1, :]) + det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :]) + # invert it + invDet = _inv(det) + + # then fill out the matrix with the inverse + out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :]) + out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :]) + out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :]) + out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :]) + else: + raise Exception("Only 2 channels are supported for the torch version.") + return out + + +# Now define the signal-processing low-level functions used by the Separator + + +def expectation_maximization( + y: torch.Tensor, + x: torch.Tensor, + iterations: int = 2, + eps: float = 1e-10, + batch_size: int = 200, +): + r"""Expectation maximization algorithm, for refining source separation + estimates. + + This algorithm allows to make source separation results better by + enforcing multichannel consistency for the estimates. This usually means + a better perceptual quality in terms of spatial artifacts. + + The implementation follows the details presented in [1]_, taking + inspiration from the original EM algorithm proposed in [2]_ and its + weighted refinement proposed in [3]_, [4]_. + It works by iteratively: + + * Re-estimate source parameters (power spectral densities and spatial + covariance matrices) through :func:`get_local_gaussian_model`. + + * Separate again the mixture with the new parameters by first computing + the new modelled mixture covariance matrices with :func:`get_mix_model`, + prepare the Wiener filters through :func:`wiener_gain` and apply them + with :func:`apply_filter``. + + References + ---------- + .. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and + N. Takahashi and Y. Mitsufuji, "Improving music source separation based + on deep neural networks through data augmentation and network + blending." 2017 IEEE International Conference on Acoustics, Speech + and Signal Processing (ICASSP). IEEE, 2017. + + .. [2] N.Q. Duong and E. Vincent and R.Gribonval. "Under-determined + reverberant audio source separation using a full-rank spatial + covariance model." IEEE Transactions on Audio, Speech, and Language + Processing 18.7 (2010): 1830-1840. + + .. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source + separation with deep neural networks." IEEE/ACM Transactions on Audio, + Speech, and Language Processing 24.9 (2016): 1652-1664. + + .. [4] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music + separation with deep neural networks." 2016 24th European Signal + Processing Conference (EUSIPCO). IEEE, 2016. + + .. [5] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for + source separation." IEEE Transactions on Signal Processing + 62.16 (2014): 4298-4310. + + Args: + y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)] + initial estimates for the sources + x (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2)] + complex STFT of the mixture signal + iterations (int): [scalar] + number of iterations for the EM algorithm. + eps (float or None): [scalar] + The epsilon value to use for regularization and filters. + + Returns: + y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)] + estimated sources after iterations + v (Tensor): [shape=(nb_frames, nb_bins, nb_sources)] + estimated power spectral densities + R (Tensor): [shape=(nb_bins, nb_channels, nb_channels, 2, nb_sources)] + estimated spatial covariance matrices + + Notes: + * You need an initial estimate for the sources to apply this + algorithm. This is precisely what the :func:`wiener` function does. + * This algorithm *is not* an implementation of the "exact" EM + proposed in [1]_. In particular, it does compute the posterior + covariance matrices the same (exact) way. Instead, it uses the + simplified approximate scheme initially proposed in [5]_ and further + refined in [3]_, [4]_, that boils down to just take the empirical + covariance of the recent source estimates, followed by a weighted + average for the update of the spatial covariance matrix. It has been + empirically demonstrated that this simplified algorithm is more + robust for music separation. + + Warning: + It is *very* important to make sure `x.dtype` is `torch.float64` + if you want double precision, because this function will **not** + do such conversion for you from `torch.complex32`, in case you want the + smaller RAM usage on purpose. + + It is usually always better in terms of quality to have double + precision, by e.g. calling :func:`expectation_maximization` + with ``x.to(torch.float64)``. + """ + # dimensions + (nb_frames, nb_bins, nb_channels) = x.shape[:-1] + nb_sources = y.shape[-1] + + regularization = torch.cat( + ( + torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None], + torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device), + ), + dim=2, + ) + regularization = torch.sqrt(torch.as_tensor(eps)) * ( + regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1)) + ) + + # allocate the spatial covariance matrices + R = [ + torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device) + for j in range(nb_sources) + ] + weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device) + + v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device) + for it in range(iterations): + # constructing the mixture covariance matrix. Doing it with a loop + # to avoid storing anytime in RAM the whole 6D tensor + + # update the PSD as the average spectrogram over channels + v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2) + + # update spatial covariance matrices (weighted update) + for j in range(nb_sources): + R[j] = torch.tensor(0.0, device=x.device) + weight = torch.tensor(eps, device=x.device) + pos: int = 0 + batch_size = batch_size if batch_size else nb_frames + while pos < nb_frames: + t = torch.arange(pos, min(nb_frames, pos + batch_size)) + pos = int(t[-1]) + 1 + + R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0) + weight = weight + torch.sum(v[t, ..., j], dim=0) + R[j] = R[j] / weight[..., None, None, None] + weight = torch.zeros_like(weight) + + # cloning y if we track gradient, because we're going to update it + if y.requires_grad: + y = y.clone() + + pos = 0 + while pos < nb_frames: + t = torch.arange(pos, min(nb_frames, pos + batch_size)) + pos = int(t[-1]) + 1 + + y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype) + + # compute mix covariance matrix + Cxx = regularization + for j in range(nb_sources): + Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone()) + + # invert it + inv_Cxx = _invert(Cxx) + + # separate the sources + for j in range(nb_sources): + + # create a wiener gain for this source + gain = torch.zeros_like(inv_Cxx) + + # computes multichannel Wiener gain as v_j R_j inv_Cxx + indices = torch.cartesian_prod( + torch.arange(nb_channels), + torch.arange(nb_channels), + torch.arange(nb_channels), + ) + for index in indices: + gain[:, :, index[0], index[1], :] = _mul_add( + R[j][None, :, index[0], index[2], :].clone(), + inv_Cxx[:, :, index[2], index[1], :], + gain[:, :, index[0], index[1], :], + ) + gain = gain * v[t, ..., None, None, None, j] + + # apply it to the mixture + for i in range(nb_channels): + y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j]) + + return y, v, R + + +def wiener( + targets_spectrograms: torch.Tensor, + mix_stft: torch.Tensor, + iterations: int = 1, + softmask: bool = False, + residual: bool = False, + scale_factor: float = 10.0, + eps: float = 1e-10, +): + """Wiener-based separation for multichannel audio. + + The method uses the (possibly multichannel) spectrograms of the + sources to separate the (complex) Short Term Fourier Transform of the + mix. Separation is done in a sequential way by: + + * Getting an initial estimate. This can be done in two ways: either by + directly using the spectrograms with the mixture phase, or + by using a softmasking strategy. This initial phase is controlled + by the `softmask` flag. + + * If required, adding an additional residual target as the mix minus + all targets. + + * Refinining these initial estimates through a call to + :func:`expectation_maximization` if the number of iterations is nonzero. + + This implementation also allows to specify the epsilon value used for + regularization. It is based on [1]_, [2]_, [3]_, [4]_. + + References + ---------- + .. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and + N. Takahashi and Y. Mitsufuji, "Improving music source separation based + on deep neural networks through data augmentation and network + blending." 2017 IEEE International Conference on Acoustics, Speech + and Signal Processing (ICASSP). IEEE, 2017. + + .. [2] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source + separation with deep neural networks." IEEE/ACM Transactions on Audio, + Speech, and Language Processing 24.9 (2016): 1652-1664. + + .. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music + separation with deep neural networks." 2016 24th European Signal + Processing Conference (EUSIPCO). IEEE, 2016. + + .. [4] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for + source separation." IEEE Transactions on Signal Processing + 62.16 (2014): 4298-4310. + + Args: + targets_spectrograms (Tensor): spectrograms of the sources + [shape=(nb_frames, nb_bins, nb_channels, nb_sources)]. + This is a nonnegative tensor that is + usually the output of the actual separation method of the user. The + spectrograms may be mono, but they need to be 4-dimensional in all + cases. + mix_stft (Tensor): [shape=(nb_frames, nb_bins, nb_channels, complex=2)] + STFT of the mixture signal. + iterations (int): [scalar] + number of iterations for the EM algorithm + softmask (bool): Describes how the initial estimates are obtained. + * if `False`, then the mixture phase will directly be used with the + spectrogram as initial estimates. + * if `True`, initial estimates are obtained by multiplying the + complex mix element-wise with the ratio of each target spectrogram + with the sum of them all. This strategy is better if the model are + not really good, and worse otherwise. + residual (bool): if `True`, an additional target is created, which is + equal to the mixture minus the other targets, before application of + expectation maximization + eps (float): Epsilon value to use for computing the separations. + This is used whenever division with a model energy is + performed, i.e. when softmasking and when iterating the EM. + It can be understood as the energy of the additional white noise + that is taken out when separating. + + Returns: + Tensor: shape=(nb_frames, nb_bins, nb_channels, complex=2, nb_sources) + STFT of estimated sources + + Notes: + * Be careful that you need *magnitude spectrogram estimates* for the + case `softmask==False`. + * `softmask=False` is recommended + * The epsilon value will have a huge impact on performance. If it's + large, only the parts of the signal with a significant energy will + be kept in the sources. This epsilon then directly controls the + energy of the reconstruction error. + + Warning: + As in :func:`expectation_maximization`, we recommend converting the + mixture `x` to double precision `torch.float64` *before* calling + :func:`wiener`. + """ + if softmask: + # if we use softmask, we compute the ratio mask for all targets and + # multiply by the mix stft + y = ( + mix_stft[..., None] + * ( + targets_spectrograms + / (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype)) + )[..., None, :] + ) + else: + # otherwise, we just multiply the targets spectrograms with mix phase + # we tacitly assume that we have magnitude estimates. + angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None] + nb_sources = targets_spectrograms.shape[-1] + y = torch.zeros( + mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device + ) + y[..., 0, :] = targets_spectrograms * torch.cos(angle) + y[..., 1, :] = targets_spectrograms * torch.sin(angle) + + if residual: + # if required, adding an additional target as the mix minus + # available targets + y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1) + + if iterations == 0: + return y + + # we need to refine the estimates. Scales down the estimates for + # numerical stability + max_abs = torch.max( + torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device), + torch.sqrt(_norm(mix_stft)).max() / scale_factor, + ) + + mix_stft = mix_stft / max_abs + y = y / max_abs + + # call expectation maximization + y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0] + + # scale estimates up again + y = y * max_abs + return y + + +def _covariance(y_j): + """ + Compute the empirical covariance for a source. + + Args: + y_j (Tensor): complex stft of the source. + [shape=(nb_frames, nb_bins, nb_channels, 2)]. + + Returns: + Cj (Tensor): [shape=(nb_frames, nb_bins, nb_channels, nb_channels, 2)] + just y_j * conj(y_j.T): empirical covariance for each TF bin. + """ + (nb_frames, nb_bins, nb_channels) = y_j.shape[:-1] + Cj = torch.zeros( + (nb_frames, nb_bins, nb_channels, nb_channels, 2), + dtype=y_j.dtype, + device=y_j.device, + ) + indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels)) + for index in indices: + Cj[:, :, index[0], index[1], :] = _mul_add( + y_j[:, :, index[0], :], + _conj(y_j[:, :, index[1], :]), + Cj[:, :, index[0], index[1], :], + ) + return Cj diff --git a/demucs/hdemucs.py b/demucs/hdemucs.py new file mode 100644 index 0000000000000000000000000000000000000000..d776d556479d6266d0fff45dca8ca0c34cf23c3e --- /dev/null +++ b/demucs/hdemucs.py @@ -0,0 +1,782 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +This code contains the spectrogram and Hybrid version of Demucs. +""" +from copy import deepcopy +import math +import typing as tp +import torch +from torch import nn +from torch.nn import functional as F +from .filtering import wiener +from .demucs import DConv, rescale_module +from .states import capture_init +from .spec import spectro, ispectro + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen.""" + x0 = x + length = x.shape[-1] + padding_left, padding_right = paddings + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + if length <= max_pad: + extra_pad = max_pad - length + 1 + extra_pad_right = min(padding_right, extra_pad) + extra_pad_left = extra_pad - extra_pad_right + paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right) + x = F.pad(x, (extra_pad_left, extra_pad_right)) + out = F.pad(x, paddings, mode, value) + assert out.shape[-1] == length + padding_left + padding_right + assert (out[..., padding_left: padding_left + length] == x0).all() + return out + +class ScaledEmbedding(nn.Module): + """ + Boost learning rate for embeddings (with `scale`). + Also, can make embeddings continuous with `smooth`. + """ + def __init__(self, num_embeddings: int, embedding_dim: int, + scale: float = 10., smooth=False): + super().__init__() + self.embedding = nn.Embedding(num_embeddings, embedding_dim) + if smooth: + weight = torch.cumsum(self.embedding.weight.data, dim=0) + # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that. + weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None] + self.embedding.weight.data[:] = weight + self.embedding.weight.data /= scale + self.scale = scale + + @property + def weight(self): + return self.embedding.weight * self.scale + + def forward(self, x): + out = self.embedding(x) * self.scale + return out + + +class HEncLayer(nn.Module): + def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, + freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True, + rewrite=True): + """Encoder layer. This used both by the time and the frequency branch. + + Args: + chin: number of input channels. + chout: number of output channels. + norm_groups: number of groups for group norm. + empty: used to make a layer with just the first conv. this is used + before merging the time and freq. branches. + freq: this is acting on frequencies. + dconv: insert DConv residual branches. + norm: use GroupNorm. + context: context size for the 1x1 conv. + dconv_kw: list of kwargs for the DConv class. + pad: pad the input. Padding is done so that the output size is + always the input size / stride. + rewrite: add 1x1 conv at the end of the layer. + """ + super().__init__() + norm_fn = lambda d: nn.Identity() # noqa + if norm: + norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa + if pad: + pad = kernel_size // 4 + else: + pad = 0 + klass = nn.Conv1d + self.freq = freq + self.kernel_size = kernel_size + self.stride = stride + self.empty = empty + self.norm = norm + self.pad = pad + if freq: + kernel_size = [kernel_size, 1] + stride = [stride, 1] + pad = [pad, 0] + klass = nn.Conv2d + self.conv = klass(chin, chout, kernel_size, stride, pad) + if self.empty: + return + self.norm1 = norm_fn(chout) + self.rewrite = None + if rewrite: + self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context) + self.norm2 = norm_fn(2 * chout) + + self.dconv = None + if dconv: + self.dconv = DConv(chout, **dconv_kw) + + def forward(self, x, inject=None): + """ + `inject` is used to inject the result from the time branch into the frequency branch, + when both have the same stride. + """ + if not self.freq and x.dim() == 4: + B, C, Fr, T = x.shape + x = x.view(B, -1, T) + + if not self.freq: + le = x.shape[-1] + if not le % self.stride == 0: + x = F.pad(x, (0, self.stride - (le % self.stride))) + y = self.conv(x) + if self.empty: + return y + if inject is not None: + assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape) + if inject.dim() == 3 and y.dim() == 4: + inject = inject[:, :, None] + y = y + inject + y = F.gelu(self.norm1(y)) + if self.dconv: + if self.freq: + B, C, Fr, T = y.shape + y = y.permute(0, 2, 1, 3).reshape(-1, C, T) + y = self.dconv(y) + if self.freq: + y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) + if self.rewrite: + z = self.norm2(self.rewrite(y)) + z = F.glu(z, dim=1) + else: + z = y + return z + + +class MultiWrap(nn.Module): + """ + Takes one layer and replicate it N times. each replica will act + on a frequency band. All is done so that if the N replica have the same weights, + then this is exactly equivalent to applying the original module on all frequencies. + + This is a bit over-engineered to avoid edge artifacts when splitting + the frequency bands, but it is possible the naive implementation would work as well... + """ + def __init__(self, layer, split_ratios): + """ + Args: + layer: module to clone, must be either HEncLayer or HDecLayer. + split_ratios: list of float indicating which ratio to keep for each band. + """ + super().__init__() + self.split_ratios = split_ratios + self.layers = nn.ModuleList() + self.conv = isinstance(layer, HEncLayer) + assert not layer.norm + assert layer.freq + assert layer.pad + if not self.conv: + assert not layer.context_freq + for k in range(len(split_ratios) + 1): + lay = deepcopy(layer) + if self.conv: + lay.conv.padding = (0, 0) + else: + lay.pad = False + for m in lay.modules(): + if hasattr(m, 'reset_parameters'): + m.reset_parameters() + self.layers.append(lay) + + def forward(self, x, skip=None, length=None): + B, C, Fr, T = x.shape + + ratios = list(self.split_ratios) + [1] + start = 0 + outs = [] + for ratio, layer in zip(ratios, self.layers): + if self.conv: + pad = layer.kernel_size // 4 + if ratio == 1: + limit = Fr + frames = -1 + else: + limit = int(round(Fr * ratio)) + le = limit - start + if start == 0: + le += pad + frames = round((le - layer.kernel_size) / layer.stride + 1) + limit = start + (frames - 1) * layer.stride + layer.kernel_size + if start == 0: + limit -= pad + assert limit - start > 0, (limit, start) + assert limit <= Fr, (limit, Fr) + y = x[:, :, start:limit, :] + if start == 0: + y = F.pad(y, (0, 0, pad, 0)) + if ratio == 1: + y = F.pad(y, (0, 0, 0, pad)) + outs.append(layer(y)) + start = limit - layer.kernel_size + layer.stride + else: + if ratio == 1: + limit = Fr + else: + limit = int(round(Fr * ratio)) + last = layer.last + layer.last = True + + y = x[:, :, start:limit] + s = skip[:, :, start:limit] + out, _ = layer(y, s, None) + if outs: + outs[-1][:, :, -layer.stride:] += ( + out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1)) + out = out[:, :, layer.stride:] + if ratio == 1: + out = out[:, :, :-layer.stride // 2, :] + if start == 0: + out = out[:, :, layer.stride // 2:, :] + outs.append(out) + layer.last = last + start = limit + out = torch.cat(outs, dim=2) + if not self.conv and not last: + out = F.gelu(out) + if self.conv: + return out + else: + return out, None + + +class HDecLayer(nn.Module): + def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, + freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True, + context_freq=True, rewrite=True): + """ + Same as HEncLayer but for decoder. See `HEncLayer` for documentation. + """ + super().__init__() + norm_fn = lambda d: nn.Identity() # noqa + if norm: + norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa + if pad: + pad = kernel_size // 4 + else: + pad = 0 + self.pad = pad + self.last = last + self.freq = freq + self.chin = chin + self.empty = empty + self.stride = stride + self.kernel_size = kernel_size + self.norm = norm + self.context_freq = context_freq + klass = nn.Conv1d + klass_tr = nn.ConvTranspose1d + if freq: + kernel_size = [kernel_size, 1] + stride = [stride, 1] + klass = nn.Conv2d + klass_tr = nn.ConvTranspose2d + self.conv_tr = klass_tr(chin, chout, kernel_size, stride) + self.norm2 = norm_fn(chout) + if self.empty: + return + self.rewrite = None + if rewrite: + if context_freq: + self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context) + else: + self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, + [0, context]) + self.norm1 = norm_fn(2 * chin) + + self.dconv = None + if dconv: + self.dconv = DConv(chin, **dconv_kw) + + def forward(self, x, skip, length): + if self.freq and x.dim() == 3: + B, C, T = x.shape + x = x.view(B, self.chin, -1, T) + + if not self.empty: + x = x + skip + + if self.rewrite: + y = F.glu(self.norm1(self.rewrite(x)), dim=1) + else: + y = x + if self.dconv: + if self.freq: + B, C, Fr, T = y.shape + y = y.permute(0, 2, 1, 3).reshape(-1, C, T) + y = self.dconv(y) + if self.freq: + y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) + else: + y = x + assert skip is None + z = self.norm2(self.conv_tr(y)) + if self.freq: + if self.pad: + z = z[..., self.pad:-self.pad, :] + else: + z = z[..., self.pad:self.pad + length] + assert z.shape[-1] == length, (z.shape[-1], length) + if not self.last: + z = F.gelu(z) + return z, y + + +class HDemucs(nn.Module): + """ + Spectrogram and hybrid Demucs model. + The spectrogram model has the same structure as Demucs, except the first few layers are over the + frequency axis, until there is only 1 frequency, and then it moves to time convolutions. + Frequency layers can still access information across time steps thanks to the DConv residual. + + Hybrid model have a parallel time branch. At some layer, the time branch has the same stride + as the frequency branch and then the two are combined. The opposite happens in the decoder. + + Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]), + or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on + Open Unmix implementation [Stoter et al. 2019]. + + The loss is always on the temporal domain, by backpropagating through the above + output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks + a bit Wiener filtering, as doing more iteration at test time will change the spectrogram + contribution, without changing the one from the waveform, which will lead to worse performance. + I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve. + CaC on the other hand provides similar performance for hybrid, and works naturally with + hybrid models. + + This model also uses frequency embeddings are used to improve efficiency on convolutions + over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf). + + Unlike classic Demucs, there is no resampling here, and normalization is always applied. + """ + @capture_init + def __init__(self, + sources, + # Channels + audio_channels=2, + channels=48, + channels_time=None, + growth=2, + # STFT + nfft=4096, + wiener_iters=0, + end_iters=0, + wiener_residual=False, + cac=True, + # Main structure + depth=6, + rewrite=True, + hybrid=True, + hybrid_old=False, + # Frequency branch + multi_freqs=None, + multi_freqs_depth=2, + freq_emb=0.2, + emb_scale=10, + emb_smooth=True, + # Convolutions + kernel_size=8, + time_stride=2, + stride=4, + context=1, + context_enc=0, + # Normalization + norm_starts=4, + norm_groups=4, + # DConv residual branch + dconv_mode=1, + dconv_depth=2, + dconv_comp=4, + dconv_attn=4, + dconv_lstm=4, + dconv_init=1e-4, + # Weight init + rescale=0.1, + # Metadata + samplerate=44100, + segment=4 * 10): + + """ + Args: + sources (list[str]): list of source names. + audio_channels (int): input/output audio channels. + channels (int): initial number of hidden channels. + channels_time: if not None, use a different `channels` value for the time branch. + growth: increase the number of hidden channels by this factor at each layer. + nfft: number of fft bins. Note that changing this require careful computation of + various shape parameters and will not work out of the box for hybrid models. + wiener_iters: when using Wiener filtering, number of iterations at test time. + end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. + wiener_residual: add residual source before wiener filtering. + cac: uses complex as channels, i.e. complex numbers are 2 channels each + in input and output. no further processing is done before ISTFT. + depth (int): number of layers in the encoder and in the decoder. + rewrite (bool): add 1x1 convolution to each layer. + hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only. + hybrid_old: some models trained for MDX had a padding bug. This replicates + this bug to avoid retraining them. + multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`. + multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost + layers will be wrapped. + freq_emb: add frequency embedding after the first frequency layer if > 0, + the actual value controls the weight of the embedding. + emb_scale: equivalent to scaling the embedding learning rate + emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). + kernel_size: kernel_size for encoder and decoder layers. + stride: stride for encoder and decoder layers. + time_stride: stride for the final time layer, after the merge. + context: context for 1x1 conv in the decoder. + context_enc: context for 1x1 conv in the encoder. + norm_starts: layer at which group norm starts being used. + decoder layers are numbered in reverse order. + norm_groups: number of groups for group norm. + dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. + dconv_depth: depth of residual DConv branch. + dconv_comp: compression of DConv branch. + dconv_attn: adds attention layers in DConv branch starting at this layer. + dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. + dconv_init: initial scale for the DConv branch LayerScale. + rescale: weight recaling trick + + """ + super().__init__() + + self.cac = cac + self.wiener_residual = wiener_residual + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.channels = channels + self.samplerate = samplerate + self.segment = segment + + self.nfft = nfft + self.hop_length = nfft // 4 + self.wiener_iters = wiener_iters + self.end_iters = end_iters + self.freq_emb = None + self.hybrid = hybrid + self.hybrid_old = hybrid_old + if hybrid_old: + assert hybrid, "hybrid_old must come with hybrid=True" + if hybrid: + assert wiener_iters == end_iters + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + if hybrid: + self.tencoder = nn.ModuleList() + self.tdecoder = nn.ModuleList() + + chin = audio_channels + chin_z = chin # number of channels for the freq branch + if self.cac: + chin_z *= 2 + chout = channels_time or channels + chout_z = channels + freqs = nfft // 2 + + for index in range(depth): + lstm = index >= dconv_lstm + attn = index >= dconv_attn + norm = index >= norm_starts + freq = freqs > 1 + stri = stride + ker = kernel_size + if not freq: + assert freqs == 1 + ker = time_stride * 2 + stri = time_stride + + pad = True + last_freq = False + if freq and freqs <= kernel_size: + ker = freqs + pad = False + last_freq = True + + kw = { + 'kernel_size': ker, + 'stride': stri, + 'freq': freq, + 'pad': pad, + 'norm': norm, + 'rewrite': rewrite, + 'norm_groups': norm_groups, + 'dconv_kw': { + 'lstm': lstm, + 'attn': attn, + 'depth': dconv_depth, + 'compress': dconv_comp, + 'init': dconv_init, + 'gelu': True, + } + } + kwt = dict(kw) + kwt['freq'] = 0 + kwt['kernel_size'] = kernel_size + kwt['stride'] = stride + kwt['pad'] = True + kw_dec = dict(kw) + multi = False + if multi_freqs and index < multi_freqs_depth: + multi = True + kw_dec['context_freq'] = False + + if last_freq: + chout_z = max(chout, chout_z) + chout = chout_z + + enc = HEncLayer(chin_z, chout_z, + dconv=dconv_mode & 1, context=context_enc, **kw) + if hybrid and freq: + tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, + empty=last_freq, **kwt) + self.tencoder.append(tenc) + + if multi: + enc = MultiWrap(enc, multi_freqs) + self.encoder.append(enc) + if index == 0: + chin = self.audio_channels * len(self.sources) + chin_z = chin + if self.cac: + chin_z *= 2 + dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, + last=index == 0, context=context, **kw_dec) + if multi: + dec = MultiWrap(dec, multi_freqs) + if hybrid and freq: + tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, + last=index == 0, context=context, **kwt) + self.tdecoder.insert(0, tdec) + self.decoder.insert(0, dec) + + chin = chout + chin_z = chout_z + chout = int(growth * chout) + chout_z = int(growth * chout_z) + if freq: + if freqs <= kernel_size: + freqs = 1 + else: + freqs //= stride + if index == 0 and freq_emb: + self.freq_emb = ScaledEmbedding( + freqs, chin_z, smooth=emb_smooth, scale=emb_scale) + self.freq_emb_scale = freq_emb + + if rescale: + rescale_module(self, reference=rescale) + + def _spec(self, x): + hl = self.hop_length + nfft = self.nfft + x0 = x # noqa + + if self.hybrid: + # We re-pad the signal in order to keep the property + # that the size of the output is exactly the size of the input + # divided by the stride (here hop_length), when divisible. + # This is achieved by padding by 1/4th of the kernel size (here nfft). + # which is not supported by torch.stft. + # Having all convolution operations follow this convention allow to easily + # align the time and frequency branches later on. + assert hl == nfft // 4 + le = int(math.ceil(x.shape[-1] / hl)) + pad = hl // 2 * 3 + if not self.hybrid_old: + x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect') + else: + x = pad1d(x, (pad, pad + le * hl - x.shape[-1])) + + z = spectro(x, nfft, hl)[..., :-1, :] + if self.hybrid: + assert z.shape[-1] == le + 4, (z.shape, x.shape, le) + z = z[..., 2:2+le] + return z + + def _ispec(self, z, length=None, scale=0): + hl = self.hop_length // (4 ** scale) + z = F.pad(z, (0, 0, 0, 1)) + if self.hybrid: + z = F.pad(z, (2, 2)) + pad = hl // 2 * 3 + if not self.hybrid_old: + le = hl * int(math.ceil(length / hl)) + 2 * pad + else: + le = hl * int(math.ceil(length / hl)) + x = ispectro(z, hl, length=le) + if not self.hybrid_old: + x = x[..., pad:pad + length] + else: + x = x[..., :length] + else: + x = ispectro(z, hl, length) + return x + + def _magnitude(self, z): + # return the magnitude of the spectrogram, except when cac is True, + # in which case we just move the complex dimension to the channel one. + if self.cac: + B, C, Fr, T = z.shape + m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) + m = m.reshape(B, C * 2, Fr, T) + else: + m = z.abs() + return m + + def _mask(self, z, m): + # Apply masking given the mixture spectrogram `z` and the estimated mask `m`. + # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored. + niters = self.wiener_iters + if self.cac: + B, S, C, Fr, T = m.shape + out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) + out = torch.view_as_complex(out.contiguous()) + return out + if self.training: + niters = self.end_iters + if niters < 0: + z = z[:, None] + return z / (1e-8 + z.abs()) * m + else: + return self._wiener(m, z, niters) + + def _wiener(self, mag_out, mix_stft, niters): + # apply wiener filtering from OpenUnmix. + init = mix_stft.dtype + wiener_win_len = 300 + residual = self.wiener_residual + + B, S, C, Fq, T = mag_out.shape + mag_out = mag_out.permute(0, 4, 3, 2, 1) + mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) + + outs = [] + for sample in range(B): + pos = 0 + out = [] + for pos in range(0, T, wiener_win_len): + frame = slice(pos, pos + wiener_win_len) + z_out = wiener( + mag_out[sample, frame], mix_stft[sample, frame], niters, + residual=residual) + out.append(z_out.transpose(-1, -2)) + outs.append(torch.cat(out, dim=0)) + out = torch.view_as_complex(torch.stack(outs, 0)) + out = out.permute(0, 4, 3, 2, 1).contiguous() + if residual: + out = out[:, :-1] + assert list(out.shape) == [B, S, C, Fq, T] + return out.to(init) + + def forward(self, mix): + x = mix + length = x.shape[-1] + + z = self._spec(mix) + mag = self._magnitude(z) + x = mag + + B, C, Fq, T = x.shape + + # unlike previous Demucs, we always normalize because it is easier. + mean = x.mean(dim=(1, 2, 3), keepdim=True) + std = x.std(dim=(1, 2, 3), keepdim=True) + x = (x - mean) / (1e-5 + std) + # x will be the freq. branch input. + + if self.hybrid: + # Prepare the time branch input. + xt = mix + meant = xt.mean(dim=(1, 2), keepdim=True) + stdt = xt.std(dim=(1, 2), keepdim=True) + xt = (xt - meant) / (1e-5 + stdt) + + # okay, this is a giant mess I know... + saved = [] # skip connections, freq. + saved_t = [] # skip connections, time. + lengths = [] # saved lengths to properly remove padding, freq branch. + lengths_t = [] # saved lengths for time branch. + for idx, encode in enumerate(self.encoder): + lengths.append(x.shape[-1]) + inject = None + if self.hybrid and idx < len(self.tencoder): + # we have not yet merged branches. + lengths_t.append(xt.shape[-1]) + tenc = self.tencoder[idx] + xt = tenc(xt) + if not tenc.empty: + # save for skip connection + saved_t.append(xt) + else: + # tenc contains just the first conv., so that now time and freq. + # branches have the same shape and can be merged. + inject = xt + x = encode(x, inject) + if idx == 0 and self.freq_emb is not None: + # add frequency embedding to allow for non equivariant convolutions + # over the frequency axis. + frs = torch.arange(x.shape[-2], device=x.device) + emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) + x = x + self.freq_emb_scale * emb + + saved.append(x) + + x = torch.zeros_like(x) + if self.hybrid: + xt = torch.zeros_like(x) + # initialize everything to zero (signal will go through u-net skips). + + for idx, decode in enumerate(self.decoder): + skip = saved.pop(-1) + x, pre = decode(x, skip, lengths.pop(-1)) + # `pre` contains the output just before final transposed convolution, + # which is used when the freq. and time branch separate. + + if self.hybrid: + offset = self.depth - len(self.tdecoder) + if self.hybrid and idx >= offset: + tdec = self.tdecoder[idx - offset] + length_t = lengths_t.pop(-1) + if tdec.empty: + assert pre.shape[2] == 1, pre.shape + pre = pre[:, :, 0] + xt, _ = tdec(pre, None, length_t) + else: + skip = saved_t.pop(-1) + xt, _ = tdec(xt, skip, length_t) + + # Let's make sure we used all stored skip connections. + assert len(saved) == 0 + assert len(lengths_t) == 0 + assert len(saved_t) == 0 + + S = len(self.sources) + x = x.view(B, S, -1, Fq, T) + x = x * std[:, None] + mean[:, None] + + zout = self._mask(z, x) + x = self._ispec(zout, length) + + if self.hybrid: + xt = xt.view(B, S, -1, length) + xt = xt * stdt[:, None] + meant[:, None] + x = xt + x + return x + + diff --git a/demucs/htdemucs.py b/demucs/htdemucs.py new file mode 100644 index 0000000000000000000000000000000000000000..ffa466bbe353e9f2aa18c18c66abeebfc2b813dd --- /dev/null +++ b/demucs/htdemucs.py @@ -0,0 +1,648 @@ +# Copyright (c) Meta, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# First author is Simon Rouard. +""" +This code contains the spectrogram and Hybrid version of Demucs. +""" +import math + +from .filtering import wiener +import torch +from torch import nn +from torch.nn import functional as F +from fractions import Fraction +from einops import rearrange + +from .transformer import CrossTransformerEncoder + +from .demucs import rescale_module +from .states import capture_init +from .spec import spectro, ispectro +from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer + + +class HTDemucs(nn.Module): + """ + Spectrogram and hybrid Demucs model. + The spectrogram model has the same structure as Demucs, except the first few layers are over the + frequency axis, until there is only 1 frequency, and then it moves to time convolutions. + Frequency layers can still access information across time steps thanks to the DConv residual. + + Hybrid model have a parallel time branch. At some layer, the time branch has the same stride + as the frequency branch and then the two are combined. The opposite happens in the decoder. + + Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]), + or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on + Open Unmix implementation [Stoter et al. 2019]. + + The loss is always on the temporal domain, by backpropagating through the above + output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks + a bit Wiener filtering, as doing more iteration at test time will change the spectrogram + contribution, without changing the one from the waveform, which will lead to worse performance. + I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve. + CaC on the other hand provides similar performance for hybrid, and works naturally with + hybrid models. + + This model also uses frequency embeddings are used to improve efficiency on convolutions + over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf). + + Unlike classic Demucs, there is no resampling here, and normalization is always applied. + """ + + @capture_init + def __init__( + self, + sources, + # Channels + audio_channels=2, + channels=48, + channels_time=None, + growth=2, + # STFT + nfft=4096, + wiener_iters=0, + end_iters=0, + wiener_residual=False, + cac=True, + # Main structure + depth=4, + rewrite=True, + # Frequency branch + multi_freqs=None, + multi_freqs_depth=3, + freq_emb=0.2, + emb_scale=10, + emb_smooth=True, + # Convolutions + kernel_size=8, + time_stride=2, + stride=4, + context=1, + context_enc=0, + # Normalization + norm_starts=4, + norm_groups=4, + # DConv residual branch + dconv_mode=1, + dconv_depth=2, + dconv_comp=8, + dconv_init=1e-3, + # Before the Transformer + bottom_channels=0, + # Transformer + t_layers=5, + t_emb="sin", + t_hidden_scale=4.0, + t_heads=8, + t_dropout=0.0, + t_max_positions=10000, + t_norm_in=True, + t_norm_in_group=False, + t_group_norm=False, + t_norm_first=True, + t_norm_out=True, + t_max_period=10000.0, + t_weight_decay=0.0, + t_lr=None, + t_layer_scale=True, + t_gelu=True, + t_weight_pos_embed=1.0, + t_sin_random_shift=0, + t_cape_mean_normalize=True, + t_cape_augment=True, + t_cape_glob_loc_scale=[5000.0, 1.0, 1.4], + t_sparse_self_attn=False, + t_sparse_cross_attn=False, + t_mask_type="diag", + t_mask_random_seed=42, + t_sparse_attn_window=500, + t_global_window=100, + t_sparsity=0.95, + t_auto_sparsity=False, + # ------ Particuliar parameters + t_cross_first=False, + # Weight init + rescale=0.1, + # Metadata + samplerate=44100, + segment=10, + use_train_segment=True, + ): + """ + Args: + sources (list[str]): list of source names. + audio_channels (int): input/output audio channels. + channels (int): initial number of hidden channels. + channels_time: if not None, use a different `channels` value for the time branch. + growth: increase the number of hidden channels by this factor at each layer. + nfft: number of fft bins. Note that changing this require careful computation of + various shape parameters and will not work out of the box for hybrid models. + wiener_iters: when using Wiener filtering, number of iterations at test time. + end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. + wiener_residual: add residual source before wiener filtering. + cac: uses complex as channels, i.e. complex numbers are 2 channels each + in input and output. no further processing is done before ISTFT. + depth (int): number of layers in the encoder and in the decoder. + rewrite (bool): add 1x1 convolution to each layer. + multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`. + multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost + layers will be wrapped. + freq_emb: add frequency embedding after the first frequency layer if > 0, + the actual value controls the weight of the embedding. + emb_scale: equivalent to scaling the embedding learning rate + emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). + kernel_size: kernel_size for encoder and decoder layers. + stride: stride for encoder and decoder layers. + time_stride: stride for the final time layer, after the merge. + context: context for 1x1 conv in the decoder. + context_enc: context for 1x1 conv in the encoder. + norm_starts: layer at which group norm starts being used. + decoder layers are numbered in reverse order. + norm_groups: number of groups for group norm. + dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. + dconv_depth: depth of residual DConv branch. + dconv_comp: compression of DConv branch. + dconv_attn: adds attention layers in DConv branch starting at this layer. + dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. + dconv_init: initial scale for the DConv branch LayerScale. + bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the + transformer in order to change the number of channels + t_layers: number of layers in each branch (waveform and spec) of the transformer + t_emb: "sin", "cape" or "scaled" + t_hidden_scale: the hidden scale of the Feedforward parts of the transformer + for instance if C = 384 (the number of channels in the transformer) and + t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension + 384 * 4 = 1536 + t_heads: number of heads for the transformer + t_dropout: dropout in the transformer + t_max_positions: max_positions for the "scaled" positional embedding, only + useful if t_emb="scaled" + t_norm_in: (bool) norm before addinf positional embedding and getting into the + transformer layers + t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the + timesteps (GroupNorm with group=1) + t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the + timesteps (GroupNorm with group=1) + t_norm_first: (bool) if True the norm is before the attention and before the FFN + t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer + t_max_period: (float) denominator in the sinusoidal embedding expression + t_weight_decay: (float) weight decay for the transformer + t_lr: (float) specific learning rate for the transformer + t_layer_scale: (bool) Layer Scale for the transformer + t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else + t_weight_pos_embed: (float) weighting of the positional embedding + t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings + see: https://arxiv.org/abs/2106.03143 + t_cape_augment: (bool) if t_emb="cape", must be True during training and False + during the inference, see: https://arxiv.org/abs/2106.03143 + t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters + see: https://arxiv.org/abs/2106.03143 + t_sparse_self_attn: (bool) if True, the self attentions are sparse + t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it + unless you designed really specific masks) + t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination + with '_' between: i.e. "diag_jmask_random" (note that this is permutation + invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag") + t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed + that generated the random part of the mask + t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and + a key (j), the mask is True id |i-j|<=t_sparse_attn_window + t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :] + and mask[:, :t_global_window] will be True + t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity + level of the random part of the mask. + t_cross_first: (bool) if True cross attention is the first layer of the + transformer (False seems to be better) + rescale: weight rescaling trick + use_train_segment: (bool) if True, the actual size that is used during the + training is used during inference. + """ + super().__init__() + self.cac = cac + self.wiener_residual = wiener_residual + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.bottom_channels = bottom_channels + self.channels = channels + self.samplerate = samplerate + self.segment = segment + self.use_train_segment = use_train_segment + self.nfft = nfft + self.hop_length = nfft // 4 + self.wiener_iters = wiener_iters + self.end_iters = end_iters + self.freq_emb = None + assert wiener_iters == end_iters + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.tencoder = nn.ModuleList() + self.tdecoder = nn.ModuleList() + + chin = audio_channels + chin_z = chin # number of channels for the freq branch + if self.cac: + chin_z *= 2 + chout = channels_time or channels + chout_z = channels + freqs = nfft // 2 + + for index in range(depth): + norm = index >= norm_starts + freq = freqs > 1 + stri = stride + ker = kernel_size + if not freq: + assert freqs == 1 + ker = time_stride * 2 + stri = time_stride + + pad = True + last_freq = False + if freq and freqs <= kernel_size: + ker = freqs + pad = False + last_freq = True + + kw = { + "kernel_size": ker, + "stride": stri, + "freq": freq, + "pad": pad, + "norm": norm, + "rewrite": rewrite, + "norm_groups": norm_groups, + "dconv_kw": { + "depth": dconv_depth, + "compress": dconv_comp, + "init": dconv_init, + "gelu": True, + }, + } + kwt = dict(kw) + kwt["freq"] = 0 + kwt["kernel_size"] = kernel_size + kwt["stride"] = stride + kwt["pad"] = True + kw_dec = dict(kw) + multi = False + if multi_freqs and index < multi_freqs_depth: + multi = True + kw_dec["context_freq"] = False + + if last_freq: + chout_z = max(chout, chout_z) + chout = chout_z + + enc = HEncLayer( + chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw + ) + if freq: + tenc = HEncLayer( + chin, + chout, + dconv=dconv_mode & 1, + context=context_enc, + empty=last_freq, + **kwt + ) + self.tencoder.append(tenc) + + if multi: + enc = MultiWrap(enc, multi_freqs) + self.encoder.append(enc) + if index == 0: + chin = self.audio_channels * len(self.sources) + chin_z = chin + if self.cac: + chin_z *= 2 + dec = HDecLayer( + chout_z, + chin_z, + dconv=dconv_mode & 2, + last=index == 0, + context=context, + **kw_dec + ) + if multi: + dec = MultiWrap(dec, multi_freqs) + if freq: + tdec = HDecLayer( + chout, + chin, + dconv=dconv_mode & 2, + empty=last_freq, + last=index == 0, + context=context, + **kwt + ) + self.tdecoder.insert(0, tdec) + self.decoder.insert(0, dec) + + chin = chout + chin_z = chout_z + chout = int(growth * chout) + chout_z = int(growth * chout_z) + if freq: + if freqs <= kernel_size: + freqs = 1 + else: + freqs //= stride + if index == 0 and freq_emb: + self.freq_emb = ScaledEmbedding( + freqs, chin_z, smooth=emb_smooth, scale=emb_scale + ) + self.freq_emb_scale = freq_emb + + if rescale: + rescale_module(self, reference=rescale) + + transformer_channels = channels * growth ** (depth - 1) + if bottom_channels: + self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1) + self.channel_downsampler = nn.Conv1d( + bottom_channels, transformer_channels, 1 + ) + self.channel_upsampler_t = nn.Conv1d( + transformer_channels, bottom_channels, 1 + ) + self.channel_downsampler_t = nn.Conv1d( + bottom_channels, transformer_channels, 1 + ) + + transformer_channels = bottom_channels + + if t_layers > 0: + self.crosstransformer = CrossTransformerEncoder( + dim=transformer_channels, + emb=t_emb, + hidden_scale=t_hidden_scale, + num_heads=t_heads, + num_layers=t_layers, + cross_first=t_cross_first, + dropout=t_dropout, + max_positions=t_max_positions, + norm_in=t_norm_in, + norm_in_group=t_norm_in_group, + group_norm=t_group_norm, + norm_first=t_norm_first, + norm_out=t_norm_out, + max_period=t_max_period, + weight_decay=t_weight_decay, + lr=t_lr, + layer_scale=t_layer_scale, + gelu=t_gelu, + sin_random_shift=t_sin_random_shift, + weight_pos_embed=t_weight_pos_embed, + cape_mean_normalize=t_cape_mean_normalize, + cape_augment=t_cape_augment, + cape_glob_loc_scale=t_cape_glob_loc_scale, + sparse_self_attn=t_sparse_self_attn, + sparse_cross_attn=t_sparse_cross_attn, + mask_type=t_mask_type, + mask_random_seed=t_mask_random_seed, + sparse_attn_window=t_sparse_attn_window, + global_window=t_global_window, + sparsity=t_sparsity, + auto_sparsity=t_auto_sparsity, + ) + else: + self.crosstransformer = None + + def _spec(self, x): + hl = self.hop_length + nfft = self.nfft + x0 = x # noqa + + # We re-pad the signal in order to keep the property + # that the size of the output is exactly the size of the input + # divided by the stride (here hop_length), when divisible. + # This is achieved by padding by 1/4th of the kernel size (here nfft). + # which is not supported by torch.stft. + # Having all convolution operations follow this convention allow to easily + # align the time and frequency branches later on. + assert hl == nfft // 4 + le = int(math.ceil(x.shape[-1] / hl)) + pad = hl // 2 * 3 + x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect") + + z = spectro(x, nfft, hl)[..., :-1, :] + assert z.shape[-1] == le + 4, (z.shape, x.shape, le) + z = z[..., 2: 2 + le] + return z + + def _ispec(self, z, length=None, scale=0): + hl = self.hop_length // (4**scale) + z = F.pad(z, (0, 0, 0, 1)) + z = F.pad(z, (2, 2)) + pad = hl // 2 * 3 + le = hl * int(math.ceil(length / hl)) + 2 * pad + x = ispectro(z, hl, length=le) + x = x[..., pad: pad + length] + return x + + def _magnitude(self, z): + # return the magnitude of the spectrogram, except when cac is True, + # in which case we just move the complex dimension to the channel one. + if self.cac: + B, C, Fr, T = z.shape + m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) + m = m.reshape(B, C * 2, Fr, T) + else: + m = z.abs() + return m + + def _mask(self, z, m): + # Apply masking given the mixture spectrogram `z` and the estimated mask `m`. + # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored. + niters = self.wiener_iters + if self.cac: + B, S, C, Fr, T = m.shape + out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) + out = torch.view_as_complex(out.contiguous()) + return out + if self.training: + niters = self.end_iters + if niters < 0: + z = z[:, None] + return z / (1e-8 + z.abs()) * m + else: + return self._wiener(m, z, niters) + + def _wiener(self, mag_out, mix_stft, niters): + # apply wiener filtering from OpenUnmix. + init = mix_stft.dtype + wiener_win_len = 300 + residual = self.wiener_residual + + B, S, C, Fq, T = mag_out.shape + mag_out = mag_out.permute(0, 4, 3, 2, 1) + mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) + + outs = [] + for sample in range(B): + pos = 0 + out = [] + for pos in range(0, T, wiener_win_len): + frame = slice(pos, pos + wiener_win_len) + z_out = wiener( + mag_out[sample, frame], + mix_stft[sample, frame], + niters, + residual=residual, + ) + out.append(z_out.transpose(-1, -2)) + outs.append(torch.cat(out, dim=0)) + out = torch.view_as_complex(torch.stack(outs, 0)) + out = out.permute(0, 4, 3, 2, 1).contiguous() + if residual: + out = out[:, :-1] + assert list(out.shape) == [B, S, C, Fq, T] + return out.to(init) + + def valid_length(self, length: int): + """ + Return a length that is appropriate for evaluation. + In our case, always return the training length, unless + it is smaller than the given length, in which case this + raises an error. + """ + if not self.use_train_segment: + return length + training_length = int(self.segment * self.samplerate) + if training_length < length: + raise ValueError( + f"Given length {length} is longer than " + f"training length {training_length}") + return training_length + + def forward(self, mix): + length = mix.shape[-1] + length_pre_pad = None + if self.use_train_segment: + if self.training: + self.segment = Fraction(mix.shape[-1], self.samplerate) + else: + training_length = int(self.segment * self.samplerate) + if mix.shape[-1] < training_length: + length_pre_pad = mix.shape[-1] + mix = F.pad(mix, (0, training_length - length_pre_pad)) + z = self._spec(mix) + mag = self._magnitude(z) + x = mag + + B, C, Fq, T = x.shape + + # unlike previous Demucs, we always normalize because it is easier. + mean = x.mean(dim=(1, 2, 3), keepdim=True) + std = x.std(dim=(1, 2, 3), keepdim=True) + x = (x - mean) / (1e-5 + std) + # x will be the freq. branch input. + + # Prepare the time branch input. + xt = mix + meant = xt.mean(dim=(1, 2), keepdim=True) + stdt = xt.std(dim=(1, 2), keepdim=True) + xt = (xt - meant) / (1e-5 + stdt) + + # okay, this is a giant mess I know... + saved = [] # skip connections, freq. + saved_t = [] # skip connections, time. + lengths = [] # saved lengths to properly remove padding, freq branch. + lengths_t = [] # saved lengths for time branch. + for idx, encode in enumerate(self.encoder): + lengths.append(x.shape[-1]) + inject = None + if idx < len(self.tencoder): + # we have not yet merged branches. + lengths_t.append(xt.shape[-1]) + tenc = self.tencoder[idx] + xt = tenc(xt) + if not tenc.empty: + # save for skip connection + saved_t.append(xt) + else: + # tenc contains just the first conv., so that now time and freq. + # branches have the same shape and can be merged. + inject = xt + x = encode(x, inject) + if idx == 0 and self.freq_emb is not None: + # add frequency embedding to allow for non equivariant convolutions + # over the frequency axis. + frs = torch.arange(x.shape[-2], device=x.device) + emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) + x = x + self.freq_emb_scale * emb + + saved.append(x) + if self.crosstransformer: + if self.bottom_channels: + b, c, f, t = x.shape + x = rearrange(x, "b c f t-> b c (f t)") + x = self.channel_upsampler(x) + x = rearrange(x, "b c (f t)-> b c f t", f=f) + xt = self.channel_upsampler_t(xt) + + x, xt = self.crosstransformer(x, xt) + + if self.bottom_channels: + x = rearrange(x, "b c f t-> b c (f t)") + x = self.channel_downsampler(x) + x = rearrange(x, "b c (f t)-> b c f t", f=f) + xt = self.channel_downsampler_t(xt) + + for idx, decode in enumerate(self.decoder): + skip = saved.pop(-1) + x, pre = decode(x, skip, lengths.pop(-1)) + # `pre` contains the output just before final transposed convolution, + # which is used when the freq. and time branch separate. + + offset = self.depth - len(self.tdecoder) + if idx >= offset: + tdec = self.tdecoder[idx - offset] + length_t = lengths_t.pop(-1) + if tdec.empty: + assert pre.shape[2] == 1, pre.shape + pre = pre[:, :, 0] + xt, _ = tdec(pre, None, length_t) + else: + skip = saved_t.pop(-1) + xt, _ = tdec(xt, skip, length_t) + + # Let's make sure we used all stored skip connections. + assert len(saved) == 0 + assert len(lengths_t) == 0 + assert len(saved_t) == 0 + + S = len(self.sources) + x = x.view(B, S, -1, Fq, T) + x = x * std[:, None] + mean[:, None] + + zout = self._mask(z, x) + if self.use_train_segment: + if self.training: + x = self._ispec(zout, length) + else: + x = self._ispec(zout, training_length) + else: + x = self._ispec(zout, length) + + if self.use_train_segment: + if self.training: + xt = xt.view(B, S, -1, length) + else: + xt = xt.view(B, S, -1, training_length) + else: + xt = xt.view(B, S, -1, length) + xt = xt * stdt[:, None] + meant[:, None] + x = xt + x + if length_pre_pad: + x = x[..., :length_pre_pad] + return x diff --git a/demucs/model.py b/demucs/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e2745b8c75322a6fa58b7c163eda7667381abbe9 --- /dev/null +++ b/demucs/model.py @@ -0,0 +1,218 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch as th +from torch import nn + +from .utils import capture_init, center_trim + + +class BLSTM(nn.Module): + def __init__(self, dim, layers=1): + super().__init__() + self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) + self.linear = nn.Linear(2 * dim, dim) + + def forward(self, x): + x = x.permute(2, 0, 1) + x = self.lstm(x)[0] + x = self.linear(x) + x = x.permute(1, 2, 0) + return x + + +def rescale_conv(conv, reference): + std = conv.weight.std().detach() + scale = (std / reference)**0.5 + conv.weight.data /= scale + if conv.bias is not None: + conv.bias.data /= scale + + +def rescale_module(module, reference): + for sub in module.modules(): + if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): + rescale_conv(sub, reference) + + +def upsample(x, stride): + """ + Linear upsampling, the output will be `stride` times longer. + """ + batch, channels, time = x.size() + weight = th.arange(stride, device=x.device, dtype=th.float) / stride + x = x.view(batch, channels, time, 1) + out = x[..., :-1, :] * (1 - weight) + x[..., 1:, :] * weight + return out.reshape(batch, channels, -1) + + +def downsample(x, stride): + """ + Downsample x by decimation. + """ + return x[:, :, ::stride] + + +class Demucs(nn.Module): + @capture_init + def __init__(self, + sources=4, + audio_channels=2, + channels=64, + depth=6, + rewrite=True, + glu=True, + upsample=False, + rescale=0.1, + kernel_size=8, + stride=4, + growth=2., + lstm_layers=2, + context=3, + samplerate=44100): + """ + Args: + sources (int): number of sources to separate + audio_channels (int): stereo or mono + channels (int): first convolution channels + depth (int): number of encoder/decoder layers + rewrite (bool): add 1x1 convolution to each encoder layer + and a convolution to each decoder layer. + For the decoder layer, `context` gives the kernel size. + glu (bool): use glu instead of ReLU + upsample (bool): use linear upsampling with convolutions + Wave-U-Net style, instead of transposed convolutions + rescale (int): rescale initial weights of convolutions + to get their standard deviation closer to `rescale` + kernel_size (int): kernel size for convolutions + stride (int): stride for convolutions + growth (float): multiply (resp divide) number of channels by that + for each layer of the encoder (resp decoder) + lstm_layers (int): number of lstm layers, 0 = no lstm + context (int): kernel size of the convolution in the + decoder before the transposed convolution. If > 1, + will provide some context from neighboring time + steps. + """ + + super().__init__() + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.upsample = upsample + self.channels = channels + self.samplerate = samplerate + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.final = None + if upsample: + self.final = nn.Conv1d(channels + audio_channels, sources * audio_channels, 1) + stride = 1 + + if glu: + activation = nn.GLU(dim=1) + ch_scale = 2 + else: + activation = nn.ReLU() + ch_scale = 1 + in_channels = audio_channels + for index in range(depth): + encode = [] + encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()] + if rewrite: + encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation] + self.encoder.append(nn.Sequential(*encode)) + + decode = [] + if index > 0: + out_channels = in_channels + else: + if upsample: + out_channels = channels + else: + out_channels = sources * audio_channels + if rewrite: + decode += [nn.Conv1d(channels, ch_scale * channels, context), activation] + if upsample: + decode += [ + nn.Conv1d(channels, out_channels, kernel_size, stride=1), + ] + else: + decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)] + if index > 0: + decode.append(nn.ReLU()) + self.decoder.insert(0, nn.Sequential(*decode)) + in_channels = channels + channels = int(growth * channels) + + channels = in_channels + + if lstm_layers: + self.lstm = BLSTM(channels, lstm_layers) + else: + self.lstm = None + + if rescale: + rescale_module(self, reference=rescale) + + def valid_length(self, length): + """ + Return the nearest valid length to use with the model so that + there is no time steps left over in a convolutions, e.g. for all + layers, size of the input - kernel_size % stride = 0. + + If the mixture has a valid length, the estimated sources + will have exactly the same length when context = 1. If context > 1, + the two signals can be center trimmed to match. + + For training, extracts should have a valid length.For evaluation + on full tracks we recommend passing `pad = True` to :method:`forward`. + """ + for _ in range(self.depth): + if self.upsample: + length = math.ceil(length / self.stride) + self.kernel_size - 1 + else: + length = math.ceil((length - self.kernel_size) / self.stride) + 1 + length = max(1, length) + length += self.context - 1 + for _ in range(self.depth): + if self.upsample: + length = length * self.stride + self.kernel_size - 1 + else: + length = (length - 1) * self.stride + self.kernel_size + + return int(length) + + def forward(self, mix): + x = mix + saved = [x] + for encode in self.encoder: + x = encode(x) + saved.append(x) + if self.upsample: + x = downsample(x, self.stride) + if self.lstm: + x = self.lstm(x) + for decode in self.decoder: + if self.upsample: + x = upsample(x, stride=self.stride) + skip = center_trim(saved.pop(-1), x) + x = x + skip + x = decode(x) + if self.final: + skip = center_trim(saved.pop(-1), x) + x = th.cat([x, skip], dim=1) + x = self.final(x) + + x = x.view(x.size(0), self.sources, self.audio_channels, x.size(-1)) + return x diff --git a/demucs/model_v2.py b/demucs/model_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..db43fc5ebba683e85d9abfa8433e679ff648e216 --- /dev/null +++ b/demucs/model_v2.py @@ -0,0 +1,218 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import julius +from torch import nn +from .tasnet_v2 import ConvTasNet + +from .utils import capture_init, center_trim + + +class BLSTM(nn.Module): + def __init__(self, dim, layers=1): + super().__init__() + self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) + self.linear = nn.Linear(2 * dim, dim) + + def forward(self, x): + x = x.permute(2, 0, 1) + x = self.lstm(x)[0] + x = self.linear(x) + x = x.permute(1, 2, 0) + return x + + +def rescale_conv(conv, reference): + std = conv.weight.std().detach() + scale = (std / reference)**0.5 + conv.weight.data /= scale + if conv.bias is not None: + conv.bias.data /= scale + + +def rescale_module(module, reference): + for sub in module.modules(): + if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): + rescale_conv(sub, reference) + +def auto_load_demucs_model_v2(sources, demucs_model_name): + + if '48' in demucs_model_name: + channels=48 + elif 'unittest' in demucs_model_name: + channels=4 + else: + channels=64 + + if 'tasnet' in demucs_model_name: + init_demucs_model = ConvTasNet(sources, X=10) + else: + init_demucs_model = Demucs(sources, channels=channels) + + return init_demucs_model + +class Demucs(nn.Module): + @capture_init + def __init__(self, + sources, + audio_channels=2, + channels=64, + depth=6, + rewrite=True, + glu=True, + rescale=0.1, + resample=True, + kernel_size=8, + stride=4, + growth=2., + lstm_layers=2, + context=3, + normalize=False, + samplerate=44100, + segment_length=4 * 10 * 44100): + """ + Args: + sources (list[str]): list of source names + audio_channels (int): stereo or mono + channels (int): first convolution channels + depth (int): number of encoder/decoder layers + rewrite (bool): add 1x1 convolution to each encoder layer + and a convolution to each decoder layer. + For the decoder layer, `context` gives the kernel size. + glu (bool): use glu instead of ReLU + resample_input (bool): upsample x2 the input and downsample /2 the output. + rescale (int): rescale initial weights of convolutions + to get their standard deviation closer to `rescale` + kernel_size (int): kernel size for convolutions + stride (int): stride for convolutions + growth (float): multiply (resp divide) number of channels by that + for each layer of the encoder (resp decoder) + lstm_layers (int): number of lstm layers, 0 = no lstm + context (int): kernel size of the convolution in the + decoder before the transposed convolution. If > 1, + will provide some context from neighboring time + steps. + samplerate (int): stored as meta information for easing + future evaluations of the model. + segment_length (int): stored as meta information for easing + future evaluations of the model. Length of the segments on which + the model was trained. + """ + + super().__init__() + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.resample = resample + self.channels = channels + self.normalize = normalize + self.samplerate = samplerate + self.segment_length = segment_length + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + if glu: + activation = nn.GLU(dim=1) + ch_scale = 2 + else: + activation = nn.ReLU() + ch_scale = 1 + in_channels = audio_channels + for index in range(depth): + encode = [] + encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()] + if rewrite: + encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation] + self.encoder.append(nn.Sequential(*encode)) + + decode = [] + if index > 0: + out_channels = in_channels + else: + out_channels = len(self.sources) * audio_channels + if rewrite: + decode += [nn.Conv1d(channels, ch_scale * channels, context), activation] + decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)] + if index > 0: + decode.append(nn.ReLU()) + self.decoder.insert(0, nn.Sequential(*decode)) + in_channels = channels + channels = int(growth * channels) + + channels = in_channels + + if lstm_layers: + self.lstm = BLSTM(channels, lstm_layers) + else: + self.lstm = None + + if rescale: + rescale_module(self, reference=rescale) + + def valid_length(self, length): + """ + Return the nearest valid length to use with the model so that + there is no time steps left over in a convolutions, e.g. for all + layers, size of the input - kernel_size % stride = 0. + + If the mixture has a valid length, the estimated sources + will have exactly the same length when context = 1. If context > 1, + the two signals can be center trimmed to match. + + For training, extracts should have a valid length.For evaluation + on full tracks we recommend passing `pad = True` to :method:`forward`. + """ + if self.resample: + length *= 2 + for _ in range(self.depth): + length = math.ceil((length - self.kernel_size) / self.stride) + 1 + length = max(1, length) + length += self.context - 1 + for _ in range(self.depth): + length = (length - 1) * self.stride + self.kernel_size + + if self.resample: + length = math.ceil(length / 2) + return int(length) + + def forward(self, mix): + x = mix + + if self.normalize: + mono = mix.mean(dim=1, keepdim=True) + mean = mono.mean(dim=-1, keepdim=True) + std = mono.std(dim=-1, keepdim=True) + else: + mean = 0 + std = 1 + + x = (x - mean) / (1e-5 + std) + + if self.resample: + x = julius.resample_frac(x, 1, 2) + + saved = [] + for encode in self.encoder: + x = encode(x) + saved.append(x) + if self.lstm: + x = self.lstm(x) + for decode in self.decoder: + skip = center_trim(saved.pop(-1), x) + x = x + skip + x = decode(x) + + if self.resample: + x = julius.resample_frac(x, 2, 1) + x = x * std + mean + x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1)) + return x diff --git a/demucs/pretrained.py b/demucs/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..25aa6856eab03e17b6addb6d2ae5e367f49355f9 --- /dev/null +++ b/demucs/pretrained.py @@ -0,0 +1,180 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Loading pretrained models. +""" + +import logging +from pathlib import Path +import typing as tp + +#from dora.log import fatal + +import logging + +from diffq import DiffQuantizer +import torch.hub + +from .model import Demucs +from .tasnet_v2 import ConvTasNet +from .utils import set_state + +from .hdemucs import HDemucs +from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa + +logger = logging.getLogger(__name__) +ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/" +REMOTE_ROOT = Path(__file__).parent / 'remote' + +SOURCES = ["drums", "bass", "other", "vocals"] + + +def demucs_unittest(): + model = HDemucs(channels=4, sources=SOURCES) + return model + + +def add_model_flags(parser): + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument("-s", "--sig", help="Locally trained XP signature.") + group.add_argument("-n", "--name", default="mdx_extra_q", + help="Pretrained model name or signature. Default is mdx_extra_q.") + parser.add_argument("--repo", type=Path, + help="Folder containing all pre-trained models for use with -n.") + + +def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]: + root: str = '' + models: tp.Dict[str, str] = {} + for line in remote_file_list.read_text().split('\n'): + line = line.strip() + if line.startswith('#'): + continue + elif line.startswith('root:'): + root = line.split(':', 1)[1].strip() + else: + sig = line.split('-', 1)[0] + assert sig not in models + models[sig] = ROOT_URL + root + line + return models + +def get_model(name: str, + repo: tp.Optional[Path] = None): + """`name` must be a bag of models name or a pretrained signature + from the remote AWS model repo or the specified local repo if `repo` is not None. + """ + if name == 'demucs_unittest': + return demucs_unittest() + model_repo: ModelOnlyRepo + if repo is None: + models = _parse_remote_files(REMOTE_ROOT / 'files.txt') + model_repo = RemoteRepo(models) + bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo) + else: + if not repo.is_dir(): + fatal(f"{repo} must exist and be a directory.") + model_repo = LocalRepo(repo) + bag_repo = BagOnlyRepo(repo, model_repo) + any_repo = AnyModelRepo(model_repo, bag_repo) + model = any_repo.get_model(name) + model.eval() + return model + +def get_model_from_args(args): + """ + Load local model package or pre-trained model. + """ + return get_model(name=args.name, repo=args.repo) + +logger = logging.getLogger(__name__) +ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/" + +PRETRAINED_MODELS = { + 'demucs': 'e07c671f', + 'demucs48_hq': '28a1282c', + 'demucs_extra': '3646af93', + 'demucs_quantized': '07afea75', + 'tasnet': 'beb46fac', + 'tasnet_extra': 'df3777b2', + 'demucs_unittest': '09ebc15f', +} + +SOURCES = ["drums", "bass", "other", "vocals"] + + +def get_url(name): + sig = PRETRAINED_MODELS[name] + return ROOT + name + "-" + sig[:8] + ".th" + +def is_pretrained(name): + return name in PRETRAINED_MODELS + + +def load_pretrained(name): + if name == "demucs": + return demucs(pretrained=True) + elif name == "demucs48_hq": + return demucs(pretrained=True, hq=True, channels=48) + elif name == "demucs_extra": + return demucs(pretrained=True, extra=True) + elif name == "demucs_quantized": + return demucs(pretrained=True, quantized=True) + elif name == "demucs_unittest": + return demucs_unittest(pretrained=True) + elif name == "tasnet": + return tasnet(pretrained=True) + elif name == "tasnet_extra": + return tasnet(pretrained=True, extra=True) + else: + raise ValueError(f"Invalid pretrained name {name}") + + +def _load_state(name, model, quantizer=None): + url = get_url(name) + state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) + set_state(model, quantizer, state) + if quantizer: + quantizer.detach() + + +def demucs_unittest(pretrained=True): + model = Demucs(channels=4, sources=SOURCES) + if pretrained: + _load_state('demucs_unittest', model) + return model + + +def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64): + if not pretrained and (extra or quantized or hq): + raise ValueError("if extra or quantized is True, pretrained must be True.") + model = Demucs(sources=SOURCES, channels=channels) + if pretrained: + name = 'demucs' + if channels != 64: + name += str(channels) + quantizer = None + if sum([extra, quantized, hq]) > 1: + raise ValueError("Only one of extra, quantized, hq, can be True.") + if quantized: + quantizer = DiffQuantizer(model, group_size=8, min_size=1) + name += '_quantized' + if extra: + name += '_extra' + if hq: + name += '_hq' + _load_state(name, model, quantizer) + return model + + +def tasnet(pretrained=True, extra=False): + if not pretrained and extra: + raise ValueError("if extra is True, pretrained must be True.") + model = ConvTasNet(X=10, sources=SOURCES) + if pretrained: + name = 'tasnet' + if extra: + name = 'tasnet_extra' + _load_state(name, model) + return model \ No newline at end of file diff --git a/demucs/repo.py b/demucs/repo.py new file mode 100644 index 0000000000000000000000000000000000000000..65ff6b33c7771b7743659d52151da67dc18082a8 --- /dev/null +++ b/demucs/repo.py @@ -0,0 +1,148 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Represents a model repository, including pre-trained models and bags of models. +A repo can either be the main remote repository stored in AWS, or a local repository +with your own models. +""" + +from hashlib import sha256 +from pathlib import Path +import typing as tp + +import torch +import yaml + +from .apply import BagOfModels, Model +from .states import load_model + + +AnyModel = tp.Union[Model, BagOfModels] + + +class ModelLoadingError(RuntimeError): + pass + + +def check_checksum(path: Path, checksum: str): + sha = sha256() + with open(path, 'rb') as file: + while True: + buf = file.read(2**20) + if not buf: + break + sha.update(buf) + actual_checksum = sha.hexdigest()[:len(checksum)] + if actual_checksum != checksum: + raise ModelLoadingError(f'Invalid checksum for file {path}, ' + f'expected {checksum} but got {actual_checksum}') + +class ModelOnlyRepo: + """Base class for all model only repos. + """ + def has_model(self, sig: str) -> bool: + raise NotImplementedError() + + def get_model(self, sig: str) -> Model: + raise NotImplementedError() + + +class RemoteRepo(ModelOnlyRepo): + def __init__(self, models: tp.Dict[str, str]): + self._models = models + + def has_model(self, sig: str) -> bool: + return sig in self._models + + def get_model(self, sig: str) -> Model: + try: + url = self._models[sig] + except KeyError: + raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.') + pkg = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) + return load_model(pkg) + + +class LocalRepo(ModelOnlyRepo): + def __init__(self, root: Path): + self.root = root + self.scan() + + def scan(self): + self._models = {} + self._checksums = {} + for file in self.root.iterdir(): + if file.suffix == '.th': + if '-' in file.stem: + xp_sig, checksum = file.stem.split('-') + self._checksums[xp_sig] = checksum + else: + xp_sig = file.stem + if xp_sig in self._models: + print('Whats xp? ', xp_sig) + raise ModelLoadingError( + f'Duplicate pre-trained model exist for signature {xp_sig}. ' + 'Please delete all but one.') + self._models[xp_sig] = file + + def has_model(self, sig: str) -> bool: + return sig in self._models + + def get_model(self, sig: str) -> Model: + try: + file = self._models[sig] + except KeyError: + raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.') + if sig in self._checksums: + check_checksum(file, self._checksums[sig]) + return load_model(file) + + +class BagOnlyRepo: + """Handles only YAML files containing bag of models, leaving the actual + model loading to some Repo. + """ + def __init__(self, root: Path, model_repo: ModelOnlyRepo): + self.root = root + self.model_repo = model_repo + self.scan() + + def scan(self): + self._bags = {} + for file in self.root.iterdir(): + if file.suffix == '.yaml': + self._bags[file.stem] = file + + def has_model(self, name: str) -> bool: + return name in self._bags + + def get_model(self, name: str) -> BagOfModels: + try: + yaml_file = self._bags[name] + except KeyError: + raise ModelLoadingError(f'{name} is neither a single pre-trained model or ' + 'a bag of models.') + bag = yaml.safe_load(open(yaml_file)) + signatures = bag['models'] + models = [self.model_repo.get_model(sig) for sig in signatures] + weights = bag.get('weights') + segment = bag.get('segment') + return BagOfModels(models, weights, segment) + + +class AnyModelRepo: + def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo): + self.model_repo = model_repo + self.bag_repo = bag_repo + + def has_model(self, name_or_sig: str) -> bool: + return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig) + + def get_model(self, name_or_sig: str) -> AnyModel: + print('name_or_sig: ', name_or_sig) + if self.model_repo.has_model(name_or_sig): + return self.model_repo.get_model(name_or_sig) + else: + return self.bag_repo.get_model(name_or_sig) diff --git a/demucs/spec.py b/demucs/spec.py new file mode 100644 index 0000000000000000000000000000000000000000..85e5dc9397df4466f4191f440bed8edf5bf01663 --- /dev/null +++ b/demucs/spec.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Conveniance wrapper to perform STFT and iSTFT""" + +import torch as th + + +def spectro(x, n_fft=512, hop_length=None, pad=0): + *other, length = x.shape + x = x.reshape(-1, length) + z = th.stft(x, + n_fft * (1 + pad), + hop_length or n_fft // 4, + window=th.hann_window(n_fft).to(x), + win_length=n_fft, + normalized=True, + center=True, + return_complex=True, + pad_mode='reflect') + _, freqs, frame = z.shape + return z.view(*other, freqs, frame) + + +def ispectro(z, hop_length=None, length=None, pad=0): + *other, freqs, frames = z.shape + n_fft = 2 * freqs - 2 + z = z.view(-1, freqs, frames) + win_length = n_fft // (1 + pad) + x = th.istft(z, + n_fft, + hop_length, + window=th.hann_window(win_length).to(z.real), + win_length=win_length, + normalized=True, + length=length, + center=True) + _, length = x.shape + return x.view(*other, length) diff --git a/demucs/states.py b/demucs/states.py new file mode 100644 index 0000000000000000000000000000000000000000..db17a182dc1b19f7934a1eccbe2f67437c2a78f1 --- /dev/null +++ b/demucs/states.py @@ -0,0 +1,148 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Utilities to save and load models. +""" +from contextlib import contextmanager + +import functools +import hashlib +import inspect +import io +from pathlib import Path +import warnings + +from omegaconf import OmegaConf +from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state +import torch + + +def get_quantizer(model, args, optimizer=None): + """Return the quantizer given the XP quantization args.""" + quantizer = None + if args.diffq: + quantizer = DiffQuantizer( + model, min_size=args.min_size, group_size=args.group_size) + if optimizer is not None: + quantizer.setup_optimizer(optimizer) + elif args.qat: + quantizer = UniformQuantizer( + model, bits=args.qat, min_size=args.min_size) + return quantizer + + +def load_model(path_or_package, strict=False): + """Load a model from the given serialized model, either given as a dict (already loaded) + or a path to a file on disk.""" + if isinstance(path_or_package, dict): + package = path_or_package + elif isinstance(path_or_package, (str, Path)): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + path = path_or_package + package = torch.load(path, 'cpu') + else: + raise ValueError(f"Invalid type for {path_or_package}.") + + klass = package["klass"] + args = package["args"] + kwargs = package["kwargs"] + + if strict: + model = klass(*args, **kwargs) + else: + sig = inspect.signature(klass) + for key in list(kwargs): + if key not in sig.parameters: + warnings.warn("Dropping inexistant parameter " + key) + del kwargs[key] + model = klass(*args, **kwargs) + + state = package["state"] + + set_state(model, state) + return model + + +def get_state(model, quantizer, half=False): + """Get the state from a model, potentially with quantization applied. + If `half` is True, model are stored as half precision, which shouldn't impact performance + but half the state size.""" + if quantizer is None: + dtype = torch.half if half else None + state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()} + else: + state = quantizer.get_quantized_state() + state['__quantized'] = True + return state + + +def set_state(model, state, quantizer=None): + """Set the state on a given model.""" + if state.get('__quantized'): + if quantizer is not None: + quantizer.restore_quantized_state(model, state['quantized']) + else: + restore_quantized_state(model, state) + else: + model.load_state_dict(state) + return state + + +def save_with_checksum(content, path): + """Save the given value on disk, along with a sha256 hash. + Should be used with the output of either `serialize_model` or `get_state`.""" + buf = io.BytesIO() + torch.save(content, buf) + sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] + + path = path.parent / (path.stem + "-" + sig + path.suffix) + path.write_bytes(buf.getvalue()) + + +def serialize_model(model, training_args, quantizer=None, half=True): + args, kwargs = model._init_args_kwargs + klass = model.__class__ + + state = get_state(model, quantizer, half) + return { + 'klass': klass, + 'args': args, + 'kwargs': kwargs, + 'state': state, + 'training_args': OmegaConf.to_container(training_args, resolve=True), + } + + +def copy_state(state): + return {k: v.cpu().clone() for k, v in state.items()} + + +@contextmanager +def swap_state(model, state): + """ + Context manager that swaps the state of a model, e.g: + + # model is in old state + with swap_state(model, new_state): + # model in new state + # model back to old state + """ + old_state = copy_state(model.state_dict()) + model.load_state_dict(state, strict=False) + try: + yield + finally: + model.load_state_dict(old_state) + + +def capture_init(init): + @functools.wraps(init) + def __init__(self, *args, **kwargs): + self._init_args_kwargs = (args, kwargs) + init(self, *args, **kwargs) + + return __init__ diff --git a/demucs/tasnet.py b/demucs/tasnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb7a957b44a21ca64cb580a02e2c33468f9cb8d --- /dev/null +++ b/demucs/tasnet.py @@ -0,0 +1,447 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Created on 2018/12 +# Author: Kaituo XU +# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels +# Here is the original license: +# The MIT License (MIT) +# +# Copyright (c) 2018 Kaituo XU +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import capture_init + +EPS = 1e-8 + + +def overlap_and_add(signal, frame_step): + outer_dimensions = signal.size()[:-2] + frames, frame_length = signal.size()[-2:] + + subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor + subframe_step = frame_step // subframe_length + subframes_per_frame = frame_length // subframe_length + output_size = frame_step * (frames - 1) + frame_length + output_subframes = output_size // subframe_length + + subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) + + frame = torch.arange(0, output_subframes, + device=signal.device).unfold(0, subframes_per_frame, subframe_step) + frame = frame.long() # signal may in GPU or CPU + frame = frame.contiguous().view(-1) + + result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) + result.index_add_(-2, frame, subframe_signal) + result = result.view(*outer_dimensions, -1) + return result + + +class ConvTasNet(nn.Module): + @capture_init + def __init__(self, + N=256, + L=20, + B=256, + H=512, + P=3, + X=8, + R=4, + C=4, + audio_channels=1, + samplerate=44100, + norm_type="gLN", + causal=False, + mask_nonlinear='relu'): + """ + Args: + N: Number of filters in autoencoder + L: Length of the filters (in samples) + B: Number of channels in bottleneck 1 × 1-conv block + H: Number of channels in convolutional blocks + P: Kernel size in convolutional blocks + X: Number of convolutional blocks in each repeat + R: Number of repeats + C: Number of speakers + norm_type: BN, gLN, cLN + causal: causal or non-causal + mask_nonlinear: use which non-linear function to generate mask + """ + super(ConvTasNet, self).__init__() + # Hyper-parameter + self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = N, L, B, H, P, X, R, C + self.norm_type = norm_type + self.causal = causal + self.mask_nonlinear = mask_nonlinear + self.audio_channels = audio_channels + self.samplerate = samplerate + # Components + self.encoder = Encoder(L, N, audio_channels) + self.separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear) + self.decoder = Decoder(N, L, audio_channels) + # init + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def valid_length(self, length): + return length + + def forward(self, mixture): + """ + Args: + mixture: [M, T], M is batch size, T is #samples + Returns: + est_source: [M, C, T] + """ + mixture_w = self.encoder(mixture) + est_mask = self.separator(mixture_w) + est_source = self.decoder(mixture_w, est_mask) + + # T changed after conv1d in encoder, fix it here + T_origin = mixture.size(-1) + T_conv = est_source.size(-1) + est_source = F.pad(est_source, (0, T_origin - T_conv)) + return est_source + + +class Encoder(nn.Module): + """Estimation of the nonnegative mixture weight by a 1-D conv layer. + """ + def __init__(self, L, N, audio_channels): + super(Encoder, self).__init__() + # Hyper-parameter + self.L, self.N = L, N + # Components + # 50% overlap + self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False) + + def forward(self, mixture): + """ + Args: + mixture: [M, T], M is batch size, T is #samples + Returns: + mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1 + """ + mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K] + return mixture_w + + +class Decoder(nn.Module): + def __init__(self, N, L, audio_channels): + super(Decoder, self).__init__() + # Hyper-parameter + self.N, self.L = N, L + self.audio_channels = audio_channels + # Components + self.basis_signals = nn.Linear(N, audio_channels * L, bias=False) + + def forward(self, mixture_w, est_mask): + """ + Args: + mixture_w: [M, N, K] + est_mask: [M, C, N, K] + Returns: + est_source: [M, C, T] + """ + # D = W * M + source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K] + source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N] + # S = DV + est_source = self.basis_signals(source_w) # [M, C, K, ac * L] + m, c, k, _ = est_source.size() + est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous() + est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T + return est_source + + +class TemporalConvNet(nn.Module): + def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'): + """ + Args: + N: Number of filters in autoencoder + B: Number of channels in bottleneck 1 × 1-conv block + H: Number of channels in convolutional blocks + P: Kernel size in convolutional blocks + X: Number of convolutional blocks in each repeat + R: Number of repeats + C: Number of speakers + norm_type: BN, gLN, cLN + causal: causal or non-causal + mask_nonlinear: use which non-linear function to generate mask + """ + super(TemporalConvNet, self).__init__() + # Hyper-parameter + self.C = C + self.mask_nonlinear = mask_nonlinear + # Components + # [M, N, K] -> [M, N, K] + layer_norm = ChannelwiseLayerNorm(N) + # [M, N, K] -> [M, B, K] + bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False) + # [M, B, K] -> [M, B, K] + repeats = [] + for r in range(R): + blocks = [] + for x in range(X): + dilation = 2**x + padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2 + blocks += [ + TemporalBlock(B, + H, + P, + stride=1, + padding=padding, + dilation=dilation, + norm_type=norm_type, + causal=causal) + ] + repeats += [nn.Sequential(*blocks)] + temporal_conv_net = nn.Sequential(*repeats) + # [M, B, K] -> [M, C*N, K] + mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False) + # Put together + self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net, + mask_conv1x1) + + def forward(self, mixture_w): + """ + Keep this API same with TasNet + Args: + mixture_w: [M, N, K], M is batch size + returns: + est_mask: [M, C, N, K] + """ + M, N, K = mixture_w.size() + score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K] + score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K] + if self.mask_nonlinear == 'softmax': + est_mask = F.softmax(score, dim=1) + elif self.mask_nonlinear == 'relu': + est_mask = F.relu(score) + else: + raise ValueError("Unsupported mask non-linear function") + return est_mask + + +class TemporalBlock(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + norm_type="gLN", + causal=False): + super(TemporalBlock, self).__init__() + # [M, B, K] -> [M, H, K] + conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False) + prelu = nn.PReLU() + norm = chose_norm(norm_type, out_channels) + # [M, H, K] -> [M, B, K] + dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding, + dilation, norm_type, causal) + # Put together + self.net = nn.Sequential(conv1x1, prelu, norm, dsconv) + + def forward(self, x): + """ + Args: + x: [M, B, K] + Returns: + [M, B, K] + """ + residual = x + out = self.net(x) + # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad? + return out + residual # look like w/o F.relu is better than w/ F.relu + # return F.relu(out + residual) + + +class DepthwiseSeparableConv(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + norm_type="gLN", + causal=False): + super(DepthwiseSeparableConv, self).__init__() + # Use `groups` option to implement depthwise convolution + # [M, H, K] -> [M, H, K] + depthwise_conv = nn.Conv1d(in_channels, + in_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=False) + if causal: + chomp = Chomp1d(padding) + prelu = nn.PReLU() + norm = chose_norm(norm_type, in_channels) + # [M, H, K] -> [M, B, K] + pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False) + # Put together + if causal: + self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv) + else: + self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv) + + def forward(self, x): + """ + Args: + x: [M, H, K] + Returns: + result: [M, B, K] + """ + return self.net(x) + + +class Chomp1d(nn.Module): + """To ensure the output length is the same as the input. + """ + def __init__(self, chomp_size): + super(Chomp1d, self).__init__() + self.chomp_size = chomp_size + + def forward(self, x): + """ + Args: + x: [M, H, Kpad] + Returns: + [M, H, K] + """ + return x[:, :, :-self.chomp_size].contiguous() + + +def chose_norm(norm_type, channel_size): + """The input of normlization will be (M, C, K), where M is batch size, + C is channel size and K is sequence length. + """ + if norm_type == "gLN": + return GlobalLayerNorm(channel_size) + elif norm_type == "cLN": + return ChannelwiseLayerNorm(channel_size) + elif norm_type == "id": + return nn.Identity() + else: # norm_type == "BN": + # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics + # along M and K, so this BN usage is right. + return nn.BatchNorm1d(channel_size) + + +# TODO: Use nn.LayerNorm to impl cLN to speed up +class ChannelwiseLayerNorm(nn.Module): + """Channel-wise Layer Normalization (cLN)""" + def __init__(self, channel_size): + super(ChannelwiseLayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.reset_parameters() + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + def forward(self, y): + """ + Args: + y: [M, N, K], M is batch size, N is channel size, K is length + Returns: + cLN_y: [M, N, K] + """ + mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K] + var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K] + cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + return cLN_y + + +class GlobalLayerNorm(nn.Module): + """Global Layer Normalization (gLN)""" + def __init__(self, channel_size): + super(GlobalLayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.reset_parameters() + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + def forward(self, y): + """ + Args: + y: [M, N, K], M is batch size, N is channel size, K is length + Returns: + gLN_y: [M, N, K] + """ + # TODO: in torch 1.0, torch.mean() support dim list + mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1] + var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) + gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + return gLN_y + + +if __name__ == "__main__": + torch.manual_seed(123) + M, N, L, T = 2, 3, 4, 12 + K = 2 * T // L - 1 + B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False + mixture = torch.randint(3, (M, T)) + # test Encoder + encoder = Encoder(L, N) + encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size()) + mixture_w = encoder(mixture) + print('mixture', mixture) + print('U', encoder.conv1d_U.weight) + print('mixture_w', mixture_w) + print('mixture_w size', mixture_w.size()) + + # test TemporalConvNet + separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal) + est_mask = separator(mixture_w) + print('est_mask', est_mask) + + # test Decoder + decoder = Decoder(N, L) + est_mask = torch.randint(2, (B, K, C, N)) + est_source = decoder(mixture_w, est_mask) + print('est_source', est_source) + + # test Conv-TasNet + conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type) + est_source = conv_tasnet(mixture) + print('est_source', est_source) + print('est_source size', est_source.size()) diff --git a/demucs/tasnet_v2.py b/demucs/tasnet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..ecc1257925ea8f4fbe389ddd6d73ce9fdf45f6d4 --- /dev/null +++ b/demucs/tasnet_v2.py @@ -0,0 +1,452 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# Created on 2018/12 +# Author: Kaituo XU +# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels +# Here is the original license: +# The MIT License (MIT) +# +# Copyright (c) 2018 Kaituo XU +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import capture_init + +EPS = 1e-8 + + +def overlap_and_add(signal, frame_step): + outer_dimensions = signal.size()[:-2] + frames, frame_length = signal.size()[-2:] + + subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor + subframe_step = frame_step // subframe_length + subframes_per_frame = frame_length // subframe_length + output_size = frame_step * (frames - 1) + frame_length + output_subframes = output_size // subframe_length + + subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) + + frame = torch.arange(0, output_subframes, + device=signal.device).unfold(0, subframes_per_frame, subframe_step) + frame = frame.long() # signal may in GPU or CPU + frame = frame.contiguous().view(-1) + + result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) + result.index_add_(-2, frame, subframe_signal) + result = result.view(*outer_dimensions, -1) + return result + + +class ConvTasNet(nn.Module): + @capture_init + def __init__(self, + sources, + N=256, + L=20, + B=256, + H=512, + P=3, + X=8, + R=4, + audio_channels=2, + norm_type="gLN", + causal=False, + mask_nonlinear='relu', + samplerate=44100, + segment_length=44100 * 2 * 4): + """ + Args: + sources: list of sources + N: Number of filters in autoencoder + L: Length of the filters (in samples) + B: Number of channels in bottleneck 1 × 1-conv block + H: Number of channels in convolutional blocks + P: Kernel size in convolutional blocks + X: Number of convolutional blocks in each repeat + R: Number of repeats + norm_type: BN, gLN, cLN + causal: causal or non-causal + mask_nonlinear: use which non-linear function to generate mask + """ + super(ConvTasNet, self).__init__() + # Hyper-parameter + self.sources = sources + self.C = len(sources) + self.N, self.L, self.B, self.H, self.P, self.X, self.R = N, L, B, H, P, X, R + self.norm_type = norm_type + self.causal = causal + self.mask_nonlinear = mask_nonlinear + self.audio_channels = audio_channels + self.samplerate = samplerate + self.segment_length = segment_length + # Components + self.encoder = Encoder(L, N, audio_channels) + self.separator = TemporalConvNet( + N, B, H, P, X, R, self.C, norm_type, causal, mask_nonlinear) + self.decoder = Decoder(N, L, audio_channels) + # init + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def valid_length(self, length): + return length + + def forward(self, mixture): + """ + Args: + mixture: [M, T], M is batch size, T is #samples + Returns: + est_source: [M, C, T] + """ + mixture_w = self.encoder(mixture) + est_mask = self.separator(mixture_w) + est_source = self.decoder(mixture_w, est_mask) + + # T changed after conv1d in encoder, fix it here + T_origin = mixture.size(-1) + T_conv = est_source.size(-1) + est_source = F.pad(est_source, (0, T_origin - T_conv)) + return est_source + + +class Encoder(nn.Module): + """Estimation of the nonnegative mixture weight by a 1-D conv layer. + """ + def __init__(self, L, N, audio_channels): + super(Encoder, self).__init__() + # Hyper-parameter + self.L, self.N = L, N + # Components + # 50% overlap + self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False) + + def forward(self, mixture): + """ + Args: + mixture: [M, T], M is batch size, T is #samples + Returns: + mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1 + """ + mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K] + return mixture_w + + +class Decoder(nn.Module): + def __init__(self, N, L, audio_channels): + super(Decoder, self).__init__() + # Hyper-parameter + self.N, self.L = N, L + self.audio_channels = audio_channels + # Components + self.basis_signals = nn.Linear(N, audio_channels * L, bias=False) + + def forward(self, mixture_w, est_mask): + """ + Args: + mixture_w: [M, N, K] + est_mask: [M, C, N, K] + Returns: + est_source: [M, C, T] + """ + # D = W * M + source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K] + source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N] + # S = DV + est_source = self.basis_signals(source_w) # [M, C, K, ac * L] + m, c, k, _ = est_source.size() + est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous() + est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T + return est_source + + +class TemporalConvNet(nn.Module): + def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'): + """ + Args: + N: Number of filters in autoencoder + B: Number of channels in bottleneck 1 × 1-conv block + H: Number of channels in convolutional blocks + P: Kernel size in convolutional blocks + X: Number of convolutional blocks in each repeat + R: Number of repeats + C: Number of speakers + norm_type: BN, gLN, cLN + causal: causal or non-causal + mask_nonlinear: use which non-linear function to generate mask + """ + super(TemporalConvNet, self).__init__() + # Hyper-parameter + self.C = C + self.mask_nonlinear = mask_nonlinear + # Components + # [M, N, K] -> [M, N, K] + layer_norm = ChannelwiseLayerNorm(N) + # [M, N, K] -> [M, B, K] + bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False) + # [M, B, K] -> [M, B, K] + repeats = [] + for r in range(R): + blocks = [] + for x in range(X): + dilation = 2**x + padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2 + blocks += [ + TemporalBlock(B, + H, + P, + stride=1, + padding=padding, + dilation=dilation, + norm_type=norm_type, + causal=causal) + ] + repeats += [nn.Sequential(*blocks)] + temporal_conv_net = nn.Sequential(*repeats) + # [M, B, K] -> [M, C*N, K] + mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False) + # Put together + self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net, + mask_conv1x1) + + def forward(self, mixture_w): + """ + Keep this API same with TasNet + Args: + mixture_w: [M, N, K], M is batch size + returns: + est_mask: [M, C, N, K] + """ + M, N, K = mixture_w.size() + score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K] + score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K] + if self.mask_nonlinear == 'softmax': + est_mask = F.softmax(score, dim=1) + elif self.mask_nonlinear == 'relu': + est_mask = F.relu(score) + else: + raise ValueError("Unsupported mask non-linear function") + return est_mask + + +class TemporalBlock(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + norm_type="gLN", + causal=False): + super(TemporalBlock, self).__init__() + # [M, B, K] -> [M, H, K] + conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False) + prelu = nn.PReLU() + norm = chose_norm(norm_type, out_channels) + # [M, H, K] -> [M, B, K] + dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding, + dilation, norm_type, causal) + # Put together + self.net = nn.Sequential(conv1x1, prelu, norm, dsconv) + + def forward(self, x): + """ + Args: + x: [M, B, K] + Returns: + [M, B, K] + """ + residual = x + out = self.net(x) + # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad? + return out + residual # look like w/o F.relu is better than w/ F.relu + # return F.relu(out + residual) + + +class DepthwiseSeparableConv(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + norm_type="gLN", + causal=False): + super(DepthwiseSeparableConv, self).__init__() + # Use `groups` option to implement depthwise convolution + # [M, H, K] -> [M, H, K] + depthwise_conv = nn.Conv1d(in_channels, + in_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=False) + if causal: + chomp = Chomp1d(padding) + prelu = nn.PReLU() + norm = chose_norm(norm_type, in_channels) + # [M, H, K] -> [M, B, K] + pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False) + # Put together + if causal: + self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv) + else: + self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv) + + def forward(self, x): + """ + Args: + x: [M, H, K] + Returns: + result: [M, B, K] + """ + return self.net(x) + + +class Chomp1d(nn.Module): + """To ensure the output length is the same as the input. + """ + def __init__(self, chomp_size): + super(Chomp1d, self).__init__() + self.chomp_size = chomp_size + + def forward(self, x): + """ + Args: + x: [M, H, Kpad] + Returns: + [M, H, K] + """ + return x[:, :, :-self.chomp_size].contiguous() + + +def chose_norm(norm_type, channel_size): + """The input of normlization will be (M, C, K), where M is batch size, + C is channel size and K is sequence length. + """ + if norm_type == "gLN": + return GlobalLayerNorm(channel_size) + elif norm_type == "cLN": + return ChannelwiseLayerNorm(channel_size) + elif norm_type == "id": + return nn.Identity() + else: # norm_type == "BN": + # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics + # along M and K, so this BN usage is right. + return nn.BatchNorm1d(channel_size) + + +# TODO: Use nn.LayerNorm to impl cLN to speed up +class ChannelwiseLayerNorm(nn.Module): + """Channel-wise Layer Normalization (cLN)""" + def __init__(self, channel_size): + super(ChannelwiseLayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.reset_parameters() + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + def forward(self, y): + """ + Args: + y: [M, N, K], M is batch size, N is channel size, K is length + Returns: + cLN_y: [M, N, K] + """ + mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K] + var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K] + cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + return cLN_y + + +class GlobalLayerNorm(nn.Module): + """Global Layer Normalization (gLN)""" + def __init__(self, channel_size): + super(GlobalLayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.reset_parameters() + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + def forward(self, y): + """ + Args: + y: [M, N, K], M is batch size, N is channel size, K is length + Returns: + gLN_y: [M, N, K] + """ + # TODO: in torch 1.0, torch.mean() support dim list + mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1] + var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) + gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + return gLN_y + + +if __name__ == "__main__": + torch.manual_seed(123) + M, N, L, T = 2, 3, 4, 12 + K = 2 * T // L - 1 + B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False + mixture = torch.randint(3, (M, T)) + # test Encoder + encoder = Encoder(L, N) + encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size()) + mixture_w = encoder(mixture) + print('mixture', mixture) + print('U', encoder.conv1d_U.weight) + print('mixture_w', mixture_w) + print('mixture_w size', mixture_w.size()) + + # test TemporalConvNet + separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal) + est_mask = separator(mixture_w) + print('est_mask', est_mask) + + # test Decoder + decoder = Decoder(N, L) + est_mask = torch.randint(2, (B, K, C, N)) + est_source = decoder(mixture_w, est_mask) + print('est_source', est_source) + + # test Conv-TasNet + conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type) + est_source = conv_tasnet(mixture) + print('est_source', est_source) + print('est_source size', est_source.size()) diff --git a/demucs/transformer.py b/demucs/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..56a465b861d7d018d0eca2779bbd392f07e411a9 --- /dev/null +++ b/demucs/transformer.py @@ -0,0 +1,839 @@ +# Copyright (c) 2019-present, Meta, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# First author is Simon Rouard. + +import random +import typing as tp + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +from einops import rearrange + + +def create_sin_embedding( + length: int, dim: int, shift: int = 0, device="cpu", max_period=10000 +): + # We aim for TBC format + assert dim % 2 == 0 + pos = shift + torch.arange(length, device=device).view(-1, 1, 1) + half_dim = dim // 2 + adim = torch.arange(dim // 2, device=device).view(1, 1, -1) + phase = pos / (max_period ** (adim / (half_dim - 1))) + return torch.cat( + [ + torch.cos(phase), + torch.sin(phase), + ], + dim=-1, + ) + + +def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000): + """ + :param d_model: dimension of the model + :param height: height of the positions + :param width: width of the positions + :return: d_model*height*width position matrix + """ + if d_model % 4 != 0: + raise ValueError( + "Cannot use sin/cos positional encoding with " + "odd dimension (got dim={:d})".format(d_model) + ) + pe = torch.zeros(d_model, height, width) + # Each dimension use half of d_model + d_model = int(d_model / 2) + div_term = torch.exp( + torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model) + ) + pos_w = torch.arange(0.0, width).unsqueeze(1) + pos_h = torch.arange(0.0, height).unsqueeze(1) + pe[0:d_model:2, :, :] = ( + torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) + ) + pe[1:d_model:2, :, :] = ( + torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) + ) + pe[d_model::2, :, :] = ( + torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) + ) + pe[d_model + 1:: 2, :, :] = ( + torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) + ) + + return pe[None, :].to(device) + + +def create_sin_embedding_cape( + length: int, + dim: int, + batch_size: int, + mean_normalize: bool, + augment: bool, # True during training + max_global_shift: float = 0.0, # delta max + max_local_shift: float = 0.0, # epsilon max + max_scale: float = 1.0, + device: str = "cpu", + max_period: float = 10000.0, +): + # We aim for TBC format + assert dim % 2 == 0 + pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1) + pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1) + if mean_normalize: + pos -= torch.nanmean(pos, dim=0, keepdim=True) + + if augment: + delta = np.random.uniform( + -max_global_shift, +max_global_shift, size=[1, batch_size, 1] + ) + delta_local = np.random.uniform( + -max_local_shift, +max_local_shift, size=[length, batch_size, 1] + ) + log_lambdas = np.random.uniform( + -np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1] + ) + pos = (pos + delta + delta_local) * np.exp(log_lambdas) + + pos = pos.to(device) + + half_dim = dim // 2 + adim = torch.arange(dim // 2, device=device).view(1, 1, -1) + phase = pos / (max_period ** (adim / (half_dim - 1))) + return torch.cat( + [ + torch.cos(phase), + torch.sin(phase), + ], + dim=-1, + ).float() + + +def get_causal_mask(length): + pos = torch.arange(length) + return pos > pos[:, None] + + +def get_elementary_mask( + T1, + T2, + mask_type, + sparse_attn_window, + global_window, + mask_random_seed, + sparsity, + device, +): + """ + When the input of the Decoder has length T1 and the output T2 + The mask matrix has shape (T2, T1) + """ + assert mask_type in ["diag", "jmask", "random", "global"] + + if mask_type == "global": + mask = torch.zeros(T2, T1, dtype=torch.bool) + mask[:, :global_window] = True + line_window = int(global_window * T2 / T1) + mask[:line_window, :] = True + + if mask_type == "diag": + + mask = torch.zeros(T2, T1, dtype=torch.bool) + rows = torch.arange(T2)[:, None] + cols = ( + (T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1)) + .long() + .clamp(0, T1 - 1) + ) + mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols)) + + elif mask_type == "jmask": + mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool) + rows = torch.arange(T2 + 2)[:, None] + t = torch.arange(0, int((2 * T1) ** 0.5 + 1)) + t = (t * (t + 1) / 2).int() + t = torch.cat([-t.flip(0)[:-1], t]) + cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1) + mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols)) + mask = mask[1:-1, 1:-1] + + elif mask_type == "random": + gene = torch.Generator(device=device) + gene.manual_seed(mask_random_seed) + mask = ( + torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1) + > sparsity + ) + + mask = mask.to(device) + return mask + + +def get_mask( + T1, + T2, + mask_type, + sparse_attn_window, + global_window, + mask_random_seed, + sparsity, + device, +): + """ + Return a SparseCSRTensor mask that is a combination of elementary masks + mask_type can be a combination of multiple masks: for instance "diag_jmask_random" + """ + from xformers.sparse import SparseCSRTensor + # create a list + mask_types = mask_type.split("_") + + all_masks = [ + get_elementary_mask( + T1, + T2, + mask, + sparse_attn_window, + global_window, + mask_random_seed, + sparsity, + device, + ) + for mask in mask_types + ] + + final_mask = torch.stack(all_masks).sum(axis=0) > 0 + + return SparseCSRTensor.from_dense(final_mask[None]) + + +class ScaledEmbedding(nn.Module): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + scale: float = 1.0, + boost: float = 3.0, + ): + super().__init__() + self.embedding = nn.Embedding(num_embeddings, embedding_dim) + self.embedding.weight.data *= scale / boost + self.boost = boost + + @property + def weight(self): + return self.embedding.weight * self.boost + + def forward(self, x): + return self.embedding(x) * self.boost + + +class LayerScale(nn.Module): + """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). + This rescales diagonaly residual outputs close to 0 initially, then learnt. + """ + + def __init__(self, channels: int, init: float = 0, channel_last=False): + """ + channel_last = False corresponds to (B, C, T) tensors + channel_last = True corresponds to (T, B, C) tensors + """ + super().__init__() + self.channel_last = channel_last + self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) + self.scale.data[:] = init + + def forward(self, x): + if self.channel_last: + return self.scale * x + else: + return self.scale[:, None] * x + + +class MyGroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + """ + x: (B, T, C) + if num_groups=1: Normalisation on all T and C together for each B + """ + x = x.transpose(1, 2) + return super().forward(x).transpose(1, 2) + + +class MyTransformerEncoderLayer(nn.TransformerEncoderLayer): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation=F.relu, + group_norm=0, + norm_first=False, + norm_out=False, + layer_norm_eps=1e-5, + layer_scale=False, + init_values=1e-4, + device=None, + dtype=None, + sparse=False, + mask_type="diag", + mask_random_seed=42, + sparse_attn_window=500, + global_window=50, + auto_sparsity=False, + sparsity=0.95, + batch_first=False, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + batch_first=batch_first, + norm_first=norm_first, + device=device, + dtype=dtype, + ) + self.sparse = sparse + self.auto_sparsity = auto_sparsity + if sparse: + if not auto_sparsity: + self.mask_type = mask_type + self.sparse_attn_window = sparse_attn_window + self.global_window = global_window + self.sparsity = sparsity + if group_norm: + self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + + self.norm_out = None + if self.norm_first & norm_out: + self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model) + self.gamma_1 = ( + LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() + ) + self.gamma_2 = ( + LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() + ) + + if sparse: + self.self_attn = MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first, + auto_sparsity=sparsity if auto_sparsity else 0, + ) + self.__setattr__("src_mask", torch.zeros(1, 1)) + self.mask_random_seed = mask_random_seed + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + """ + if batch_first = False, src shape is (T, B, C) + the case where batch_first=True is not covered + """ + device = src.device + x = src + T, B, C = x.shape + if self.sparse and not self.auto_sparsity: + assert src_mask is None + src_mask = self.src_mask + if src_mask.shape[-1] != T: + src_mask = get_mask( + T, + T, + self.mask_type, + self.sparse_attn_window, + self.global_window, + self.mask_random_seed, + self.sparsity, + device, + ) + self.__setattr__("src_mask", src_mask) + + if self.norm_first: + x = x + self.gamma_1( + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) + ) + x = x + self.gamma_2(self._ff_block(self.norm2(x))) + + if self.norm_out: + x = self.norm_out(x) + else: + x = self.norm1( + x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask)) + ) + x = self.norm2(x + self.gamma_2(self._ff_block(x))) + + return x + + +class CrossTransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation=F.relu, + layer_norm_eps: float = 1e-5, + layer_scale: bool = False, + init_values: float = 1e-4, + norm_first: bool = False, + group_norm: bool = False, + norm_out: bool = False, + sparse=False, + mask_type="diag", + mask_random_seed=42, + sparse_attn_window=500, + global_window=50, + sparsity=0.95, + auto_sparsity=None, + device=None, + dtype=None, + batch_first=False, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.sparse = sparse + self.auto_sparsity = auto_sparsity + if sparse: + if not auto_sparsity: + self.mask_type = mask_type + self.sparse_attn_window = sparse_attn_window + self.global_window = global_window + self.sparsity = sparsity + + self.cross_attn: nn.Module + self.cross_attn = nn.MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) + + self.norm_first = norm_first + self.norm1: nn.Module + self.norm2: nn.Module + self.norm3: nn.Module + if group_norm: + self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) + else: + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + + self.norm_out = None + if self.norm_first & norm_out: + self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model) + + self.gamma_1 = ( + LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() + ) + self.gamma_2 = ( + LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + self.activation = self._get_activation_fn(activation) + else: + self.activation = activation + + if sparse: + self.cross_attn = MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first, + auto_sparsity=sparsity if auto_sparsity else 0) + if not auto_sparsity: + self.__setattr__("mask", torch.zeros(1, 1)) + self.mask_random_seed = mask_random_seed + + def forward(self, q, k, mask=None): + """ + Args: + q: tensor of shape (T, B, C) + k: tensor of shape (S, B, C) + mask: tensor of shape (T, S) + + """ + device = q.device + T, B, C = q.shape + S, B, C = k.shape + if self.sparse and not self.auto_sparsity: + assert mask is None + mask = self.mask + if mask.shape[-1] != S or mask.shape[-2] != T: + mask = get_mask( + S, + T, + self.mask_type, + self.sparse_attn_window, + self.global_window, + self.mask_random_seed, + self.sparsity, + device, + ) + self.__setattr__("mask", mask) + + if self.norm_first: + x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask)) + x = x + self.gamma_2(self._ff_block(self.norm3(x))) + if self.norm_out: + x = self.norm_out(x) + else: + x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask))) + x = self.norm2(x + self.gamma_2(self._ff_block(x))) + + return x + + # self-attention block + def _ca_block(self, q, k, attn_mask=None): + x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x): + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + def _get_activation_fn(self, activation): + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + + +# ----------------- MULTI-BLOCKS MODELS: ----------------------- + + +class CrossTransformerEncoder(nn.Module): + def __init__( + self, + dim: int, + emb: str = "sin", + hidden_scale: float = 4.0, + num_heads: int = 8, + num_layers: int = 6, + cross_first: bool = False, + dropout: float = 0.0, + max_positions: int = 1000, + norm_in: bool = True, + norm_in_group: bool = False, + group_norm: int = False, + norm_first: bool = False, + norm_out: bool = False, + max_period: float = 10000.0, + weight_decay: float = 0.0, + lr: tp.Optional[float] = None, + layer_scale: bool = False, + gelu: bool = True, + sin_random_shift: int = 0, + weight_pos_embed: float = 1.0, + cape_mean_normalize: bool = True, + cape_augment: bool = True, + cape_glob_loc_scale: list = [5000.0, 1.0, 1.4], + sparse_self_attn: bool = False, + sparse_cross_attn: bool = False, + mask_type: str = "diag", + mask_random_seed: int = 42, + sparse_attn_window: int = 500, + global_window: int = 50, + auto_sparsity: bool = False, + sparsity: float = 0.95, + ): + super().__init__() + """ + """ + assert dim % num_heads == 0 + + hidden_dim = int(dim * hidden_scale) + + self.num_layers = num_layers + # classic parity = 1 means that if idx%2 == 1 there is a + # classical encoder else there is a cross encoder + self.classic_parity = 1 if cross_first else 0 + self.emb = emb + self.max_period = max_period + self.weight_decay = weight_decay + self.weight_pos_embed = weight_pos_embed + self.sin_random_shift = sin_random_shift + if emb == "cape": + self.cape_mean_normalize = cape_mean_normalize + self.cape_augment = cape_augment + self.cape_glob_loc_scale = cape_glob_loc_scale + if emb == "scaled": + self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2) + + self.lr = lr + + activation: tp.Any = F.gelu if gelu else F.relu + + self.norm_in: nn.Module + self.norm_in_t: nn.Module + if norm_in: + self.norm_in = nn.LayerNorm(dim) + self.norm_in_t = nn.LayerNorm(dim) + elif norm_in_group: + self.norm_in = MyGroupNorm(int(norm_in_group), dim) + self.norm_in_t = MyGroupNorm(int(norm_in_group), dim) + else: + self.norm_in = nn.Identity() + self.norm_in_t = nn.Identity() + + # spectrogram layers + self.layers = nn.ModuleList() + # temporal layers + self.layers_t = nn.ModuleList() + + kwargs_common = { + "d_model": dim, + "nhead": num_heads, + "dim_feedforward": hidden_dim, + "dropout": dropout, + "activation": activation, + "group_norm": group_norm, + "norm_first": norm_first, + "norm_out": norm_out, + "layer_scale": layer_scale, + "mask_type": mask_type, + "mask_random_seed": mask_random_seed, + "sparse_attn_window": sparse_attn_window, + "global_window": global_window, + "sparsity": sparsity, + "auto_sparsity": auto_sparsity, + "batch_first": True, + } + + kwargs_classic_encoder = dict(kwargs_common) + kwargs_classic_encoder.update({ + "sparse": sparse_self_attn, + }) + kwargs_cross_encoder = dict(kwargs_common) + kwargs_cross_encoder.update({ + "sparse": sparse_cross_attn, + }) + + for idx in range(num_layers): + if idx % 2 == self.classic_parity: + + self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder)) + self.layers_t.append( + MyTransformerEncoderLayer(**kwargs_classic_encoder) + ) + + else: + self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder)) + + self.layers_t.append( + CrossTransformerEncoderLayer(**kwargs_cross_encoder) + ) + + def forward(self, x, xt): + B, C, Fr, T1 = x.shape + pos_emb_2d = create_2d_sin_embedding( + C, Fr, T1, x.device, self.max_period + ) # (1, C, Fr, T1) + pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c") + x = rearrange(x, "b c fr t1 -> b (t1 fr) c") + x = self.norm_in(x) + x = x + self.weight_pos_embed * pos_emb_2d + + B, C, T2 = xt.shape + xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C + pos_emb = self._get_pos_embedding(T2, B, C, x.device) + pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c") + xt = self.norm_in_t(xt) + xt = xt + self.weight_pos_embed * pos_emb + + for idx in range(self.num_layers): + if idx % 2 == self.classic_parity: + x = self.layers[idx](x) + xt = self.layers_t[idx](xt) + else: + old_x = x + x = self.layers[idx](x, xt) + xt = self.layers_t[idx](xt, old_x) + + x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1) + xt = rearrange(xt, "b t2 c -> b c t2") + return x, xt + + def _get_pos_embedding(self, T, B, C, device): + if self.emb == "sin": + shift = random.randrange(self.sin_random_shift + 1) + pos_emb = create_sin_embedding( + T, C, shift=shift, device=device, max_period=self.max_period + ) + elif self.emb == "cape": + if self.training: + pos_emb = create_sin_embedding_cape( + T, + C, + B, + device=device, + max_period=self.max_period, + mean_normalize=self.cape_mean_normalize, + augment=self.cape_augment, + max_global_shift=self.cape_glob_loc_scale[0], + max_local_shift=self.cape_glob_loc_scale[1], + max_scale=self.cape_glob_loc_scale[2], + ) + else: + pos_emb = create_sin_embedding_cape( + T, + C, + B, + device=device, + max_period=self.max_period, + mean_normalize=self.cape_mean_normalize, + augment=False, + ) + + elif self.emb == "scaled": + pos = torch.arange(T, device=device) + pos_emb = self.position_embeddings(pos)[:, None] + + return pos_emb + + def make_optim_group(self): + group = {"params": list(self.parameters()), "weight_decay": self.weight_decay} + if self.lr is not None: + group["lr"] = self.lr + return group + + +# Attention Modules + + +class MultiheadAttention(nn.Module): + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + auto_sparsity=None, + ): + super().__init__() + assert auto_sparsity is not None, "sanity check" + self.num_heads = num_heads + self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias) + self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias) + self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias) + self.attn_drop = torch.nn.Dropout(dropout) + self.proj = torch.nn.Linear(embed_dim, embed_dim, bias) + self.proj_drop = torch.nn.Dropout(dropout) + self.batch_first = batch_first + self.auto_sparsity = auto_sparsity + + def forward( + self, + query, + key, + value, + key_padding_mask=None, + need_weights=True, + attn_mask=None, + average_attn_weights=True, + ): + + if not self.batch_first: # N, B, C + query = query.permute(1, 0, 2) # B, N_q, C + key = key.permute(1, 0, 2) # B, N_k, C + value = value.permute(1, 0, 2) # B, N_k, C + B, N_q, C = query.shape + B, N_k, C = key.shape + + q = ( + self.q(query) + .reshape(B, N_q, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + q = q.flatten(0, 1) + k = ( + self.k(key) + .reshape(B, N_k, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + k = k.flatten(0, 1) + v = ( + self.v(value) + .reshape(B, N_k, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + v = v.flatten(0, 1) + + if self.auto_sparsity: + assert attn_mask is None + x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity) + else: + x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop) + x = x.reshape(B, self.num_heads, N_q, C // self.num_heads) + + x = x.transpose(1, 2).reshape(B, N_q, C) + x = self.proj(x) + x = self.proj_drop(x) + if not self.batch_first: + x = x.permute(1, 0, 2) + return x, None + + +def scaled_query_key_softmax(q, k, att_mask): + from xformers.ops import masked_matmul + q = q / (k.size(-1)) ** 0.5 + att = masked_matmul(q, k.transpose(-2, -1), att_mask) + att = torch.nn.functional.softmax(att, -1) + return att + + +def scaled_dot_product_attention(q, k, v, att_mask, dropout): + att = scaled_query_key_softmax(q, k, att_mask=att_mask) + att = dropout(att) + y = att @ v + return y + + +def _compute_buckets(x, R): + qq = torch.einsum('btf,bfhi->bhti', x, R) + qq = torch.cat([qq, -qq], dim=-1) + buckets = qq.argmax(dim=-1) + + return buckets.permute(0, 2, 1).byte().contiguous() + + +def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None): + # assert False, "The code for the custom sparse kernel is not ready for release yet." + from xformers.ops import find_locations, sparse_memory_efficient_attention + n_hashes = 32 + proj_size = 4 + query, key, value = [x.contiguous() for x in [query, key, value]] + with torch.no_grad(): + R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device) + bucket_query = _compute_buckets(query, R) + bucket_key = _compute_buckets(key, R) + row_offsets, column_indices = find_locations( + bucket_query, bucket_key, sparsity, infer_sparsity) + return sparse_memory_efficient_attention( + query, key, value, row_offsets, column_indices, attn_bias) diff --git a/demucs/utils.py b/demucs/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..94bd323dcd4b170b4d71ff27bade1bb6ad28738f --- /dev/null +++ b/demucs/utils.py @@ -0,0 +1,502 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from contextlib import contextmanager +import math +import os +import tempfile +import typing as tp + +import errno +import functools +import hashlib +import inspect +import io +import os +import random +import socket +import tempfile +import warnings +import zlib +import tkinter as tk + +from diffq import UniformQuantizer, DiffQuantizer +import torch as th +import tqdm +from torch import distributed +from torch.nn import functional as F + +import torch + +def unfold(a, kernel_size, stride): + """Given input of size [*OT, T], output Tensor of size [*OT, F, K] + with K the kernel size, by extracting frames with the given stride. + + This will pad the input so that `F = ceil(T / K)`. + + see https://github.com/pytorch/pytorch/issues/60466 + """ + *shape, length = a.shape + n_frames = math.ceil(length / stride) + tgt_length = (n_frames - 1) * stride + kernel_size + a = F.pad(a, (0, tgt_length - length)) + strides = list(a.stride()) + assert strides[-1] == 1, 'data should be contiguous' + strides = strides[:-1] + [stride, 1] + return a.as_strided([*shape, n_frames, kernel_size], strides) + + +def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]): + """ + Center trim `tensor` with respect to `reference`, along the last dimension. + `reference` can also be a number, representing the length to trim to. + If the size difference != 0 mod 2, the extra sample is removed on the right side. + """ + ref_size: int + if isinstance(reference, torch.Tensor): + ref_size = reference.size(-1) + else: + ref_size = reference + delta = tensor.size(-1) - ref_size + if delta < 0: + raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") + if delta: + tensor = tensor[..., delta // 2:-(delta - delta // 2)] + return tensor + + +def pull_metric(history: tp.List[dict], name: str): + out = [] + for metrics in history: + metric = metrics + for part in name.split("."): + metric = metric[part] + out.append(metric) + return out + + +def EMA(beta: float = 1): + """ + Exponential Moving Average callback. + Returns a single function that can be called to repeatidly update the EMA + with a dict of metrics. The callback will return + the new averaged dict of metrics. + + Note that for `beta=1`, this is just plain averaging. + """ + fix: tp.Dict[str, float] = defaultdict(float) + total: tp.Dict[str, float] = defaultdict(float) + + def _update(metrics: dict, weight: float = 1) -> dict: + nonlocal total, fix + for key, value in metrics.items(): + total[key] = total[key] * beta + weight * float(value) + fix[key] = fix[key] * beta + weight + return {key: tot / fix[key] for key, tot in total.items()} + return _update + + +def sizeof_fmt(num: float, suffix: str = 'B'): + """ + Given `num` bytes, return human readable size. + Taken from https://stackoverflow.com/a/1094933 + """ + for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: + if abs(num) < 1024.0: + return "%3.1f%s%s" % (num, unit, suffix) + num /= 1024.0 + return "%.1f%s%s" % (num, 'Yi', suffix) + + +@contextmanager +def temp_filenames(count: int, delete=True): + names = [] + try: + for _ in range(count): + names.append(tempfile.NamedTemporaryFile(delete=False).name) + yield names + finally: + if delete: + for name in names: + os.unlink(name) + +def average_metric(metric, count=1.): + """ + Average `metric` which should be a float across all hosts. `count` should be + the weight for this particular host (i.e. number of examples). + """ + metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda') + distributed.all_reduce(metric, op=distributed.ReduceOp.SUM) + return metric[1].item() / metric[0].item() + + +def free_port(host='', low=20000, high=40000): + """ + Return a port number that is most likely free. + This could suffer from a race condition although + it should be quite rare. + """ + sock = socket.socket() + while True: + port = random.randint(low, high) + try: + sock.bind((host, port)) + except OSError as error: + if error.errno == errno.EADDRINUSE: + continue + raise + return port + + +def sizeof_fmt(num, suffix='B'): + """ + Given `num` bytes, return human readable size. + Taken from https://stackoverflow.com/a/1094933 + """ + for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: + if abs(num) < 1024.0: + return "%3.1f%s%s" % (num, unit, suffix) + num /= 1024.0 + return "%.1f%s%s" % (num, 'Yi', suffix) + + +def human_seconds(seconds, display='.2f'): + """ + Given `seconds` seconds, return human readable duration. + """ + value = seconds * 1e6 + ratios = [1e3, 1e3, 60, 60, 24] + names = ['us', 'ms', 's', 'min', 'hrs', 'days'] + last = names.pop(0) + for name, ratio in zip(names, ratios): + if value / ratio < 0.3: + break + value /= ratio + last = name + return f"{format(value, display)} {last}" + + +class TensorChunk: + def __init__(self, tensor, offset=0, length=None): + total_length = tensor.shape[-1] + assert offset >= 0 + assert offset < total_length + + if length is None: + length = total_length - offset + else: + length = min(total_length - offset, length) + + self.tensor = tensor + self.offset = offset + self.length = length + self.device = tensor.device + + @property + def shape(self): + shape = list(self.tensor.shape) + shape[-1] = self.length + return shape + + def padded(self, target_length): + delta = target_length - self.length + total_length = self.tensor.shape[-1] + assert delta >= 0 + + start = self.offset - delta // 2 + end = start + target_length + + correct_start = max(0, start) + correct_end = min(total_length, end) + + pad_left = correct_start - start + pad_right = end - correct_end + + out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) + assert out.shape[-1] == target_length + return out + + +def tensor_chunk(tensor_or_chunk): + if isinstance(tensor_or_chunk, TensorChunk): + return tensor_or_chunk + else: + assert isinstance(tensor_or_chunk, th.Tensor) + return TensorChunk(tensor_or_chunk) + + +def apply_model_v1(model, mix, shifts=None, split=False, progress=False, set_progress_bar=None): + """ + Apply model to a given mixture. + + Args: + shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec + and apply the oppositve shift to the output. This is repeated `shifts` time and + all predictions are averaged. This effectively makes the model time equivariant + and improves SDR by up to 0.2 points. + split (bool): if True, the input will be broken down in 8 seconds extracts + and predictions will be performed individually on each and concatenated. + Useful for model with large memory footprint like Tasnet. + progress (bool): if True, show a progress bar (requires split=True) + """ + + channels, length = mix.size() + device = mix.device + progress_value = 0 + + if split: + out = th.zeros(4, channels, length, device=device) + shift = model.samplerate * 10 + offsets = range(0, length, shift) + scale = 10 + if progress: + offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds') + for offset in offsets: + chunk = mix[..., offset:offset + shift] + if set_progress_bar: + progress_value += 1 + set_progress_bar(0.1, (0.8/len(offsets)*progress_value)) + chunk_out = apply_model_v1(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar) + else: + chunk_out = apply_model_v1(model, chunk, shifts=shifts) + out[..., offset:offset + shift] = chunk_out + offset += shift + return out + elif shifts: + max_shift = int(model.samplerate / 2) + mix = F.pad(mix, (max_shift, max_shift)) + offsets = list(range(max_shift)) + random.shuffle(offsets) + out = 0 + for offset in offsets[:shifts]: + shifted = mix[..., offset:offset + length + max_shift] + if set_progress_bar: + shifted_out = apply_model_v1(model, shifted, set_progress_bar=set_progress_bar) + else: + shifted_out = apply_model_v1(model, shifted) + out += shifted_out[..., max_shift - offset:max_shift - offset + length] + out /= shifts + return out + else: + valid_length = model.valid_length(length) + delta = valid_length - length + padded = F.pad(mix, (delta // 2, delta - delta // 2)) + with th.no_grad(): + out = model(padded.unsqueeze(0))[0] + return center_trim(out, mix) + +def apply_model_v2(model, mix, shifts=None, split=False, + overlap=0.25, transition_power=1., progress=False, set_progress_bar=None): + """ + Apply model to a given mixture. + + Args: + shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec + and apply the oppositve shift to the output. This is repeated `shifts` time and + all predictions are averaged. This effectively makes the model time equivariant + and improves SDR by up to 0.2 points. + split (bool): if True, the input will be broken down in 8 seconds extracts + and predictions will be performed individually on each and concatenated. + Useful for model with large memory footprint like Tasnet. + progress (bool): if True, show a progress bar (requires split=True) + """ + + assert transition_power >= 1, "transition_power < 1 leads to weird behavior." + device = mix.device + channels, length = mix.shape + progress_value = 0 + + if split: + out = th.zeros(len(model.sources), channels, length, device=device) + sum_weight = th.zeros(length, device=device) + segment = model.segment_length + stride = int((1 - overlap) * segment) + offsets = range(0, length, stride) + scale = stride / model.samplerate + if progress: + offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds') + # We start from a triangle shaped weight, with maximal weight in the middle + # of the segment. Then we normalize and take to the power `transition_power`. + # Large values of transition power will lead to sharper transitions. + weight = th.cat([th.arange(1, segment // 2 + 1), + th.arange(segment - segment // 2, 0, -1)]).to(device) + assert len(weight) == segment + # If the overlap < 50%, this will translate to linear transition when + # transition_power is 1. + weight = (weight / weight.max())**transition_power + for offset in offsets: + chunk = TensorChunk(mix, offset, segment) + if set_progress_bar: + progress_value += 1 + set_progress_bar(0.1, (0.8/len(offsets)*progress_value)) + chunk_out = apply_model_v2(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar) + else: + chunk_out = apply_model_v2(model, chunk, shifts=shifts) + chunk_length = chunk_out.shape[-1] + out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out + sum_weight[offset:offset + segment] += weight[:chunk_length] + offset += segment + assert sum_weight.min() > 0 + out /= sum_weight + return out + elif shifts: + max_shift = int(0.5 * model.samplerate) + mix = tensor_chunk(mix) + padded_mix = mix.padded(length + 2 * max_shift) + out = 0 + for _ in range(shifts): + offset = random.randint(0, max_shift) + shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) + + if set_progress_bar: + progress_value += 1 + shifted_out = apply_model_v2(model, shifted, set_progress_bar=set_progress_bar) + else: + shifted_out = apply_model_v2(model, shifted) + out += shifted_out[..., max_shift - offset:] + out /= shifts + return out + else: + valid_length = model.valid_length(length) + mix = tensor_chunk(mix) + padded_mix = mix.padded(valid_length) + with th.no_grad(): + out = model(padded_mix.unsqueeze(0))[0] + return center_trim(out, length) + + +@contextmanager +def temp_filenames(count, delete=True): + names = [] + try: + for _ in range(count): + names.append(tempfile.NamedTemporaryFile(delete=False).name) + yield names + finally: + if delete: + for name in names: + os.unlink(name) + + +def get_quantizer(model, args, optimizer=None): + quantizer = None + if args.diffq: + quantizer = DiffQuantizer( + model, min_size=args.q_min_size, group_size=8) + if optimizer is not None: + quantizer.setup_optimizer(optimizer) + elif args.qat: + quantizer = UniformQuantizer( + model, bits=args.qat, min_size=args.q_min_size) + return quantizer + + +def load_model(path, strict=False): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + load_from = path + package = th.load(load_from, 'cpu') + + klass = package["klass"] + args = package["args"] + kwargs = package["kwargs"] + + if strict: + model = klass(*args, **kwargs) + else: + sig = inspect.signature(klass) + for key in list(kwargs): + if key not in sig.parameters: + warnings.warn("Dropping inexistant parameter " + key) + del kwargs[key] + model = klass(*args, **kwargs) + + state = package["state"] + training_args = package["training_args"] + quantizer = get_quantizer(model, training_args) + + set_state(model, quantizer, state) + return model + + +def get_state(model, quantizer): + if quantizer is None: + state = {k: p.data.to('cpu') for k, p in model.state_dict().items()} + else: + state = quantizer.get_quantized_state() + buf = io.BytesIO() + th.save(state, buf) + state = {'compressed': zlib.compress(buf.getvalue())} + return state + + +def set_state(model, quantizer, state): + if quantizer is None: + model.load_state_dict(state) + else: + buf = io.BytesIO(zlib.decompress(state["compressed"])) + state = th.load(buf, "cpu") + quantizer.restore_quantized_state(state) + + return state + + +def save_state(state, path): + buf = io.BytesIO() + th.save(state, buf) + sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] + + path = path.parent / (path.stem + "-" + sig + path.suffix) + path.write_bytes(buf.getvalue()) + + +def save_model(model, quantizer, training_args, path): + args, kwargs = model._init_args_kwargs + klass = model.__class__ + + state = get_state(model, quantizer) + + save_to = path + package = { + 'klass': klass, + 'args': args, + 'kwargs': kwargs, + 'state': state, + 'training_args': training_args, + } + th.save(package, save_to) + + +def capture_init(init): + @functools.wraps(init) + def __init__(self, *args, **kwargs): + self._init_args_kwargs = (args, kwargs) + init(self, *args, **kwargs) + + return __init__ + +class DummyPoolExecutor: + class DummyResult: + def __init__(self, func, *args, **kwargs): + self.func = func + self.args = args + self.kwargs = kwargs + + def result(self): + return self.func(*self.args, **self.kwargs) + + def __init__(self, workers=0): + pass + + def submit(self, func, *args, **kwargs): + return DummyPoolExecutor.DummyResult(func, *args, **kwargs) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + return diff --git a/gui_data/constants.py b/gui_data/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..2e9aedaaca56a99597052bcbcf22b27192cb608b --- /dev/null +++ b/gui_data/constants.py @@ -0,0 +1,1147 @@ +import platform + +#Platform Details +OPERATING_SYSTEM = platform.system() +SYSTEM_ARCH = platform.platform() +SYSTEM_PROC = platform.processor() +ARM = 'arm' + +#Main Font +MAIN_FONT_NAME = "Century Gothic" + +#Model Types +VR_ARCH_TYPE = 'VR Arc' +MDX_ARCH_TYPE = 'MDX-Net' +DEMUCS_ARCH_TYPE = 'Demucs' +VR_ARCH_PM = 'VR Architecture' +ENSEMBLE_MODE = 'Ensemble Mode' +ENSEMBLE_STEM_CHECK = 'Ensemble Stem' +SECONDARY_MODEL = 'Secondary Model' +DEMUCS_6_STEM_MODEL = 'htdemucs_6s' + +DEMUCS_V3_ARCH_TYPE = 'Demucs v3' +DEMUCS_V4_ARCH_TYPE = 'Demucs v4' +DEMUCS_NEWER_ARCH_TYPES = [DEMUCS_V3_ARCH_TYPE, DEMUCS_V4_ARCH_TYPE] + +DEMUCS_V1 = 'v1' +DEMUCS_V2 = 'v2' +DEMUCS_V3 = 'v3' +DEMUCS_V4 = 'v4' + +DEMUCS_V1_TAG = 'v1 | ' +DEMUCS_V2_TAG = 'v2 | ' +DEMUCS_V3_TAG = 'v3 | ' +DEMUCS_V4_TAG = 'v4 | ' +DEMUCS_NEWER_TAGS = [DEMUCS_V3_TAG, DEMUCS_V4_TAG] + +DEMUCS_VERSION_MAPPER = { + DEMUCS_V1:DEMUCS_V1_TAG, + DEMUCS_V2:DEMUCS_V2_TAG, + DEMUCS_V3:DEMUCS_V3_TAG, + DEMUCS_V4:DEMUCS_V4_TAG} + +#Download Center +DOWNLOAD_FAILED = 'Download Failed' +DOWNLOAD_STOPPED = 'Download Stopped' +DOWNLOAD_COMPLETE = 'Download Complete' +DOWNLOAD_UPDATE_COMPLETE = 'Update Download Complete' +SETTINGS_MENU_EXIT = 'exit' +NO_CONNECTION = 'No Internet Connection' +VIP_SELECTION = 'VIP:' +DEVELOPER_SELECTION = 'VIP:' +NO_NEW_MODELS = 'All Available Models Downloaded' +ENSEMBLE_PARTITION = ': ' +NO_MODEL = 'No Model Selected' +CHOOSE_MODEL = 'Choose Model' +SINGLE_DOWNLOAD = 'Downloading Item 1/1...' +DOWNLOADING_ITEM = 'Downloading Item' +FILE_EXISTS = 'File already exists!' +DOWNLOADING_UPDATE = 'Downloading Update...' +DOWNLOAD_MORE = 'Download More Models' + +#Menu Options + +AUTO_SELECT = 'Auto' + +#LINKS +DOWNLOAD_CHECKS = "https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json" +MDX_MODEL_DATA_LINK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/model_data.json" +VR_MODEL_DATA_LINK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/vr_model_data/model_data.json" + +DEMUCS_MODEL_NAME_DATA_LINK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/demucs_model_data/model_name_mapper.json" +MDX_MODEL_NAME_DATA_LINK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/model_name_mapper.json" + +DONATE_LINK_BMAC = "https://www.buymeacoffee.com/uvr5" +DONATE_LINK_PATREON = "https://www.patreon.com/uvr" + +#DOWNLOAD REPOS +NORMAL_REPO = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/" +UPDATE_REPO = "https://github.com/TRvlvr/model_repo/releases/download/uvr_update_patches/" + +UPDATE_MAC_ARM_REPO = "https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.5.0/Ultimate_Vocal_Remover_v5_5_MacOS_arm64.dmg" +UPDATE_MAC_X86_64_REPO = "https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.5.0/Ultimate_Vocal_Remover_v5_5_MacOS_x86_64.dmg" +UPDATE_LINUX_REPO = "https://github.com/Anjok07/ultimatevocalremovergui#linux-installation" +UPDATE_REPO = "https://github.com/TRvlvr/model_repo/releases/download/uvr_update_patches/" + +ISSUE_LINK = 'https://github.com/Anjok07/ultimatevocalremovergui/issues/new' +VIP_REPO = b'\xf3\xc2W\x19\x1foI)\xc2\xa9\xcc\xb67(Z\xf5',\ + b'gAAAAABjQAIQ-NpNMMxMedpKHHb7ze_nqB05hw0YhbOy3pFzuzDrfqumn8_qvraxEoUpZC5ZXC0gGvfDxFMqyq9VWbYKlA67SUFI_wZB6QoVyGI581vs7kaGfUqlXHIdDS6tQ_U-BfjbEAK9EU_74-R2zXjz8Xzekw==' +NO_CODE = 'incorrect_code' + +#Extensions + +ONNX = '.onnx' +CKPT = '.ckpt' +YAML = '.yaml' +PTH = '.pth' +TH_EXT = '.th' +JSON = '.json' + +#GUI Buttons + +START_PROCESSING = 'Start Processing' +WAIT_PROCESSING = 'Please wait...' +STOP_PROCESSING = 'Halting process, please wait...' +LOADING_MODELS = 'Loading models...' + +#---Messages and Logs---- + +MISSING_MODEL = 'missing' +MODEL_PRESENT = 'present' + +UNRECOGNIZED_MODEL = 'Unrecognized Model Detected', ' is an unrecognized model.\n\n' + \ + 'Would you like to select the correct parameters before continuing?' + +STOP_PROCESS_CONFIRM = 'Confirmation', 'You are about to stop all active processes.\n\nAre you sure you wish to continue?' +NO_ENSEMBLE_SELECTED = 'No Models Selected', 'Please select ensemble and try again.' +PICKLE_CORRU = 'File Corrupted', 'Unable to load this ensemble.\n\n' + \ + 'Would you like to remove this ensemble from your list?' +DELETE_ENS_ENTRY = 'Confirm Removal', 'Are you sure you want to remove this entry?' + +ALL_STEMS = 'All Stems' +VOCAL_STEM = 'Vocals' +INST_STEM = 'Instrumental' +OTHER_STEM = 'Other' +BASS_STEM = 'Bass' +DRUM_STEM = 'Drums' +GUITAR_STEM = 'Guitar' +PIANO_STEM = 'Piano' +SYNTH_STEM = 'Synthesizer' +STRINGS_STEM = 'Strings' +WOODWINDS_STEM = 'Woodwinds' +BRASS_STEM = 'Brass' +WIND_INST_STEM = 'Wind Inst' +NO_OTHER_STEM = 'No Other' +NO_BASS_STEM = 'No Bass' +NO_DRUM_STEM = 'No Drums' +NO_GUITAR_STEM = 'No Guitar' +NO_PIANO_STEM = 'No Piano' +NO_SYNTH_STEM = 'No Synthesizer' +NO_STRINGS_STEM = 'No Strings' +NO_WOODWINDS_STEM = 'No Woodwinds' +NO_WIND_INST_STEM = 'No Wind Inst' +NO_BRASS_STEM = 'No Brass' +PRIMARY_STEM = 'Primary Stem' +SECONDARY_STEM = 'Secondary Stem' + +#Other Constants +DEMUCS_2_SOURCE = ["instrumental", "vocals"] +DEMUCS_4_SOURCE = ["drums", "bass", "other", "vocals"] + +DEMUCS_2_SOURCE_MAPPER = { + INST_STEM: 0, + VOCAL_STEM: 1} + +DEMUCS_4_SOURCE_MAPPER = { + BASS_STEM: 0, + DRUM_STEM: 1, + OTHER_STEM: 2, + VOCAL_STEM: 3} + +DEMUCS_6_SOURCE_MAPPER = { + BASS_STEM: 0, + DRUM_STEM: 1, + OTHER_STEM: 2, + VOCAL_STEM: 3, + GUITAR_STEM:4, + PIANO_STEM:5} + +DEMUCS_4_SOURCE_LIST = [BASS_STEM, DRUM_STEM, OTHER_STEM, VOCAL_STEM] +DEMUCS_6_SOURCE_LIST = [BASS_STEM, DRUM_STEM, OTHER_STEM, VOCAL_STEM, GUITAR_STEM, PIANO_STEM] + +DEMUCS_UVR_MODEL = 'UVR_Model' + +CHOOSE_STEM_PAIR = 'Choose Stem Pair' + +STEM_SET_MENU = (VOCAL_STEM, + INST_STEM, + OTHER_STEM, + BASS_STEM, + DRUM_STEM, + GUITAR_STEM, + PIANO_STEM, + SYNTH_STEM, + STRINGS_STEM, + WOODWINDS_STEM, + BRASS_STEM, + WIND_INST_STEM, + NO_OTHER_STEM, + NO_BASS_STEM, + NO_DRUM_STEM, + NO_GUITAR_STEM, + NO_PIANO_STEM, + NO_SYNTH_STEM, + NO_STRINGS_STEM, + NO_WOODWINDS_STEM, + NO_BRASS_STEM, + NO_WIND_INST_STEM) + +STEM_PAIR_MAPPER = { + VOCAL_STEM: INST_STEM, + INST_STEM: VOCAL_STEM, + OTHER_STEM: NO_OTHER_STEM, + BASS_STEM: NO_BASS_STEM, + DRUM_STEM: NO_DRUM_STEM, + GUITAR_STEM: NO_GUITAR_STEM, + PIANO_STEM: NO_PIANO_STEM, + SYNTH_STEM: NO_SYNTH_STEM, + STRINGS_STEM: NO_STRINGS_STEM, + WOODWINDS_STEM: NO_WOODWINDS_STEM, + BRASS_STEM: NO_BRASS_STEM, + WIND_INST_STEM: NO_WIND_INST_STEM, + NO_OTHER_STEM: OTHER_STEM, + NO_BASS_STEM: BASS_STEM, + NO_DRUM_STEM: DRUM_STEM, + NO_GUITAR_STEM: GUITAR_STEM, + NO_PIANO_STEM: PIANO_STEM, + NO_SYNTH_STEM: SYNTH_STEM, + NO_STRINGS_STEM: STRINGS_STEM, + NO_WOODWINDS_STEM: WOODWINDS_STEM, + NO_BRASS_STEM: BRASS_STEM, + NO_WIND_INST_STEM: WIND_INST_STEM, + PRIMARY_STEM: SECONDARY_STEM} + +NON_ACCOM_STEMS = ( + VOCAL_STEM, + OTHER_STEM, + BASS_STEM, + DRUM_STEM, + GUITAR_STEM, + PIANO_STEM, + SYNTH_STEM, + STRINGS_STEM, + WOODWINDS_STEM, + BRASS_STEM, + WIND_INST_STEM) + +MDX_NET_FREQ_CUT = [VOCAL_STEM, INST_STEM] + +DEMUCS_4_STEM_OPTIONS = (ALL_STEMS, VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM) +DEMUCS_6_STEM_OPTIONS = (ALL_STEMS, VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM, GUITAR_STEM, PIANO_STEM) +DEMUCS_2_STEM_OPTIONS = (VOCAL_STEM, INST_STEM) +DEMUCS_4_STEM_CHECK = (OTHER_STEM, BASS_STEM, DRUM_STEM) + +#Menu Dropdowns + +VOCAL_PAIR = f'{VOCAL_STEM}/{INST_STEM}' +INST_PAIR = f'{INST_STEM}/{VOCAL_STEM}' +OTHER_PAIR = f'{OTHER_STEM}/{NO_OTHER_STEM}' +DRUM_PAIR = f'{DRUM_STEM}/{NO_DRUM_STEM}' +BASS_PAIR = f'{BASS_STEM}/{NO_BASS_STEM}' +FOUR_STEM_ENSEMBLE = '4 Stem Ensemble' + +ENSEMBLE_MAIN_STEM = (CHOOSE_STEM_PAIR, VOCAL_PAIR, OTHER_PAIR, DRUM_PAIR, BASS_PAIR, FOUR_STEM_ENSEMBLE) + +MIN_SPEC = 'Min Spec' +MAX_SPEC = 'Max Spec' +AUDIO_AVERAGE = 'Average' + +MAX_MIN = f'{MAX_SPEC}/{MIN_SPEC}' +MAX_MAX = f'{MAX_SPEC}/{MAX_SPEC}' +MAX_AVE = f'{MAX_SPEC}/{AUDIO_AVERAGE}' +MIN_MAX = f'{MIN_SPEC}/{MAX_SPEC}' +MIN_MIX = f'{MIN_SPEC}/{MIN_SPEC}' +MIN_AVE = f'{MIN_SPEC}/{AUDIO_AVERAGE}' +AVE_MAX = f'{AUDIO_AVERAGE}/{MAX_SPEC}' +AVE_MIN = f'{AUDIO_AVERAGE}/{MIN_SPEC}' +AVE_AVE = f'{AUDIO_AVERAGE}/{AUDIO_AVERAGE}' + +ENSEMBLE_TYPE = (MAX_MIN, MAX_MAX, MAX_AVE, MIN_MAX, MIN_MIX, MIN_AVE, AVE_MAX, AVE_MIN, AVE_AVE) +ENSEMBLE_TYPE_4_STEM = (MAX_SPEC, MIN_SPEC, AUDIO_AVERAGE) + +BATCH_MODE = 'Batch Mode' +BETA_VERSION = 'BETA' +DEF_OPT = 'Default' + +CHUNKS = (AUTO_SELECT, '1', '5', '10', '15', '20', + '25', '30', '35', '40', '45', '50', + '55', '60', '65', '70', '75', '80', + '85', '90', '95', 'Full') + +BATCH_SIZE = (DEF_OPT, '2', '3', '4', '5', + '6', '7', '8', '9', '10') + +VOL_COMPENSATION = (AUTO_SELECT, '1.035', '1.08') + +MARGIN_SIZE = ('44100', '22050', '11025') + +AUDIO_TOOLS = 'Audio Tools' + +MANUAL_ENSEMBLE = 'Manual Ensemble' +TIME_STRETCH = 'Time Stretch' +CHANGE_PITCH = 'Change Pitch' +ALIGN_INPUTS = 'Align Inputs' + +if OPERATING_SYSTEM == 'Windows' or OPERATING_SYSTEM == 'Darwin': + AUDIO_TOOL_OPTIONS = (MANUAL_ENSEMBLE, TIME_STRETCH, CHANGE_PITCH, ALIGN_INPUTS) +else: + AUDIO_TOOL_OPTIONS = (MANUAL_ENSEMBLE, ALIGN_INPUTS) + +MANUAL_ENSEMBLE_OPTIONS = (MIN_SPEC, MAX_SPEC, AUDIO_AVERAGE) + +PROCESS_METHODS = (VR_ARCH_PM, MDX_ARCH_TYPE, DEMUCS_ARCH_TYPE, ENSEMBLE_MODE, AUDIO_TOOLS) + +DEMUCS_SEGMENTS = ('Default', '1', '5', '10', '15', '20', + '25', '30', '35', '40', '45', '50', + '55', '60', '65', '70', '75', '80', + '85', '90', '95', '100') + +DEMUCS_SHIFTS = (0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, + 18, 19, 20) + +DEMUCS_OVERLAP = (0.25, 0.50, 0.75, 0.99) + +VR_AGGRESSION = (1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, + 18, 19, 20) + +VR_WINDOW = ('320', '512','1024') +VR_CROP = ('256', '512', '1024') +POST_PROCESSES_THREASHOLD_VALUES = ('0.1', '0.2', '0.3') + +MDX_POP_PRO = ('MDX-NET_Noise_Profile_14_kHz', 'MDX-NET_Noise_Profile_17_kHz', 'MDX-NET_Noise_Profile_Full_Band') +MDX_POP_STEMS = ('Vocals', 'Instrumental', 'Other', 'Drums', 'Bass') +MDX_POP_NFFT = ('4096', '5120', '6144', '7680', '8192', '16384') +MDX_POP_DIMF = ('2048', '3072', '4096') + +SAVE_ENSEMBLE = 'Save Ensemble' +CLEAR_ENSEMBLE = 'Clear Selection(s)' +MENU_SEPARATOR = 35*'•' +CHOOSE_ENSEMBLE_OPTION = 'Choose Option' + +INVALID_ENTRY = 'Invalid Input, Please Try Again' +ENSEMBLE_INPUT_RULE = '1. Only letters, numbers, spaces, and dashes allowed.\n2. No dashes or spaces at the start or end of input.' + +ENSEMBLE_OPTIONS = (SAVE_ENSEMBLE, CLEAR_ENSEMBLE) +ENSEMBLE_CHECK = 'ensemble check' + +SELECT_SAVED_ENSEMBLE = 'Select Saved Ensemble' +SELECT_SAVED_SETTING = 'Select Saved Setting' +ENSEMBLE_OPTION = "Ensemble Customization Options" +MDX_OPTION = "Advanced MDX-Net Options" +DEMUCS_OPTION = "Advanced Demucs Options" +VR_OPTION = "Advanced VR Options" +HELP_OPTION = "Open Information Guide" +ERROR_OPTION = "Open Error Log" +VERIFY_BEGIN = 'Verifying file ' +SAMPLE_BEGIN = 'Creating Sample ' +MODEL_MISSING_CHECK = 'Model Missing:' + +# Audio Player + +PLAYING_SONG = ": Playing" +PAUSE_SONG = ": Paused" +STOP_SONG = ": Stopped" + +SELECTED_VER = 'Selected' +DETECTED_VER = 'Detected' + +SAMPLE_MODE_CHECKBOX = lambda v:f'Sample Mode ({v}s)' +REMOVED_FILES = lambda r, e:f'Audio Input Verification Report:\n\nRemoved Files:\n\n{r}\n\nError Details:\n\n{e}' +ADVANCED_SETTINGS = (ENSEMBLE_OPTION, MDX_OPTION, DEMUCS_OPTION, VR_OPTION, HELP_OPTION, ERROR_OPTION) + +WAV = 'WAV' +FLAC = 'FLAC' +MP3 = 'MP3' + +MP3_BIT_RATES = ('96k', '128k', '160k', '224k', '256k', '320k') +WAV_TYPE = ('PCM_U8', 'PCM_16', 'PCM_24', 'PCM_32', '32-bit Float', '64-bit Float') + +SELECT_SAVED_SET = 'Choose Option' +SAVE_SETTINGS = 'Save Current Settings' +RESET_TO_DEFAULT = 'Reset to Default' +RESET_FULL_TO_DEFAULT = 'Reset to Default' +RESET_PM_TO_DEFAULT = 'Reset All Application Settings to Default' + +SAVE_SET_OPTIONS = (SAVE_SETTINGS, RESET_TO_DEFAULT) + +TIME_PITCH = ('1.0', '2.0', '3.0', '4.0') +TIME_TEXT = '_time_stretched' +PITCH_TEXT = '_pitch_shifted' + +#RegEx Input Validation + +REG_PITCH = r'^[-+]?(1[0]|[0-9]([.][0-9]*)?)$' +REG_TIME = r'^[+]?(1[0]|[0-9]([.][0-9]*)?)$' +REG_COMPENSATION = r'\b^(1[0]|[0-9]([.][0-9]*)?|Auto|None)$\b' +REG_THES_POSTPORCESS = r'\b^([0]([.][0-9]{0,6})?)$\b' +REG_CHUNKS = r'\b^(200|1[0-9][0-9]|[1-9][0-9]?|Auto|Full)$\b' +REG_CHUNKS_DEMUCS = r'\b^(200|1[0-9][0-9]|[1-9][0-9]?|Auto|Full)$\b' +REG_MARGIN = r'\b^[0-9]*$\b' +REG_SEGMENTS = r'\b^(200|1[0-9][0-9]|[1-9][0-9]?|Default)$\b' +REG_SAVE_INPUT = r'\b^([a-zA-Z0-9 -]{0,25})$\b' +REG_AGGRESSION = r'^[-+]?[0-9]\d*?$' +REG_WINDOW = r'\b^[0-9]{0,4}$\b' +REG_SHIFTS = r'\b^[0-9]*$\b' +REG_BATCHES = r'\b^([0-9]*?|Default)$\b' +REG_OVERLAP = r'\b^([0]([.][0-9]{0,6})?|None)$\b' + +# Sub Menu + +VR_ARCH_SETTING_LOAD = 'Load for VR Arch' +MDX_SETTING_LOAD = 'Load for MDX-Net' +DEMUCS_SETTING_LOAD = 'Load for Demucs' +ALL_ARCH_SETTING_LOAD = 'Load for Full Application' + +# Mappers + +DEFAULT_DATA = { + + 'chosen_process_method': MDX_ARCH_TYPE, + 'vr_model': CHOOSE_MODEL, + 'aggression_setting': 10, + 'window_size': 512, + 'batch_size': 4, + 'crop_size': 256, + 'is_tta': False, + 'is_output_image': False, + 'is_post_process': False, + 'is_high_end_process': False, + 'post_process_threshold': 0.2, + 'vr_voc_inst_secondary_model': NO_MODEL, + 'vr_other_secondary_model': NO_MODEL, + 'vr_bass_secondary_model': NO_MODEL, + 'vr_drums_secondary_model': NO_MODEL, + 'vr_is_secondary_model_activate': False, + 'vr_voc_inst_secondary_model_scale': 0.9, + 'vr_other_secondary_model_scale': 0.7, + 'vr_bass_secondary_model_scale': 0.5, + 'vr_drums_secondary_model_scale': 0.5, + 'demucs_model': CHOOSE_MODEL, + 'demucs_stems': ALL_STEMS, + 'segment': DEMUCS_SEGMENTS[0], + 'overlap': DEMUCS_OVERLAP[0], + 'shifts': 2, + 'chunks_demucs': CHUNKS[0], + 'margin_demucs': 44100, + 'is_chunk_demucs': False, + 'is_chunk_mdxnet': False, + 'is_primary_stem_only_Demucs': False, + 'is_secondary_stem_only_Demucs': False, + 'is_split_mode': True, + 'is_demucs_combine_stems': True, + 'demucs_voc_inst_secondary_model': NO_MODEL, + 'demucs_other_secondary_model': NO_MODEL, + 'demucs_bass_secondary_model': NO_MODEL, + 'demucs_drums_secondary_model': NO_MODEL, + 'demucs_is_secondary_model_activate': False, + 'demucs_voc_inst_secondary_model_scale': 0.9, + 'demucs_other_secondary_model_scale': 0.7, + 'demucs_bass_secondary_model_scale': 0.5, + 'demucs_drums_secondary_model_scale': 0.5, + 'demucs_stems': ALL_STEMS, + 'demucs_pre_proc_model': NO_MODEL, + 'is_demucs_pre_proc_model_activate': False, + 'is_demucs_pre_proc_model_inst_mix': False, + 'mdx_net_model': CHOOSE_MODEL, + 'chunks': CHUNKS[0], + 'margin': 44100, + 'compensate': AUTO_SELECT, + 'is_denoise': False, + 'is_invert_spec': False, + 'is_mixer_mode': False, + 'mdx_batch_size': DEF_OPT, + 'mdx_voc_inst_secondary_model': NO_MODEL, + 'mdx_other_secondary_model': NO_MODEL, + 'mdx_bass_secondary_model': NO_MODEL, + 'mdx_drums_secondary_model': NO_MODEL, + 'mdx_is_secondary_model_activate': False, + 'mdx_voc_inst_secondary_model_scale': 0.9, + 'mdx_other_secondary_model_scale': 0.7, + 'mdx_bass_secondary_model_scale': 0.5, + 'mdx_drums_secondary_model_scale': 0.5, + 'is_save_all_outputs_ensemble': True, + 'is_append_ensemble_name': False, + 'chosen_audio_tool': AUDIO_TOOL_OPTIONS[0], + 'choose_algorithm': MANUAL_ENSEMBLE_OPTIONS[0], + 'time_stretch_rate': 2.0, + 'pitch_rate': 2.0, + 'is_gpu_conversion': False, + 'is_primary_stem_only': False, + 'is_secondary_stem_only': False, + 'is_testing_audio': False, + 'is_add_model_name': False, + 'is_accept_any_input': False, + 'is_task_complete': False, + 'is_normalization': False, + 'is_create_model_folder': False, + 'mp3_bit_set': '320k', + 'save_format': WAV, + 'wav_type_set': 'PCM_16', + 'user_code': '', + 'export_path': '', + 'input_paths': [], + 'lastDir': None, + 'export_path': '', + 'model_hash_table': None, + 'help_hints_var': False, + 'model_sample_mode': False, + 'model_sample_mode_duration': 30 +} + +SETTING_CHECK = ('vr_model', + 'aggression_setting', + 'window_size', + 'batch_size', + 'crop_size', + 'is_tta', + 'is_output_image', + 'is_post_process', + 'is_high_end_process', + 'post_process_threshold', + 'vr_voc_inst_secondary_model', + 'vr_other_secondary_model', + 'vr_bass_secondary_model', + 'vr_drums_secondary_model', + 'vr_is_secondary_model_activate', + 'vr_voc_inst_secondary_model_scale', + 'vr_other_secondary_model_scale', + 'vr_bass_secondary_model_scale', + 'vr_drums_secondary_model_scale', + 'demucs_model', + 'segment', + 'overlap', + 'shifts', + 'chunks_demucs', + 'margin_demucs', + 'is_chunk_demucs', + 'is_primary_stem_only_Demucs', + 'is_secondary_stem_only_Demucs', + 'is_split_mode', + 'is_demucs_combine_stems', + 'demucs_voc_inst_secondary_model', + 'demucs_other_secondary_model', + 'demucs_bass_secondary_model', + 'demucs_drums_secondary_model', + 'demucs_is_secondary_model_activate', + 'demucs_voc_inst_secondary_model_scale', + 'demucs_other_secondary_model_scale', + 'demucs_bass_secondary_model_scale', + 'demucs_drums_secondary_model_scale', + 'demucs_stems', + 'mdx_net_model', + 'chunks', + 'margin', + 'compensate', + 'is_denoise', + 'is_invert_spec', + 'mdx_batch_size', + 'mdx_voc_inst_secondary_model', + 'mdx_other_secondary_model', + 'mdx_bass_secondary_model', + 'mdx_drums_secondary_model', + 'mdx_is_secondary_model_activate', + 'mdx_voc_inst_secondary_model_scale', + 'mdx_other_secondary_model_scale', + 'mdx_bass_secondary_model_scale', + 'mdx_drums_secondary_model_scale', + 'is_save_all_outputs_ensemble', + 'is_append_ensemble_name', + 'chosen_audio_tool', + 'choose_algorithm', + 'time_stretch_rate', + 'pitch_rate', + 'is_primary_stem_only', + 'is_secondary_stem_only', + 'is_testing_audio', + 'is_add_model_name', + "is_accept_any_input", + 'is_task_complete', + 'is_create_model_folder', + 'mp3_bit_set', + 'save_format', + 'wav_type_set', + 'user_code', + 'is_gpu_conversion', + 'is_normalization', + 'help_hints_var', + 'model_sample_mode', + 'model_sample_mode_duration') + +# Message Box Text + +INVALID_INPUT = 'Invalid Input', 'The input is invalid.\n\nPlease verify the input still exists or is valid and try again.' +INVALID_EXPORT = 'Invalid Export Directory', 'You have selected an invalid export directory.\n\nPlease make sure the selected directory still exists.' +INVALID_ENSEMBLE = 'Not Enough Models', 'You must select 2 or more models to run ensemble.' +INVALID_MODEL = 'No Model Chosen', 'You must select an model to continue.' +MISSING_MODEL = 'Model Missing', 'The selected model is missing or not valid.' +ERROR_OCCURED = 'Error Occured', '\n\nWould you like to open the error log for more details?\n' + +# GUI Text Constants + +BACK_TO_MAIN_MENU = 'Back to Main Menu' + +# Help Hint Text + +INTERNAL_MODEL_ATT = 'Internal model attribute. \n\n ***Do not change this setting if you are unsure!***' +STOP_HELP = 'Halts any running processes. \n A pop-up window will ask the user to confirm the action.' +SETTINGS_HELP = 'Opens the main settings guide. This window includes the \"Download Center\"' +COMMAND_TEXT_HELP = 'Provides information on the progress of the current process.' +SAVE_CURRENT_SETTINGS_HELP = 'Allows the user to open any saved settings or save the current application settings.' +CHUNKS_HELP = ('For MDX-Net, all values use the same amount of resources. Using chunks is no longer recommended.\n\n' + \ + '• This option is now only for output quality.\n' + \ + '• Some tracks may fare better depending on the value.\n' + \ + '• Some tracks may fare worse depending on the value.\n' + \ + '• Larger chunk sizes use will take less time to process.\n' +\ + '• Smaller chunk sizes use will take more time to process.\n') +CHUNKS_DEMUCS_HELP = ('This option allows the user to reduce (or increase) RAM or V-RAM usage.\n\n' + \ + '• Smaller chunk sizes use less RAM or V-RAM but can also increase processing times.\n' + \ + '• Larger chunk sizes use more RAM or V-RAM but can also reduce processing times.\n' + \ + '• Selecting \"Auto\" calculates an appropriate chuck size based on how much RAM or V-RAM your system has.\n' + \ + '• Selecting \"Full\" will process the track as one whole chunk. (not recommended)\n' + \ + '• The default selection is \"Auto\".') +MARGIN_HELP = 'Selects the frequency margins to slice the chunks from.\n\n• The recommended margin size is 44100.\n• Other values can give unpredictable results.' +AGGRESSION_SETTING_HELP = ('This option allows you to set how strong the primary stem extraction will be.\n\n' + \ + '• The range is 0-100.\n' + \ + '• Higher values perform deeper extractions.\n' + \ + '• The default is 10 for instrumental & vocal models.\n' + \ + '• Values over 10 can result in muddy-sounding instrumentals for the non-vocal models') +WINDOW_SIZE_HELP = ('The smaller your window size, the better your conversions will be. \nHowever, a smaller window means longer conversion times and heavier resource usage.\n\n' + \ + 'Breakdown of the selectable window size values:\n' + \ + '• 1024 - Low conversion quality, shortest conversion time, low resource usage.\n' + \ + '• 512 - Average conversion quality, average conversion time, normal resource usage.\n' + \ + '• 320 - Better conversion quality.') +DEMUCS_STEMS_HELP = ('Here, you can choose which stem to extract using the selected model.\n\n' +\ + 'Stem Selections:\n\n' +\ + '• All Stems - Saves all of the stems the model is able to extract.\n' +\ + '• Vocals - Pulls vocal stem only.\n' +\ + '• Other - Pulls other stem only.\n' +\ + '• Bass - Pulls bass stem only.\n' +\ + '• Drums - Pulls drum stem only.\n') +SEGMENT_HELP = ('This option allows the user to reduce (or increase) RAM or V-RAM usage.\n\n' + \ + '• Smaller segment sizes use less RAM or V-RAM but can also increase processing times.\n' + \ + '• Larger segment sizes use more RAM or V-RAM but can also reduce processing times.\n' + \ + '• Selecting \"Default\" uses the recommended segment size.\n' + \ + '• It is recommended that you not use segments with \"Chunking\".') +ENSEMBLE_MAIN_STEM_HELP = 'Allows the user to select the type of stems they wish to ensemble.\n\nOptions:\n\n' +\ + f'• {VOCAL_PAIR} - The primary stem will be the vocals and the secondary stem will be the the instrumental\n' +\ + f'• {OTHER_PAIR} - The primary stem will be other and the secondary stem will be no other (the mixture without the \'other\' stem)\n' +\ + f'• {BASS_PAIR} - The primary stem will be bass and the secondary stem will be no bass (the mixture without the \'bass\' stem)\n' +\ + f'• {DRUM_PAIR} - The primary stem will be drums and the secondary stem will be no drums (the mixture without the \'drums\' stem)\n' +\ + f'• {FOUR_STEM_ENSEMBLE} - This option will gather all the 4 stem Demucs models and ensemble all of the outputs.\n' +ENSEMBLE_TYPE_HELP = 'Allows the user to select the ensemble algorithm to be used to generate the final output.\n\nExample & Other Note:\n\n' +\ + f'• {MAX_MIN} - If this option is chosen, the primary stem outputs will be processed through \nthe \'Max Spec\' algorithm, and the secondary stem will be processed through the \'Min Spec\' algorithm.\n' +\ + f'• Only a single algorithm will be shown when the \'4 Stem Ensemble\' option is chosen.\n\nAlgorithm Details:\n\n' +\ + f'• {MAX_SPEC} - This algorithm combines the final results and generates the highest possible output from them.\nFor example, if this algorithm were processing vocal stems, you would get the fullest possible \n' +\ + 'result making the ensembled vocal stem sound cleaner. However, it might result in more unwanted artifacts.\n' +\ + f'• {MIN_SPEC} - This algorithm combines the results and generates the lowest possible output from them.\nFor example, if this algorithm were processing instrumental stems, you would get the cleanest possible result \n' +\ + 'result, eliminating more unwanted artifacts. However, the result might also sound \'muddy\' and lack a fuller sound.\n' +\ + f'• {AUDIO_AVERAGE} - This algorithm simply combines the results and averages all of them together. \n' +ENSEMBLE_LISTBOX_HELP = 'List of the all the models available for the main stem pair selected.' +IS_GPU_CONVERSION_HELP = ('When checked, the application will attempt to use your GPU (if you have one).\n' +\ + 'If you do not have a GPU but have this checked, the application will default to your CPU.\n\n' +\ + 'Note: CPU conversions are much slower than those processed through the GPU.') +SAVE_STEM_ONLY_HELP = 'Allows the user to save only the selected stem.' +IS_NORMALIZATION_HELP = 'Normalizes output to prevent clipping.' +CROP_SIZE_HELP = '**Only compatible with select models only!**\n\n Setting should match training crop-size value. Leave as is if unsure.' +IS_TTA_HELP = ('This option performs Test-Time-Augmentation to improve the separation quality.\n\n' +\ + 'Note: Having this selected will increase the time it takes to complete a conversion') +IS_POST_PROCESS_HELP = ('This option can potentially identify leftover instrumental artifacts within the vocal outputs. \nThis option may improve the separation of some songs.\n\n' +\ + 'Note: Selecting this option can adversely affect the conversion process, depending on the track. Because of this, it is only recommended as a last resort.') +IS_HIGH_END_PROCESS_HELP = 'The application will mirror the missing frequency range of the output.' +SHIFTS_HELP = ('Performs multiple predictions with random shifts of the input and averages them.\n\n' +\ + '• The higher number of shifts, the longer the prediction will take. \n- Not recommended unless you have a GPU.') +OVERLAP_HELP = 'This option controls the amount of overlap between prediction windows (for Demucs one window is 10 seconds)' +IS_CHUNK_DEMUCS_HELP = '• Enables \"Chunks\".\n• We recommend you not enable this option with \"Split Mode\" enabled or with the Demucs v4 Models.' +IS_CHUNK_MDX_NET_HELP = '• Enables \"Chunks\".\n• Using this option for MDX-Net no longer effects RAM usage.\n• Having this enabled will effect output quality, for better or worse depending on the set value.' +IS_SPLIT_MODE_HELP = ('• Enables \"Segments\". \n• We recommend you not enable this option with \"Enable Chunks\".\n' +\ + '• Deselecting this option is only recommended for those with powerful PCs or if using \"Chunk\" mode instead.') +IS_DEMUCS_COMBINE_STEMS_HELP = 'The application will create the secondary stem by combining the remaining stems \ninstead of inverting the primary stem with the mixture.' +COMPENSATE_HELP = 'Compensates the audio of the primary stems to allow for a better secondary stem.' +IS_DENOISE_HELP = '• This option removes a majority of the noise generated by the MDX-Net models.\n• The conversion will take nearly twice as long with this enabled.' +CLEAR_CACHE_HELP = 'Clears any user selected model settings for previously unrecognized models.' +IS_SAVE_ALL_OUTPUTS_ENSEMBLE_HELP = 'Enabling this option will keep all indivudual outputs generated by an ensemble.' +IS_APPEND_ENSEMBLE_NAME_HELP = 'The application will append the ensemble name to the final output \nwhen this option is enabled.' +DONATE_HELP = 'Takes the user to an external web-site to donate to this project!' +IS_INVERT_SPEC_HELP = '• This option may produce a better secondary stem.\n• Inverts primary stem with mixture using spectragrams instead of wavforms.\n• This inversion method is slightly slower.' +IS_MIXER_MODE_HELP = '• This option may improve separations for outputs from 4-stem models.\n• Might produce more noise.\n• This option might slow down separation time.' +IS_TESTING_AUDIO_HELP = 'Appends a unique 10 digit number to output files so the user \ncan compare results with different settings.' +IS_MODEL_TESTING_AUDIO_HELP = 'Appends the model name to output files so the user \ncan compare results with different settings.' +IS_ACCEPT_ANY_INPUT_HELP = 'The application will accept any input when enabled, even if it does not have an audio format extension.\n\nThis is for experimental purposes, and having it enabled is not recommended.' +IS_TASK_COMPLETE_HELP = 'When enabled, chimes will be heard when a process completes or fails.' +IS_CREATE_MODEL_FOLDER_HELP = 'Two new directories will be generated for the outputs in \nthe export directory after each conversion.\n\n' +\ + '• First directory - Named after the model.\n' +\ + '• Second directory - Named after the track.\n\n' +\ + '• Example: \n\n' +\ + '─ Export Directory\n' +\ + ' └── First Directory\n' +\ + ' └── Second Directory\n' +\ + ' └── Output File(s)' +DELETE_YOUR_SETTINGS_HELP = 'This menu contains your saved settings. You will be asked to \nconfirm if you wish to delete the selected setting.' +SET_STEM_NAME_HELP = 'Choose the primary stem for the selected model.' +MDX_DIM_T_SET_HELP = INTERNAL_MODEL_ATT +MDX_DIM_F_SET_HELP = INTERNAL_MODEL_ATT +MDX_N_FFT_SCALE_SET_HELP = 'Set the N_FFT size the model was trained with.' +POPUP_COMPENSATE_HELP = f'Choose the appropriate voluem compensattion for the selected model\n\nReminder: {COMPENSATE_HELP}' +VR_MODEL_PARAM_HELP = 'Choose the parameters needed to run the selected model.' +CHOSEN_ENSEMBLE_HELP = 'Select saved enselble or save current ensemble.\n\nDefault Selections:\n\n• Save the current ensemble.\n• Clears all current model selections.' +CHOSEN_PROCESS_METHOD_HELP = 'Here, you choose between different Al networks and algorithms to process your track.\n\n' +\ + 'There are five options:\n\n' +\ + '• VR Architecture - These models use magnitude spectrograms for Source Separation.\n' +\ + '• MDX-Net - These models use Hybrid Spectrogram/Waveform for Source Separation.\n' +\ + '• Demucs v3 - These models use Hybrid Spectrogram/Waveform for Source Separation.\n' +\ + '• Ensemble Mode - Here, you can get the best results from multiple models and networks.\n' +\ + '• Audio Tools - These are additional tools for added convenience.' +INPUT_FOLDER_ENTRY_HELP = 'Select Input:\n\nHere is where you select the audio files(s) you wish to process.' +INPUT_FOLDER_ENTRY_HELP_2 = 'Input Option Menu:\n\nClick here to access the input option menu.' +OUTPUT_FOLDER_ENTRY_HELP = 'Select Output:\n\nHere is where you select the directory where your processed files are to be saved.' +INPUT_FOLDER_BUTTON_HELP = 'Open Input Folder Button: \n\nOpens the directory containing the selected input audio file(s).' +OUTPUT_FOLDER_BUTTON_HELP = 'Open Output Folder Button: \n\nOpens the selected output folder.' +CHOOSE_MODEL_HELP = 'Each process method comes with its own set of options and models.\n\nHere is where you choose the model associated with the selected process method.' +FORMAT_SETTING_HELP = 'Save outputs as ' +SECONDARY_MODEL_ACTIVATE_HELP = 'When enabled, the application will run an additional inference with the selected model(s) above.' +SECONDARY_MODEL_HELP = 'Choose the secondary model associated with this stem you wish to run with the current process method.' +SECONDARY_MODEL_SCALE_HELP = 'The scale determines how the final audio outputs will be averaged between the primary and secondary models.\n\nFor example:\n\n' +\ + '• 10% - 10 percent of the main model result will be factored into the final result.\n' +\ + '• 50% - The results from the main and secondary models will be averaged evenly.\n' +\ + '• 90% - 90 percent of the main model result will be factored into the final result.' +PRE_PROC_MODEL_ACTIVATE_HELP = 'The application will run an inference with the selected model above, pulling only the instrumental stem when enabled. \nFrom there, all of the non-vocal stems will be pulled from the generated instrumental.\n\nNotes:\n\n' +\ + '• This option can significantly reduce vocal bleed within the non-vocal stems.\n' +\ + '• It is only available in Demucs.\n' +\ + '• It is only compatible with non-vocal and non-instrumental stem outputs.\n' +\ + '• This will increase thetotal processing time.\n' +\ + '• Only VR and MDX-Net Vocal or Instrumental models are selectable above.' + +AUDIO_TOOLS_HELP = 'Here, you choose between different audio tools to process your track.\n\n' +\ + '• Manual Ensemble - You must have 2 or more files selected as your inputs. Allows the user to run their tracks through \nthe same algorithms used in Ensemble Mode.\n' +\ + '• Align Inputs - You must have exactly 2 files selected as your inputs. The second input will be aligned with the first input.\n' +\ + '• Time Stretch - The user can speed up or slow down the selected inputs.\n' +\ + '• Change Pitch - The user can change the pitch for the selected inputs.\n' +PRE_PROC_MODEL_INST_MIX_HELP = 'When enabled, the application will generate a third output without the selected stem and vocals.' +MODEL_SAMPLE_MODE_HELP = 'Allows the user to process only part of a track to sample settings or a model without \nrunning a full conversion.\n\nNotes:\n\n' +\ + '• The number in the parentheses is the current number of seconds the generated sample will be.\n' +\ + '• You can choose the number of seconds to extract from the track in the \"Additional Settings\" menu.' + +POST_PROCESS_THREASHOLD_HELP = 'Allows the user to control the intensity of the Post_process option.\n\nNotes:\n\n' +\ + '• Higher values potentially remove more artifacts. However, bleed might increase.\n' +\ + '• Lower values limit artifact removal.' + +BATCH_SIZE_HELP = 'Specify the number of batches to be processed at a time.\n\nNotes:\n\n' +\ + '• Higher values mean more RAM usage but slightly faster processing times.\n' +\ + '• Lower values mean less RAM usage but slightly longer processing times.\n' +\ + '• Batch size value has no effect on output quality.' + +# Warning Messages + +STORAGE_ERROR = 'Insufficient Storage', 'There is not enough storage on main drive to continue. Your main drive must have at least 3 GB\'s of storage in order for this application function properly. \n\nPlease ensure your main drive has at least 3 GB\'s of storage and try again.\n\n' +STORAGE_WARNING = 'Available Storage Low', 'Your main drive is running low on storage. Your main drive must have at least 3 GB\'s of storage in order for this application function properly.\n\n' +CONFIRM_WARNING = '\nAre you sure you wish to continue?' +PROCESS_FAILED = 'Process failed, please see error log\n' +EXIT_PROCESS_ERROR = 'Active Process', 'Please stop the active process or wait for it to complete before you exit.' +EXIT_HALTED_PROCESS_ERROR = 'Halting Process', 'Please wait for the application to finish halting the process before exiting.' +EXIT_DOWNLOAD_ERROR = 'Active Download', 'Please stop the download or wait for it to complete before you exit.' +SET_TO_DEFAULT_PROCESS_ERROR = 'Active Process', 'You cannot reset all of the application settings during an active process.' +SET_TO_ANY_PROCESS_ERROR = 'Active Process', 'You cannot reset the application settings during an active process.' +RESET_ALL_TO_DEFAULT_WARNING = 'Reset Settings Confirmation', 'All application settings will be set to factory default.\n\nAre you sure you wish to continue?' +AUDIO_VERIFICATION_CHECK = lambda i, e:f'++++++++++++++++++++++++++++++++++++++++++++++++++++\n\nBroken File Removed: \n\n{i}\n\nError Details:\n\n{e}\n++++++++++++++++++++++++++++++++++++++++++++++++++++' +INVALID_ONNX_MODEL_ERROR = 'Invalid Model', 'The file selected is not a valid MDX-Net model. Please see the error log for more information.' + + +# Separation Text + +LOADING_MODEL = 'Loading model...' +INFERENCE_STEP_1 = 'Running inference...' +INFERENCE_STEP_1_SEC = 'Running inference (secondary model)...' +INFERENCE_STEP_1_4_STEM = lambda stem:f'Running inference (secondary model for {stem})...' +INFERENCE_STEP_1_PRE = 'Running inference (pre-process model)...' +INFERENCE_STEP_2_PRE = lambda pm, m:f'Loading pre-process model ({pm}: {m})...' +INFERENCE_STEP_2_SEC = lambda pm, m:f'Loading secondary model ({pm}: {m})...' +INFERENCE_STEP_2_SEC_CACHED_MODOEL = lambda pm, m:f'Secondary model ({pm}: {m}) cache loaded.\n' +INFERENCE_STEP_2_PRE_CACHED_MODOEL = lambda pm, m:f'Pre-process model ({pm}: {m}) cache loaded.\n' +INFERENCE_STEP_2_SEC_CACHED = 'Loading cached secondary model source(s)... Done!\n' +INFERENCE_STEP_2_PRIMARY_CACHED = 'Model cache loaded.\n' +INFERENCE_STEP_2 = 'Inference complete.' +SAVING_STEM = 'Saving ', ' stem...' +SAVING_ALL_STEMS = 'Saving all stems...' +ENSEMBLING_OUTPUTS = 'Ensembling outputs...' +DONE = ' Done!\n' +ENSEMBLES_SAVED = 'Ensembled outputs saved!\n\n' +NEW_LINES = "\n\n" +NEW_LINE = "\n" +NO_LINE = '' + +# Widget Placements + +MAIN_ROW_Y = -15, -17 +MAIN_ROW_X = -4, 21 +MAIN_ROW_WIDTH = -53 +MAIN_ROW_2_Y = -15, -17 +MAIN_ROW_2_X = -28, 1 +CHECK_BOX_Y = 0 +CHECK_BOX_X = 20 +CHECK_BOX_WIDTH = -50 +CHECK_BOX_HEIGHT = 2 +LEFT_ROW_WIDTH = -10 +LABEL_HEIGHT = -5 +OPTION_HEIGHT = 7 +LOW_MENU_Y = 18, 16 +FFMPEG_EXT = (".aac", ".aiff", ".alac" ,".flac", ".FLAC", ".mov", ".mp4", ".MP4", + ".m4a", ".M4A", ".mp2", ".mp3", "MP3", ".mpc", ".mpc8", + ".mpeg", ".ogg", ".OGG", ".tta", ".wav", ".wave", ".WAV", ".WAVE", ".wma", ".webm", ".eac3", ".mkv") + +FFMPEG_MORE_EXT = (".aa", ".aac", ".ac3", ".aiff", ".alac", ".avi", ".f4v",".flac", ".flic", ".flv", + ".m4v",".mlv", ".mov", ".mp4", ".m4a", ".mp2", ".mp3", ".mp4", ".mpc", ".mpc8", + ".mpeg", ".ogg", ".tta", ".tty", ".vcd", ".wav", ".wma") +ANY_EXT = "" + +# Secondary Menu Constants + +VOCAL_PAIR_PLACEMENT = 1, 2, 3, 4 +OTHER_PAIR_PLACEMENT = 5, 6, 7, 8 +BASS_PAIR_PLACEMENT = 9, 10, 11, 12 +DRUMS_PAIR_PLACEMENT = 13, 14, 15, 16 + +# Drag n Drop String Checks + +DOUBLE_BRACKET = "} {" +RIGHT_BRACKET = "}" +LEFT_BRACKET = "{" + +# Manual Downloads + +VR_PLACEMENT_TEXT = 'Place models in \"models/VR_Models\" directory.' +MDX_PLACEMENT_TEXT = 'Place models in \"models/MDX_Net_Models\" directory.' +DEMUCS_PLACEMENT_TEXT = 'Place models in \"models/Demucs_Models\" directory.' +DEMUCS_V3_V4_PLACEMENT_TEXT = 'Place items in \"models/Demucs_Models/v3_v4_repo\" directory.' + +FULL_DOWNLOAD_LIST_VR = { + "VR Arch Single Model v5: 1_HP-UVR": "1_HP-UVR.pth", + "VR Arch Single Model v5: 2_HP-UVR": "2_HP-UVR.pth", + "VR Arch Single Model v5: 3_HP-Vocal-UVR": "3_HP-Vocal-UVR.pth", + "VR Arch Single Model v5: 4_HP-Vocal-UVR": "4_HP-Vocal-UVR.pth", + "VR Arch Single Model v5: 5_HP-Karaoke-UVR": "5_HP-Karaoke-UVR.pth", + "VR Arch Single Model v5: 6_HP-Karaoke-UVR": "6_HP-Karaoke-UVR.pth", + "VR Arch Single Model v5: 7_HP2-UVR": "7_HP2-UVR.pth", + "VR Arch Single Model v5: 8_HP2-UVR": "8_HP2-UVR.pth", + "VR Arch Single Model v5: 9_HP2-UVR": "9_HP2-UVR.pth", + "VR Arch Single Model v5: 10_SP-UVR-2B-32000-1": "10_SP-UVR-2B-32000-1.pth", + "VR Arch Single Model v5: 11_SP-UVR-2B-32000-2": "11_SP-UVR-2B-32000-2.pth", + "VR Arch Single Model v5: 12_SP-UVR-3B-44100": "12_SP-UVR-3B-44100.pth", + "VR Arch Single Model v5: 13_SP-UVR-4B-44100-1": "13_SP-UVR-4B-44100-1.pth", + "VR Arch Single Model v5: 14_SP-UVR-4B-44100-2": "14_SP-UVR-4B-44100-2.pth", + "VR Arch Single Model v5: 15_SP-UVR-MID-44100-1": "15_SP-UVR-MID-44100-1.pth", + "VR Arch Single Model v5: 16_SP-UVR-MID-44100-2": "16_SP-UVR-MID-44100-2.pth", + "VR Arch Single Model v4: MGM_HIGHEND_v4": "MGM_HIGHEND_v4.pth", + "VR Arch Single Model v4: MGM_LOWEND_A_v4": "MGM_LOWEND_A_v4.pth", + "VR Arch Single Model v4: MGM_LOWEND_B_v4": "MGM_LOWEND_B_v4.pth", + "VR Arch Single Model v4: MGM_MAIN_v4": "MGM_MAIN_v4.pth" + } + +FULL_DOWNLOAD_LIST_MDX = { + "MDX-Net Model: UVR-MDX-NET Main": "UVR_MDXNET_Main.onnx", + "MDX-Net Model: UVR-MDX-NET Inst Main": "UVR-MDX-NET-Inst_Main.onnx", + "MDX-Net Model: UVR-MDX-NET 1": "UVR_MDXNET_1_9703.onnx", + "MDX-Net Model: UVR-MDX-NET 2": "UVR_MDXNET_2_9682.onnx", + "MDX-Net Model: UVR-MDX-NET 3": "UVR_MDXNET_3_9662.onnx", + "MDX-Net Model: UVR-MDX-NET Inst 1": "UVR-MDX-NET-Inst_1.onnx", + "MDX-Net Model: UVR-MDX-NET Inst 2": "UVR-MDX-NET-Inst_2.onnx", + "MDX-Net Model: UVR-MDX-NET Inst 3": "UVR-MDX-NET-Inst_3.onnx", + "MDX-Net Model: UVR-MDX-NET Karaoke": "UVR_MDXNET_KARA.onnx", + "MDX-Net Model: UVR_MDXNET_9482": "UVR_MDXNET_9482.onnx", + "MDX-Net Model: Kim_Vocal_1": "Kim_Vocal_1.onnx", + "MDX-Net Model: kuielab_a_vocals": "kuielab_a_vocals.onnx", + "MDX-Net Model: kuielab_a_other": "kuielab_a_other.onnx", + "MDX-Net Model: kuielab_a_bass": "kuielab_a_bass.onnx", + "MDX-Net Model: kuielab_a_drums": "kuielab_a_drums.onnx", + "MDX-Net Model: kuielab_b_vocals": "kuielab_b_vocals.onnx", + "MDX-Net Model: kuielab_b_other": "kuielab_b_other.onnx", + "MDX-Net Model: kuielab_b_bass": "kuielab_b_bass.onnx", + "MDX-Net Model: kuielab_b_drums": "kuielab_b_drums.onnx"} + +FULL_DOWNLOAD_LIST_DEMUCS = { + + "Demucs v4: htdemucs_ft":{ + "f7e0c4bc-ba3fe64a.th":"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th", + "d12395a8-e57c48e6.th":"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/d12395a8-e57c48e6.th", + "92cfc3b6-ef3bcb9c.th":"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/92cfc3b6-ef3bcb9c.th", + "04573f0d-f3cf25b2.th":"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th", + "htdemucs_ft.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/htdemucs_ft.yaml" + }, + + "Demucs v4: htdemucs":{ + "955717e8-8726e21a.th": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/955717e8-8726e21a.th", + "htdemucs.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/htdemucs.yaml" + }, + + "Demucs v4: hdemucs_mmi":{ + "75fc33f5-1941ce65.th": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/75fc33f5-1941ce65.th", + "hdemucs_mmi.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/hdemucs_mmi.yaml" + }, + "Demucs v4: htdemucs_6s":{ + "5c90dfd2-34c22ccb.th": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/5c90dfd2-34c22ccb.th", + "htdemucs_6s.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/htdemucs_6s.yaml" + }, + "Demucs v3: mdx":{ + "0d19c1c6-0f06f20e.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/0d19c1c6-0f06f20e.th", + "7ecf8ec1-70f50cc9.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/7ecf8ec1-70f50cc9.th", + "c511e2ab-fe698775.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/c511e2ab-fe698775.th", + "7d865c68-3d5dd56b.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/7d865c68-3d5dd56b.th", + "mdx.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/mdx.yaml" + }, + + "Demucs v3: mdx_q":{ + "6b9c2ca1-3fd82607.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/6b9c2ca1-3fd82607.th", + "b72baf4e-8778635e.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/b72baf4e-8778635e.th", + "42e558d4-196e0e1b.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/42e558d4-196e0e1b.th", + "305bc58f-18378783.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/305bc58f-18378783.th", + "mdx_q.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/mdx_q.yaml" + }, + + "Demucs v3: mdx_extra":{ + "e51eebcc-c1b80bdd.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/e51eebcc-c1b80bdd.th", + "a1d90b5c-ae9d2452.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/a1d90b5c-ae9d2452.th", + "5d2d6c55-db83574e.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/5d2d6c55-db83574e.th", + "cfa93e08-61801ae1.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/cfa93e08-61801ae1.th", + "mdx_extra.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/mdx_extra.yaml" + }, + + "Demucs v3: mdx_extra_q": { + "83fc094f-4a16d450.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/83fc094f-4a16d450.th", + "464b36d7-e5a9386e.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/464b36d7-e5a9386e.th", + "14fc6a69-a89dd0ee.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/14fc6a69-a89dd0ee.th", + "7fd6ef75-a905dd85.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/7fd6ef75-a905dd85.th", + "mdx_extra_q.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/mdx_extra_q.yaml" + }, + + "Demucs v3: UVR Model":{ + "ebf34a2db.th": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/ebf34a2db.th", + "UVR_Demucs_Model_1.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/UVR_Demucs_Model_1.yaml" + }, + + "Demucs v3: repro_mdx_a":{ + "9a6b4851-03af0aa6.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/9a6b4851-03af0aa6.th", + "1ef250f1-592467ce.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/1ef250f1-592467ce.th", + "fa0cb7f9-100d8bf4.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/fa0cb7f9-100d8bf4.th", + "902315c2-b39ce9c9.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/902315c2-b39ce9c9.th", + "repro_mdx_a.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/repro_mdx_a.yaml" + }, + + "Demucs v3: repro_mdx_a_time_only":{ + "9a6b4851-03af0aa6.th":"https://dl.fbaipublicfiles.com/demucs/mdx_final/9a6b4851-03af0aa6.th", + "1ef250f1-592467ce.th":"https://dl.fbaipublicfiles.com/demucs/mdx_final/1ef250f1-592467ce.th", + "repro_mdx_a_time_only.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/repro_mdx_a_time_only.yaml" + }, + + "Demucs v3: repro_mdx_a_hybrid_only":{ + "fa0cb7f9-100d8bf4.th":"https://dl.fbaipublicfiles.com/demucs/mdx_final/fa0cb7f9-100d8bf4.th", + "902315c2-b39ce9c9.th":"https://dl.fbaipublicfiles.com/demucs/mdx_final/902315c2-b39ce9c9.th", + "repro_mdx_a_hybrid_only.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/repro_mdx_a_hybrid_only.yaml" + }, + + "Demucs v2: demucs": { + "demucs-e07c671f.th": "https://dl.fbaipublicfiles.com/demucs/v3.0/demucs-e07c671f.th" + }, + + "Demucs v2: demucs_extra": { + "demucs_extra-3646af93.th":"https://dl.fbaipublicfiles.com/demucs/v3.0/demucs_extra-3646af93.th" + }, + + "Demucs v2: demucs48_hq": { + "demucs48_hq-28a1282c.th":"https://dl.fbaipublicfiles.com/demucs/v3.0/demucs48_hq-28a1282c.th" + }, + + "Demucs v2: tasnet": { + "tasnet-beb46fac.th":"https://dl.fbaipublicfiles.com/demucs/v3.0/tasnet-beb46fac.th" + }, + + "Demucs v2: tasnet_extra": { + "tasnet_extra-df3777b2.th":"https://dl.fbaipublicfiles.com/demucs/v3.0/tasnet_extra-df3777b2.th" + }, + + "Demucs v2: demucs_unittest": { + "demucs_unittest-09ebc15f.th":"https://dl.fbaipublicfiles.com/demucs/v3.0/demucs_unittest-09ebc15f.th" + }, + + "Demucs v1: demucs": { + "demucs.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/demucs.th" + }, + + "Demucs v1: demucs_extra": { + "demucs_extra.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/demucs_extra.th" + }, + + "Demucs v1: light": { + "light.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/light.th" + }, + + "Demucs v1: light_extra": { + "light_extra.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/light_extra.th" + }, + + "Demucs v1: tasnet": { + "tasnet.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/tasnet.th" + }, + + "Demucs v1: tasnet_extra": { + "tasnet_extra.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/tasnet_extra.th" + } + } + +# Main Menu Labels + +CHOOSE_PROC_METHOD_MAIN_LABEL = 'CHOOSE PROCESS METHOD (เลือก "MDX-Net")' +SELECT_SAVED_SETTINGS_MAIN_LABEL = 'SELECT SAVED SETTINGS' +CHOOSE_MDX_MODEL_MAIN_LABEL = 'CHOOSE MDX-NET MODEL (เลือก "UVR_MDXNET_Main" - แนะนำ)' +BATCHES_MDX_MAIN_LABEL = 'BATCH SIZE' +VOL_COMP_MDX_MAIN_LABEL = 'VOLUME COMPENSATION' +SELECT_VR_MODEL_MAIN_LABEL = 'CHOOSE VR MODEL' +AGGRESSION_SETTING_MAIN_LABEL = 'AGGRESSION SETTING' +WINDOW_SIZE_MAIN_LABEL = 'WINDOW SIZE' +CHOOSE_DEMUCS_MODEL_MAIN_LABEL = 'CHOOSE DEMUCS MODEL' +CHOOSE_DEMUCS_STEMS_MAIN_LABEL = 'CHOOSE STEM(S)' +CHOOSE_SEGMENT_MAIN_LABEL = 'SEGMENT' +ENSEMBLE_OPTIONS_MAIN_LABEL = 'ENSEMBLE OPTIONS' +CHOOSE_MAIN_PAIR_MAIN_LABEL = 'MAIN STEM PAIR' +CHOOSE_ENSEMBLE_ALGORITHM_MAIN_LABEL = 'ENSEMBLE ALGORITHM' +AVAILABLE_MODELS_MAIN_LABEL = 'AVAILABLE MODELS' +CHOOSE_AUDIO_TOOLS_MAIN_LABEL = 'CHOOSE AUDIO TOOL' +CHOOSE_MANUAL_ALGORITHM_MAIN_LABEL = 'CHOOSE ALGORITHM' +CHOOSE_RATE_MAIN_LABEL = 'RATE' +CHOOSE_SEMITONES_MAIN_LABEL = 'SEMITONES' +GPU_CONVERSION_MAIN_LABEL = 'GPU Conversion' + +if OPERATING_SYSTEM=="Darwin": + LICENSE_OS_SPECIFIC_TEXT = '• This application is intended for those running macOS Catalina and above.\n' +\ + '• Application functionality for systems running macOS Mojave or lower is not guaranteed.\n' +\ + '• Application functionality for older or budget Mac systems is not guaranteed.\n\n' + FONT_SIZE_F1 = 13 + FONT_SIZE_F2 = 11 + FONT_SIZE_F3 = 12 + FONT_SIZE_0 = 9 + FONT_SIZE_1 = 11 + FONT_SIZE_2 = 12 + FONT_SIZE_3 = 13 + FONT_SIZE_4 = 14 + FONT_SIZE_5 = 15 + FONT_SIZE_6 = 17 + HELP_HINT_CHECKBOX_WIDTH = 13 + MDX_CHECKBOXS_WIDTH = 14 + VR_CHECKBOXS_WIDTH = 14 + ENSEMBLE_CHECKBOXS_WIDTH = 18 + DEMUCS_CHECKBOXS_WIDTH = 14 + DEMUCS_PRE_CHECKBOXS_WIDTH = 20 + GEN_SETTINGS_WIDTH = 17 + MENU_COMBOBOX_WIDTH = 16 + +elif OPERATING_SYSTEM=="Linux": + LICENSE_OS_SPECIFIC_TEXT = '• This application is intended for those running Linux Ubuntu 18.04+.\n' +\ + '• Application functionality for systems running other Linux platforms is not guaranteed.\n' +\ + '• Application functionality for older or budget systems is not guaranteed.\n\n' + FONT_SIZE_F1 = 10 + FONT_SIZE_F2 = 8 + FONT_SIZE_F3 = 9 + FONT_SIZE_0 = 7 + FONT_SIZE_1 = 8 + FONT_SIZE_2 = 9 + FONT_SIZE_3 = 10 + FONT_SIZE_4 = 11 + FONT_SIZE_5 = 12 + FONT_SIZE_6 = 15 + HELP_HINT_CHECKBOX_WIDTH = 13 + MDX_CHECKBOXS_WIDTH = 14 + VR_CHECKBOXS_WIDTH = 16 + ENSEMBLE_CHECKBOXS_WIDTH = 25 + DEMUCS_CHECKBOXS_WIDTH = 18 + DEMUCS_PRE_CHECKBOXS_WIDTH = 27 + GEN_SETTINGS_WIDTH = 17 + MENU_COMBOBOX_WIDTH = 19 + +elif OPERATING_SYSTEM=="Windows": + LICENSE_OS_SPECIFIC_TEXT = '• This application is intended for those running Windows 10 or higher.\n' +\ + '• Application functionality for systems running Windows 7 or lower is not guaranteed.\n' +\ + '• Application functionality for Intel Pentium & Celeron CPUs systems is not guaranteed.\n\n' + FONT_SIZE_F1 = 10 + FONT_SIZE_F2 = 8 + FONT_SIZE_F3 = 9 + FONT_SIZE_0 = 7 + FONT_SIZE_1 = 8 + FONT_SIZE_2 = 9 + FONT_SIZE_3 = 10 + FONT_SIZE_4 = 11 + FONT_SIZE_5 = 12 + FONT_SIZE_6 = 15 + HELP_HINT_CHECKBOX_WIDTH = 16 + MDX_CHECKBOXS_WIDTH = 16 + VR_CHECKBOXS_WIDTH = 16 + ENSEMBLE_CHECKBOXS_WIDTH = 25 + DEMUCS_CHECKBOXS_WIDTH = 18 + DEMUCS_PRE_CHECKBOXS_WIDTH = 27 + GEN_SETTINGS_WIDTH = 23 + MENU_COMBOBOX_WIDTH = 19 + + +LICENSE_TEXT = lambda a, p:f'Current Application Version: Ultimate Vocal Remover {a}\n' +\ + f'Current Patch Version: {p}\n\n' +\ + 'Copyright (c) 2022 Ultimate Vocal Remover\n\n' +\ + 'UVR is free and open-source, but MIT licensed. Please credit us if you use our\n' +\ + f'models or code for projects unrelated to UVR.\n\n{LICENSE_OS_SPECIFIC_TEXT}' +\ + 'This bundle contains the UVR interface, Python, PyTorch, and other\n' +\ + 'dependencies needed to run the application effectively.\n\n' +\ + 'Website Links: This application, System or Service(s) may contain links to\n' +\ + 'other websites and downloads, and they are solely provided to you as an\n' +\ + 'additional convenience. You understand and acknowledge that by clicking\n' +\ + 'or activating such links you are accessing a site or service outside of\n' +\ + 'this application, and that we do not screen, review, approve, or otherwise\n' +\ + 'endorse any content or information contained in these linked websites.\n' +\ + 'You acknowledge and agree that we, our affiliates and partners are not\n' +\ + 'responsible for the contents of any of these linked websites, including\n' +\ + 'the accuracy or availability of information provided by the linked websites,\n' +\ + 'and we make no representations or warranties regarding your use of\n' +\ + 'the linked websites.\n\n' +\ + 'This application is MIT Licensed\n\n' +\ + 'Permission is hereby granted, free of charge, to any person obtaining a copy\n' +\ + 'of this software and associated documentation files (the "Software"), to deal\n' +\ + 'in the Software without restriction, including without limitation the rights\n' +\ + 'to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n' +\ + 'copies of the Software, and to permit persons to whom the Software is\n' +\ + 'furnished to do so, subject to the following conditions:\n\n' +\ + 'The above copyright notice and this permission notice shall be included in all\n' +\ + 'copies or substantial portions of the Software.\n\n' +\ + 'THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n' +\ + 'IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n' +\ + 'FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n' +\ + 'AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n' +\ + 'LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n' +\ + 'OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n' +\ + 'SOFTWARE.' + +CHANGE_LOG_HEADER = lambda patch:f"Patch Version:\n\n{patch}" + +#DND CONSTS + +MAC_DND_CHECK = ('/Users/', + '/Applications/', + '/Library/', + '/System/') +LINUX_DND_CHECK = ('/home/', + '/usr/') +WINDOWS_DND_CHECK = ('A:', 'B:', 'C:', 'D:', 'E:', 'F:', 'G:', 'H:', 'I:', 'J:', 'K:', 'L:', 'M:', 'N:', 'O:', 'P:', 'Q:', 'R:', 'S:', 'T:', 'U:', 'V:', 'W:', 'X:', 'Y:', 'Z:') + +WOOD_INST_MODEL_HASH = '0ec76fd9e65f81d8b4fbd13af4826ed8' +WOOD_INST_PARAMS = { + "vr_model_param": "4band_v3", + "primary_stem": NO_WIND_INST_STEM + } \ No newline at end of file diff --git a/gui_data/error_handling.py b/gui_data/error_handling.py new file mode 100644 index 0000000000000000000000000000000000000000..366d77d44420afe0cfa292dc879361b742bc6417 --- /dev/null +++ b/gui_data/error_handling.py @@ -0,0 +1,106 @@ +from datetime import datetime +import traceback + +CUDA_MEMORY_ERROR = "CUDA out of memory" +CUDA_RUNTIME_ERROR = "CUDNN error executing cudnnSetTensorNdDescriptor" +DEMUCS_MODEL_MISSING_ERROR = "is neither a single pre-trained model or a bag of models." +ENSEMBLE_MISSING_MODEL_ERROR = "local variable \'enseExport\' referenced before assignment" +FFMPEG_MISSING_ERROR = """audioread\__init__.py", line 116, in audio_open""" +FILE_MISSING_ERROR = "FileNotFoundError" +MDX_MEMORY_ERROR = "onnxruntime::CudaCall CUDA failure 2: out of memory" +MDX_MODEL_MISSING = "[ONNXRuntimeError] : 3 : NO_SUCHFILE" +MDX_MODEL_SETTINGS_ERROR = "Got invalid dimensions for input" +MDX_RUNTIME_ERROR = "onnxruntime::BFCArena::AllocateRawInternal" +MODULE_ERROR = "ModuleNotFoundError" +WINDOW_SIZE_ERROR = "h1_shape[3] must be greater than h2_shape[3]" +SF_WRITE_ERROR = "sf.write" +SYSTEM_MEMORY_ERROR = "DefaultCPUAllocator: not enough memory" +MISSING_MODEL_ERROR = "'NoneType\' object has no attribute \'model_basename\'" +ARRAY_SIZE_ERROR = "ValueError: \"array is too big; `arr.size * arr.dtype.itemsize` is larger than the maximum possible size.\"" +GPU_INCOMPATIBLE_ERROR = "no kernel image is available for execution on the device" + +CONTACT_DEV = 'If this error persists, please contact the developers with the error details.' + +ERROR_MAPPER = { + CUDA_MEMORY_ERROR: + ('The application was unable to allocate enough GPU memory to use this model. ' + + 'Please close any GPU intensive applications and try again.\n' + + 'If the error persists, your GPU might not be supported.') , + CUDA_RUNTIME_ERROR: + (f'Your PC cannot process this audio file with the chunk size selected. Please lower the chunk size and try again.\n\n{CONTACT_DEV}'), + DEMUCS_MODEL_MISSING_ERROR: + ('The selected Demucs model is missing. ' + + 'Please download the model or make sure it is in the correct directory.'), + ENSEMBLE_MISSING_MODEL_ERROR: + ('The application was unable to locate a model you selected for this ensemble.\n\n' + + 'Please do the following to use all compatible models:\n\n1. Navigate to the \"Updates\" tab in the Help Guide.\n2. Download and install the model expansion pack.\n3. Then try again.\n\n' + + 'If the error persists, please verify all models are present.'), + FFMPEG_MISSING_ERROR: + ('The input file type is not supported or FFmpeg is missing. Please select a file type supported by FFmpeg and try again. ' + + 'If FFmpeg is missing or not installed, you will only be able to process \".wav\" files until it is available on this system. ' + + f'See the \"More Info\" tab in the Help Guide.\n\n{CONTACT_DEV}'), + FILE_MISSING_ERROR: + (f'Missing file error raised. Please address the error and try again.\n\n{CONTACT_DEV}'), + MDX_MEMORY_ERROR: + ('The application was unable to allocate enough GPU memory to use this model.\n\n' + + 'Please do the following:\n\n1. Close any GPU intensive applications.\n2. Lower the set chunk size.\n3. Then try again.\n\n' + + 'If the error persists, your GPU might not be supported.'), + MDX_MODEL_MISSING: + ('The application could not detect this MDX-Net model on your system. ' + + 'Please make sure all the models are present in the correct directory.\n\n' + + 'If the error persists, please reinstall application or contact the developers.'), + MDX_RUNTIME_ERROR: + ('The application was unable to allocate enough GPU memory to use this model.\n\n' + + 'Please do the following:\n\n1. Close any GPU intensive applications.\n2. Lower the set chunk size.\n3. Then try again.\n\n' + + 'If the error persists, your GPU might not be supported.'), + WINDOW_SIZE_ERROR: + ('Invalid window size.\n\n' + + 'The chosen window size is likely not compatible with this model. Please select a different size and try again.'), + SF_WRITE_ERROR: + ('Could not write audio file.\n\n' + + 'This could be due to one of the following:\n\n1. Low storage on target device.\n2. The export directory no longer exists.\n3. A system permissions issue.'), + SYSTEM_MEMORY_ERROR: + ('The application was unable to allocate enough system memory to use this model.\n\n' + + 'Please do the following:\n\n1. Restart this application.\n2. Ensure any CPU intensive applications are closed.\n3. Then try again.\n\n' + + 'Please Note: Intel Pentium and Intel Celeron processors do not work well with this application.\n\n' + + 'If the error persists, the system may not have enough RAM, or your CPU might not be supported.'), + MISSING_MODEL_ERROR: + ('Model Missing: The application was unable to locate the chosen model.\n\n' + + 'If the error persists, please verify any selected models are present.'), + GPU_INCOMPATIBLE_ERROR: + ('This process is not compatible with your GPU.\n\n' + + 'Please uncheck \"GPU Conversion\" and try again'), + ARRAY_SIZE_ERROR: + ('The application was not able to process the given audiofile. Please convert the audiofile to another format and try again.'), +} + +def error_text(process_method, exception): + + traceback_text = ''.join(traceback.format_tb(exception.__traceback__)) + message = f'{type(exception).__name__}: "{exception}"\nTraceback Error: "\n{traceback_text}"\n' + error_message = f'\n\nRaw Error Details:\n\n{message}\nError Time Stamp [{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]\n' + process = f'Last Error Received:\n\nProcess: {process_method}\n\n' + + for error_type, full_text in ERROR_MAPPER.items(): + if error_type in message: + final_message = full_text + break + else: + final_message = (CONTACT_DEV) + + return f"{process}{final_message}{error_message}" + +def error_dialouge(exception): + + error_name = f'{type(exception).__name__}' + traceback_text = ''.join(traceback.format_tb(exception.__traceback__)) + message = f'{error_name}: "{exception}"\n{traceback_text}"' + + for error_type, full_text in ERROR_MAPPER.items(): + if error_type in message: + final_message = full_text + break + else: + final_message = (f'{error_name}: {exception}\n\n{CONTACT_DEV}') + + return final_message diff --git a/gui_data/old_data_check.py b/gui_data/old_data_check.py new file mode 100644 index 0000000000000000000000000000000000000000..1694baedc64de4b240bd3afb4d2e35b4e2dd65e3 --- /dev/null +++ b/gui_data/old_data_check.py @@ -0,0 +1,27 @@ +import os +import shutil + +def file_check(original_dir, new_dir): + + if os.path.isdir(original_dir): + for file in os.listdir(original_dir): + shutil.move(os.path.join(original_dir, file), os.path.join(new_dir, file)) + + if len(os.listdir(original_dir)) == 0: + shutil.rmtree(original_dir) + +def remove_unneeded_yamls(demucs_dir): + + for file in os.listdir(demucs_dir): + if file.endswith('.yaml'): + if os.path.isfile(os.path.join(demucs_dir, file)): + os.remove(os.path.join(demucs_dir, file)) + +def remove_temps(remove_dir): + + if os.path.isdir(remove_dir): + try: + shutil.rmtree(remove_dir) + except Exception as e: + print(e) + \ No newline at end of file diff --git a/lib_v5/mdxnet.py b/lib_v5/mdxnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c0a61fee819286cf26257a9a5f91038a0097d1b2 --- /dev/null +++ b/lib_v5/mdxnet.py @@ -0,0 +1,140 @@ +from abc import ABCMeta + +import torch +import torch.nn as nn +from pytorch_lightning import LightningModule +from .modules import TFC_TDF + +dim_s = 4 + +class AbstractMDXNet(LightningModule): + __metaclass__ = ABCMeta + + def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap): + super().__init__() + self.target_name = target_name + self.lr = lr + self.optimizer = optimizer + self.dim_c = dim_c + self.dim_f = dim_f + self.dim_t = dim_t + self.n_fft = n_fft + self.n_bins = n_fft // 2 + 1 + self.hop_length = hop_length + self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False) + self.freq_pad = nn.Parameter(torch.zeros([1, dim_c, self.n_bins - self.dim_f, self.dim_t]), requires_grad=False) + + def configure_optimizers(self): + if self.optimizer == 'rmsprop': + return torch.optim.RMSprop(self.parameters(), self.lr) + + if self.optimizer == 'adamw': + return torch.optim.AdamW(self.parameters(), self.lr) + +class ConvTDFNet(AbstractMDXNet): + def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, + num_blocks, l, g, k, bn, bias, overlap): + + super(ConvTDFNet, self).__init__( + target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap) + self.save_hyperparameters() + + self.num_blocks = num_blocks + self.l = l + self.g = g + self.k = k + self.bn = bn + self.bias = bias + + if optimizer == 'rmsprop': + norm = nn.BatchNorm2d + + if optimizer == 'adamw': + norm = lambda input:nn.GroupNorm(2, input) + + self.n = num_blocks // 2 + scale = (2, 2) + + self.first_conv = nn.Sequential( + nn.Conv2d(in_channels=self.dim_c, out_channels=g, kernel_size=(1, 1)), + norm(g), + nn.ReLU(), + ) + + f = self.dim_f + c = g + self.encoding_blocks = nn.ModuleList() + self.ds = nn.ModuleList() + for i in range(self.n): + self.encoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)) + self.ds.append( + nn.Sequential( + nn.Conv2d(in_channels=c, out_channels=c + g, kernel_size=scale, stride=scale), + norm(c + g), + nn.ReLU() + ) + ) + f = f // 2 + c += g + + self.bottleneck_block = TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm) + + self.decoding_blocks = nn.ModuleList() + self.us = nn.ModuleList() + for i in range(self.n): + self.us.append( + nn.Sequential( + nn.ConvTranspose2d(in_channels=c, out_channels=c - g, kernel_size=scale, stride=scale), + norm(c - g), + nn.ReLU() + ) + ) + f = f * 2 + c -= g + + self.decoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)) + + self.final_conv = nn.Sequential( + nn.Conv2d(in_channels=c, out_channels=self.dim_c, kernel_size=(1, 1)), + ) + + def forward(self, x): + + x = self.first_conv(x) + + x = x.transpose(-1, -2) + + ds_outputs = [] + for i in range(self.n): + x = self.encoding_blocks[i](x) + ds_outputs.append(x) + x = self.ds[i](x) + + x = self.bottleneck_block(x) + + for i in range(self.n): + x = self.us[i](x) + x *= ds_outputs[-i - 1] + x = self.decoding_blocks[i](x) + + x = x.transpose(-1, -2) + + x = self.final_conv(x) + + return x + +class Mixer(nn.Module): + def __init__(self, device, mixer_path): + + super(Mixer, self).__init__() + + self.linear = nn.Linear((dim_s+1)*2, dim_s*2, bias=False) + + self.load_state_dict( + torch.load(mixer_path, map_location=device) + ) + + def forward(self, x): + x = x.reshape(1,(dim_s+1)*2,-1).transpose(-1,-2) + x = self.linear(x) + return x.transpose(-1,-2).reshape(dim_s,2,-1) \ No newline at end of file diff --git a/lib_v5/mixer.ckpt b/lib_v5/mixer.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..309325e0e0c3d309a7876812b1fa5edd835600ff --- /dev/null +++ b/lib_v5/mixer.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:35121bf49b892b112b073b605c86efa1812bc37c9603c0fbc1726f4320c2aa91 +size 132 diff --git a/lib_v5/modules.py b/lib_v5/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..4e77d2fb5b97c4ca0e6f6011e012f43e03a70b14 --- /dev/null +++ b/lib_v5/modules.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn + + +class TFC(nn.Module): + def __init__(self, c, l, k, norm): + super(TFC, self).__init__() + + self.H = nn.ModuleList() + for i in range(l): + self.H.append( + nn.Sequential( + nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2), + norm(c), + nn.ReLU(), + ) + ) + + def forward(self, x): + for h in self.H: + x = h(x) + return x + + +class DenseTFC(nn.Module): + def __init__(self, c, l, k, norm): + super(DenseTFC, self).__init__() + + self.conv = nn.ModuleList() + for i in range(l): + self.conv.append( + nn.Sequential( + nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2), + norm(c), + nn.ReLU(), + ) + ) + + def forward(self, x): + for layer in self.conv[:-1]: + x = torch.cat([layer(x), x], 1) + return self.conv[-1](x) + + +class TFC_TDF(nn.Module): + def __init__(self, c, l, f, k, bn, dense=False, bias=True, norm=nn.BatchNorm2d): + + super(TFC_TDF, self).__init__() + + self.use_tdf = bn is not None + + self.tfc = DenseTFC(c, l, k, norm) if dense else TFC(c, l, k, norm) + + if self.use_tdf: + if bn == 0: + self.tdf = nn.Sequential( + nn.Linear(f, f, bias=bias), + norm(c), + nn.ReLU() + ) + else: + self.tdf = nn.Sequential( + nn.Linear(f, f // bn, bias=bias), + norm(c), + nn.ReLU(), + nn.Linear(f // bn, f, bias=bias), + norm(c), + nn.ReLU() + ) + + def forward(self, x): + x = self.tfc(x) + return x + self.tdf(x) if self.use_tdf else x + diff --git a/lib_v5/pyrb.py b/lib_v5/pyrb.py new file mode 100644 index 0000000000000000000000000000000000000000..883a525ee28351b2d99f674dcc721af3f852f87f --- /dev/null +++ b/lib_v5/pyrb.py @@ -0,0 +1,92 @@ +import os +import subprocess +import tempfile +import six +import numpy as np +import soundfile as sf +import sys + +if getattr(sys, 'frozen', False): + BASE_PATH_RUB = sys._MEIPASS +else: + BASE_PATH_RUB = os.path.dirname(os.path.abspath(__file__)) + +__all__ = ['time_stretch', 'pitch_shift'] + +__RUBBERBAND_UTIL = os.path.join(BASE_PATH_RUB, 'rubberband') + +if six.PY2: + DEVNULL = open(os.devnull, 'w') +else: + DEVNULL = subprocess.DEVNULL + +def __rubberband(y, sr, **kwargs): + + assert sr > 0 + + # Get the input and output tempfile + fd, infile = tempfile.mkstemp(suffix='.wav') + os.close(fd) + fd, outfile = tempfile.mkstemp(suffix='.wav') + os.close(fd) + + # dump the audio + sf.write(infile, y, sr) + + try: + # Execute rubberband + arguments = [__RUBBERBAND_UTIL, '-q'] + + for key, value in six.iteritems(kwargs): + arguments.append(str(key)) + arguments.append(str(value)) + + arguments.extend([infile, outfile]) + + subprocess.check_call(arguments, stdout=DEVNULL, stderr=DEVNULL) + + # Load the processed audio. + y_out, _ = sf.read(outfile, always_2d=True) + + # make sure that output dimensions matches input + if y.ndim == 1: + y_out = np.squeeze(y_out) + + except OSError as exc: + six.raise_from(RuntimeError('Failed to execute rubberband. ' + 'Please verify that rubberband-cli ' + 'is installed.'), + exc) + + finally: + # Remove temp files + os.unlink(infile) + os.unlink(outfile) + + return y_out + +def time_stretch(y, sr, rate, rbargs=None): + if rate <= 0: + raise ValueError('rate must be strictly positive') + + if rate == 1.0: + return y + + if rbargs is None: + rbargs = dict() + + rbargs.setdefault('--tempo', rate) + + return __rubberband(y, sr, **rbargs) + +def pitch_shift(y, sr, n_steps, rbargs=None): + + if n_steps == 0: + return y + + if rbargs is None: + rbargs = dict() + + rbargs.setdefault('--pitch', n_steps) + + return __rubberband(y, sr, **rbargs) diff --git a/lib_v5/spec_utils.py b/lib_v5/spec_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a89a050e459f879e2a0fe9c34e544362772cc3b1 --- /dev/null +++ b/lib_v5/spec_utils.py @@ -0,0 +1,692 @@ +import librosa +import numpy as np +import soundfile as sf +import math +import random +import math +import platform +import traceback +from . import pyrb +#cur +OPERATING_SYSTEM = platform.system() +SYSTEM_ARCH = platform.platform() +SYSTEM_PROC = platform.processor() +ARM = 'arm' + +if OPERATING_SYSTEM == 'Windows': + from pyrubberband import pyrb +else: + from . import pyrb + +if OPERATING_SYSTEM == 'Darwin': + wav_resolution = "polyphase" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else "sinc_fastest" +else: + wav_resolution = "sinc_fastest" + +MAX_SPEC = 'Max Spec' +MIN_SPEC = 'Min Spec' +AVERAGE = 'Average' + +def crop_center(h1, h2): + h1_shape = h1.size() + h2_shape = h2.size() + + if h1_shape[3] == h2_shape[3]: + return h1 + elif h1_shape[3] < h2_shape[3]: + raise ValueError('h1_shape[3] must be greater than h2_shape[3]') + + s_time = (h1_shape[3] - h2_shape[3]) // 2 + e_time = s_time + h2_shape[3] + h1 = h1[:, :, :, s_time:e_time] + + return h1 + +def preprocess(X_spec): + X_mag = np.abs(X_spec) + X_phase = np.angle(X_spec) + + return X_mag, X_phase + +def make_padding(width, cropsize, offset): + left = offset + roi_size = cropsize - offset * 2 + if roi_size == 0: + roi_size = cropsize + right = roi_size - (width % roi_size) + left + + return left, right, roi_size + +def wave_to_spectrogram(wave, hop_length, n_fft, mid_side=False, mid_side_b2=False, reverse=False): + if reverse: + wave_left = np.flip(np.asfortranarray(wave[0])) + wave_right = np.flip(np.asfortranarray(wave[1])) + elif mid_side: + wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2) + wave_right = np.asfortranarray(np.subtract(wave[0], wave[1])) + elif mid_side_b2: + wave_left = np.asfortranarray(np.add(wave[1], wave[0] * .5)) + wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * .5)) + else: + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + + spec_left = librosa.stft(wave_left, n_fft, hop_length=hop_length) + spec_right = librosa.stft(wave_right, n_fft, hop_length=hop_length) + + spec = np.asfortranarray([spec_left, spec_right]) + + return spec + +def wave_to_spectrogram_mt(wave, hop_length, n_fft, mid_side=False, mid_side_b2=False, reverse=False): + import threading + + if reverse: + wave_left = np.flip(np.asfortranarray(wave[0])) + wave_right = np.flip(np.asfortranarray(wave[1])) + elif mid_side: + wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2) + wave_right = np.asfortranarray(np.subtract(wave[0], wave[1])) + elif mid_side_b2: + wave_left = np.asfortranarray(np.add(wave[1], wave[0] * .5)) + wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * .5)) + else: + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + + def run_thread(**kwargs): + global spec_left + spec_left = librosa.stft(**kwargs) + + thread = threading.Thread(target=run_thread, kwargs={'y': wave_left, 'n_fft': n_fft, 'hop_length': hop_length}) + thread.start() + spec_right = librosa.stft(wave_right, n_fft, hop_length=hop_length) + thread.join() + + spec = np.asfortranarray([spec_left, spec_right]) + + return spec + +def normalize(wave, is_normalize=False): + """Save output music files""" + maxv = np.abs(wave).max() + if maxv > 1.0: + print(f"\nNormalization Set {is_normalize}: Input above threshold for clipping. Max:{maxv}") + if is_normalize: + print(f"The result was normalized.") + wave /= maxv + else: + print(f"The result was not normalized.") + else: + print(f"\nNormalization Set {is_normalize}: Input not above threshold for clipping. Max:{maxv}") + + return wave + +def normalize_two_stem(wave, mix, is_normalize=False): + """Save output music files""" + + maxv = np.abs(wave).max() + max_mix = np.abs(mix).max() + + if maxv > 1.0: + print(f"\nNormalization Set {is_normalize}: Primary source above threshold for clipping. Max:{maxv}") + print(f"\nNormalization Set {is_normalize}: Mixture above threshold for clipping. Max:{max_mix}") + if is_normalize: + print(f"The result was normalized.") + wave /= maxv + mix /= maxv + else: + print(f"The result was not normalized.") + else: + print(f"\nNormalization Set {is_normalize}: Input not above threshold for clipping. Max:{maxv}") + + + print(f"\nNormalization Set {is_normalize}: Primary source - Max:{np.abs(wave).max()}") + print(f"\nNormalization Set {is_normalize}: Mixture - Max:{np.abs(mix).max()}") + + return wave, mix + +def combine_spectrograms(specs, mp): + l = min([specs[i].shape[2] for i in specs]) + spec_c = np.zeros(shape=(2, mp.param['bins'] + 1, l), dtype=np.complex64) + offset = 0 + bands_n = len(mp.param['band']) + + for d in range(1, bands_n + 1): + h = mp.param['band'][d]['crop_stop'] - mp.param['band'][d]['crop_start'] + spec_c[:, offset:offset+h, :l] = specs[d][:, mp.param['band'][d]['crop_start']:mp.param['band'][d]['crop_stop'], :l] + offset += h + + if offset > mp.param['bins']: + raise ValueError('Too much bins') + + # lowpass fiter + if mp.param['pre_filter_start'] > 0: # and mp.param['band'][bands_n]['res_type'] in ['scipy', 'polyphase']: + if bands_n == 1: + spec_c = fft_lp_filter(spec_c, mp.param['pre_filter_start'], mp.param['pre_filter_stop']) + else: + gp = 1 + for b in range(mp.param['pre_filter_start'] + 1, mp.param['pre_filter_stop']): + g = math.pow(10, -(b - mp.param['pre_filter_start']) * (3.5 - gp) / 20.0) + gp = g + spec_c[:, b, :] *= g + + return np.asfortranarray(spec_c) + +def spectrogram_to_image(spec, mode='magnitude'): + if mode == 'magnitude': + if np.iscomplexobj(spec): + y = np.abs(spec) + else: + y = spec + y = np.log10(y ** 2 + 1e-8) + elif mode == 'phase': + if np.iscomplexobj(spec): + y = np.angle(spec) + else: + y = spec + + y -= y.min() + y *= 255 / y.max() + img = np.uint8(y) + + if y.ndim == 3: + img = img.transpose(1, 2, 0) + img = np.concatenate([ + np.max(img, axis=2, keepdims=True), img + ], axis=2) + + return img + +def reduce_vocal_aggressively(X, y, softmask): + v = X - y + y_mag_tmp = np.abs(y) + v_mag_tmp = np.abs(v) + + v_mask = v_mag_tmp > y_mag_tmp + y_mag = np.clip(y_mag_tmp - v_mag_tmp * v_mask * softmask, 0, np.inf) + + return y_mag * np.exp(1.j * np.angle(y)) + +def merge_artifacts(y_mask, thres=0.01, min_range=64, fade_size=32): + mask = y_mask + + try: + if min_range < fade_size * 2: + raise ValueError('min_range must be >= fade_size * 2') + + idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0] + start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0]) + end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1]) + artifact_idx = np.where(end_idx - start_idx > min_range)[0] + weight = np.zeros_like(y_mask) + if len(artifact_idx) > 0: + start_idx = start_idx[artifact_idx] + end_idx = end_idx[artifact_idx] + old_e = None + for s, e in zip(start_idx, end_idx): + if old_e is not None and s - old_e < fade_size: + s = old_e - fade_size * 2 + + if s != 0: + weight[:, :, s:s + fade_size] = np.linspace(0, 1, fade_size) + else: + s -= fade_size + + if e != y_mask.shape[2]: + weight[:, :, e - fade_size:e] = np.linspace(1, 0, fade_size) + else: + e += fade_size + + weight[:, :, s + fade_size:e - fade_size] = 1 + old_e = e + + v_mask = 1 - y_mask + y_mask += weight * v_mask + + mask = y_mask + except Exception as e: + error_name = f'{type(e).__name__}' + traceback_text = ''.join(traceback.format_tb(e.__traceback__)) + message = f'{error_name}: "{e}"\n{traceback_text}"' + print('Post Process Failed: ', message) + + + return mask + +def align_wave_head_and_tail(a, b): + l = min([a[0].size, b[0].size]) + + return a[:l,:l], b[:l,:l] + +def spectrogram_to_wave(spec, hop_length, mid_side, mid_side_b2, reverse, clamp=False): + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + + wave_left = librosa.istft(spec_left, hop_length=hop_length) + wave_right = librosa.istft(spec_right, hop_length=hop_length) + + if reverse: + return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)]) + elif mid_side: + return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)]) + elif mid_side_b2: + return np.asfortranarray([np.add(wave_right / 1.25, .4 * wave_left), np.subtract(wave_left / 1.25, .4 * wave_right)]) + else: + return np.asfortranarray([wave_left, wave_right]) + +def spectrogram_to_wave_mt(spec, hop_length, mid_side, reverse, mid_side_b2): + import threading + + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + + def run_thread(**kwargs): + global wave_left + wave_left = librosa.istft(**kwargs) + + thread = threading.Thread(target=run_thread, kwargs={'stft_matrix': spec_left, 'hop_length': hop_length}) + thread.start() + wave_right = librosa.istft(spec_right, hop_length=hop_length) + thread.join() + + if reverse: + return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)]) + elif mid_side: + return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)]) + elif mid_side_b2: + return np.asfortranarray([np.add(wave_right / 1.25, .4 * wave_left), np.subtract(wave_left / 1.25, .4 * wave_right)]) + else: + return np.asfortranarray([wave_left, wave_right]) + +def cmb_spectrogram_to_wave(spec_m, mp, extra_bins_h=None, extra_bins=None): + bands_n = len(mp.param['band']) + offset = 0 + + for d in range(1, bands_n + 1): + bp = mp.param['band'][d] + spec_s = np.ndarray(shape=(2, bp['n_fft'] // 2 + 1, spec_m.shape[2]), dtype=complex) + h = bp['crop_stop'] - bp['crop_start'] + spec_s[:, bp['crop_start']:bp['crop_stop'], :] = spec_m[:, offset:offset+h, :] + + offset += h + if d == bands_n: # higher + if extra_bins_h: # if --high_end_process bypass + max_bin = bp['n_fft'] // 2 + spec_s[:, max_bin-extra_bins_h:max_bin, :] = extra_bins[:, :extra_bins_h, :] + if bp['hpf_start'] > 0: + spec_s = fft_hp_filter(spec_s, bp['hpf_start'], bp['hpf_stop'] - 1) + if bands_n == 1: + wave = spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse']) + else: + wave = np.add(wave, spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse'])) + else: + sr = mp.param['band'][d+1]['sr'] + if d == 1: # lower + spec_s = fft_lp_filter(spec_s, bp['lpf_start'], bp['lpf_stop']) + wave = librosa.resample(spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse']), bp['sr'], sr, res_type=wav_resolution) + else: # mid + spec_s = fft_hp_filter(spec_s, bp['hpf_start'], bp['hpf_stop'] - 1) + spec_s = fft_lp_filter(spec_s, bp['lpf_start'], bp['lpf_stop']) + wave2 = np.add(wave, spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse'])) + wave = librosa.resample(wave2, bp['sr'], sr, res_type=wav_resolution) + + return wave + +def fft_lp_filter(spec, bin_start, bin_stop): + g = 1.0 + for b in range(bin_start, bin_stop): + g -= 1 / (bin_stop - bin_start) + spec[:, b, :] = g * spec[:, b, :] + + spec[:, bin_stop:, :] *= 0 + + return spec + +def fft_hp_filter(spec, bin_start, bin_stop): + g = 1.0 + for b in range(bin_start, bin_stop, -1): + g -= 1 / (bin_start - bin_stop) + spec[:, b, :] = g * spec[:, b, :] + + spec[:, 0:bin_stop+1, :] *= 0 + + return spec + +def mirroring(a, spec_m, input_high_end, mp): + if 'mirroring' == a: + mirror = np.flip(np.abs(spec_m[:, mp.param['pre_filter_start']-10-input_high_end.shape[1]:mp.param['pre_filter_start']-10, :]), 1) + mirror = mirror * np.exp(1.j * np.angle(input_high_end)) + + return np.where(np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror) + + if 'mirroring2' == a: + mirror = np.flip(np.abs(spec_m[:, mp.param['pre_filter_start']-10-input_high_end.shape[1]:mp.param['pre_filter_start']-10, :]), 1) + mi = np.multiply(mirror, input_high_end * 1.7) + + return np.where(np.abs(input_high_end) <= np.abs(mi), input_high_end, mi) + +def adjust_aggr(mask, is_non_accom_stem, aggressiveness): + aggr = aggressiveness['value'] + + if aggr != 0: + if is_non_accom_stem: + aggr = 1 - aggr + + aggr = [aggr, aggr] + + if aggressiveness['aggr_correction'] is not None: + aggr[0] += aggressiveness['aggr_correction']['left'] + aggr[1] += aggressiveness['aggr_correction']['right'] + + for ch in range(2): + mask[ch, :aggressiveness['split_bin']] = np.power(mask[ch, :aggressiveness['split_bin']], 1 + aggr[ch] / 3) + mask[ch, aggressiveness['split_bin']:] = np.power(mask[ch, aggressiveness['split_bin']:], 1 + aggr[ch]) + + # if is_non_accom_stem: + # mask = (1.0 - mask) + + return mask + +def stft(wave, nfft, hl): + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + spec_left = librosa.stft(wave_left, nfft, hop_length=hl) + spec_right = librosa.stft(wave_right, nfft, hop_length=hl) + spec = np.asfortranarray([spec_left, spec_right]) + + return spec + +def istft(spec, hl): + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + wave_left = librosa.istft(spec_left, hop_length=hl) + wave_right = librosa.istft(spec_right, hop_length=hl) + wave = np.asfortranarray([wave_left, wave_right]) + + return wave + +def spec_effects(wave, algorithm='Default', value=None): + spec = [stft(wave[0],2048,1024), stft(wave[1],2048,1024)] + if algorithm == 'Min_Mag': + v_spec_m = np.where(np.abs(spec[1]) <= np.abs(spec[0]), spec[1], spec[0]) + wave = istft(v_spec_m,1024) + elif algorithm == 'Max_Mag': + v_spec_m = np.where(np.abs(spec[1]) >= np.abs(spec[0]), spec[1], spec[0]) + wave = istft(v_spec_m,1024) + elif algorithm == 'Default': + wave = (wave[1] * value) + (wave[0] * (1-value)) + elif algorithm == 'Invert_p': + X_mag = np.abs(spec[0]) + y_mag = np.abs(spec[1]) + max_mag = np.where(X_mag >= y_mag, X_mag, y_mag) + v_spec = spec[1] - max_mag * np.exp(1.j * np.angle(spec[0])) + wave = istft(v_spec,1024) + + return wave + +def spectrogram_to_wave_no_mp(spec, n_fft=2048, hop_length=1024): + wave = librosa.istft(spec, n_fft=n_fft, hop_length=hop_length) + + if wave.ndim == 1: + wave = np.asfortranarray([wave,wave]) + + return wave + +def wave_to_spectrogram_no_mp(wave): + + spec = librosa.stft(wave, n_fft=2048, hop_length=1024) + + if spec.ndim == 1: + spec = np.asfortranarray([spec,spec]) + + return spec + +def invert_audio(specs, invert_p=True): + + ln = min([specs[0].shape[2], specs[1].shape[2]]) + specs[0] = specs[0][:,:,:ln] + specs[1] = specs[1][:,:,:ln] + + if invert_p: + X_mag = np.abs(specs[0]) + y_mag = np.abs(specs[1]) + max_mag = np.where(X_mag >= y_mag, X_mag, y_mag) + v_spec = specs[1] - max_mag * np.exp(1.j * np.angle(specs[0])) + else: + specs[1] = reduce_vocal_aggressively(specs[0], specs[1], 0.2) + v_spec = specs[0] - specs[1] + + return v_spec + +def invert_stem(mixture, stem): + + mixture = wave_to_spectrogram_no_mp(mixture) + stem = wave_to_spectrogram_no_mp(stem) + output = spectrogram_to_wave_no_mp(invert_audio([mixture, stem])) + + return -output.T + +def ensembling(a, specs): + for i in range(1, len(specs)): + if i == 1: + spec = specs[0] + + ln = min([spec.shape[2], specs[i].shape[2]]) + spec = spec[:,:,:ln] + specs[i] = specs[i][:,:,:ln] + + if MIN_SPEC == a: + spec = np.where(np.abs(specs[i]) <= np.abs(spec), specs[i], spec) + if MAX_SPEC == a: + spec = np.where(np.abs(specs[i]) >= np.abs(spec), specs[i], spec) + if AVERAGE == a: + spec = np.where(np.abs(specs[i]) == np.abs(spec), specs[i], spec) + + return spec + +def ensemble_inputs(audio_input, algorithm, is_normalization, wav_type_set, save_path): + + wavs_ = [] + + if algorithm == AVERAGE: + output = average_audio(audio_input) + samplerate = 44100 + else: + specs = [] + + for i in range(len(audio_input)): + wave, samplerate = librosa.load(audio_input[i], mono=False, sr=44100) + wavs_.append(wave) + spec = wave_to_spectrogram_no_mp(wave) + specs.append(spec) + + wave_shapes = [w.shape[1] for w in wavs_] + target_shape = wavs_[wave_shapes.index(max(wave_shapes))] + + output = spectrogram_to_wave_no_mp(ensembling(algorithm, specs)) + output = to_shape(output, target_shape.shape) + + sf.write(save_path, normalize(output.T, is_normalization), samplerate, subtype=wav_type_set) + +def to_shape(x, target_shape): + padding_list = [] + for x_dim, target_dim in zip(x.shape, target_shape): + pad_value = (target_dim - x_dim) + pad_tuple = ((0, pad_value)) + padding_list.append(pad_tuple) + + return np.pad(x, tuple(padding_list), mode='constant') + +def to_shape_minimize(x: np.ndarray, target_shape): + + padding_list = [] + for x_dim, target_dim in zip(x.shape, target_shape): + pad_value = (target_dim - x_dim) + pad_tuple = ((0, pad_value)) + padding_list.append(pad_tuple) + + return np.pad(x, tuple(padding_list), mode='constant') + +def augment_audio(export_path, audio_file, rate, is_normalization, wav_type_set, save_format=None, is_pitch=False): + + wav, sr = librosa.load(audio_file, sr=44100, mono=False) + + if wav.ndim == 1: + wav = np.asfortranarray([wav,wav]) + + if is_pitch: + wav_1 = pyrb.pitch_shift(wav[0], sr, rate, rbargs=None) + wav_2 = pyrb.pitch_shift(wav[1], sr, rate, rbargs=None) + else: + wav_1 = pyrb.time_stretch(wav[0], sr, rate, rbargs=None) + wav_2 = pyrb.time_stretch(wav[1], sr, rate, rbargs=None) + + if wav_1.shape > wav_2.shape: + wav_2 = to_shape(wav_2, wav_1.shape) + if wav_1.shape < wav_2.shape: + wav_1 = to_shape(wav_1, wav_2.shape) + + wav_mix = np.asfortranarray([wav_1, wav_2]) + + sf.write(export_path, normalize(wav_mix.T, is_normalization), sr, subtype=wav_type_set) + save_format(export_path) + +def average_audio(audio): + + waves = [] + wave_shapes = [] + final_waves = [] + + for i in range(len(audio)): + wave = librosa.load(audio[i], sr=44100, mono=False) + waves.append(wave[0]) + wave_shapes.append(wave[0].shape[1]) + + wave_shapes_index = wave_shapes.index(max(wave_shapes)) + target_shape = waves[wave_shapes_index] + waves.pop(wave_shapes_index) + final_waves.append(target_shape) + + for n_array in waves: + wav_target = to_shape(n_array, target_shape.shape) + final_waves.append(wav_target) + + waves = sum(final_waves) + waves = waves/len(audio) + + return waves + +def average_dual_sources(wav_1, wav_2, value): + + if wav_1.shape > wav_2.shape: + wav_2 = to_shape(wav_2, wav_1.shape) + if wav_1.shape < wav_2.shape: + wav_1 = to_shape(wav_1, wav_2.shape) + + wave = (wav_1 * value) + (wav_2 * (1-value)) + + return wave + +def reshape_sources(wav_1: np.ndarray, wav_2: np.ndarray): + + if wav_1.shape > wav_2.shape: + wav_2 = to_shape(wav_2, wav_1.shape) + if wav_1.shape < wav_2.shape: + ln = min([wav_1.shape[1], wav_2.shape[1]]) + wav_2 = wav_2[:,:ln] + + ln = min([wav_1.shape[1], wav_2.shape[1]]) + wav_1 = wav_1[:,:ln] + wav_2 = wav_2[:,:ln] + + return wav_2 + +def align_audio(file1, file2, file2_aligned, file_subtracted, wav_type_set, is_normalization, command_Text, progress_bar_main_var, save_format): + def get_diff(a, b): + corr = np.correlate(a, b, "full") + diff = corr.argmax() - (b.shape[0] - 1) + return diff + + progress_bar_main_var.set(10) + + # read tracks + wav1, sr1 = librosa.load(file1, sr=44100, mono=False) + wav2, sr2 = librosa.load(file2, sr=44100, mono=False) + wav1 = wav1.transpose() + wav2 = wav2.transpose() + + command_Text(f"Audio file shapes: {wav1.shape} / {wav2.shape}\n") + + wav2_org = wav2.copy() + progress_bar_main_var.set(20) + + command_Text("Processing files... \n") + + # pick random position and get diff + + counts = {} # counting up for each diff value + progress = 20 + + check_range = 64 + + base = (64 / check_range) + + for i in range(check_range): + index = int(random.uniform(44100 * 2, min(wav1.shape[0], wav2.shape[0]) - 44100 * 2)) + shift = int(random.uniform(-22050,+22050)) + samp1 = wav1[index :index +44100, 0] # currently use left channel + samp2 = wav2[index+shift:index+shift+44100, 0] + progress += 1 * base + progress_bar_main_var.set(progress) + diff = get_diff(samp1, samp2) + diff -= shift + + if abs(diff) < 22050: + if not diff in counts: + counts[diff] = 0 + counts[diff] += 1 + + # use max counted diff value + max_count = 0 + est_diff = 0 + for diff in counts.keys(): + if counts[diff] > max_count: + max_count = counts[diff] + est_diff = diff + + command_Text(f"Estimated difference is {est_diff} (count: {max_count})\n") + + progress_bar_main_var.set(90) + + audio_files = [] + + def save_aligned_audio(wav2_aligned): + command_Text(f"Aligned File 2 with File 1.\n") + command_Text(f"Saving files... ") + sf.write(file2_aligned, normalize(wav2_aligned, is_normalization), sr2, subtype=wav_type_set) + save_format(file2_aligned) + min_len = min(wav1.shape[0], wav2_aligned.shape[0]) + wav_sub = wav1[:min_len] - wav2_aligned[:min_len] + audio_files.append(file2_aligned) + return min_len, wav_sub + + # make aligned track 2 + if est_diff > 0: + wav2_aligned = np.append(np.zeros((est_diff, 2)), wav2_org, axis=0) + min_len, wav_sub = save_aligned_audio(wav2_aligned) + elif est_diff < 0: + wav2_aligned = wav2_org[-est_diff:] + min_len, wav_sub = save_aligned_audio(wav2_aligned) + else: + command_Text(f"Audio files already aligned.\n") + command_Text(f"Saving inverted track... ") + min_len = min(wav1.shape[0], wav2.shape[0]) + wav_sub = wav1[:min_len] - wav2[:min_len] + + wav_sub = np.clip(wav_sub, -1, +1) + + sf.write(file_subtracted, normalize(wav_sub, is_normalization), sr1, subtype=wav_type_set) + save_format(file_subtracted) + + progress_bar_main_var.set(95) \ No newline at end of file diff --git a/lib_v5/vr_network/__init__.py b/lib_v5/vr_network/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..361b7086d3bebd02a4eb7b20582f2fc5ab299cf2 --- /dev/null +++ b/lib_v5/vr_network/__init__.py @@ -0,0 +1 @@ +# VR init. diff --git a/lib_v5/vr_network/__pycache__/__init__.cpython-310.pyc b/lib_v5/vr_network/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..481b39e2799e07e8cf8509525805c794d48bff6b Binary files /dev/null and b/lib_v5/vr_network/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib_v5/vr_network/__pycache__/layers.cpython-310.pyc b/lib_v5/vr_network/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d5dbd517a72e928f1cb74a015be1d4ab0aa1a52 Binary files /dev/null and b/lib_v5/vr_network/__pycache__/layers.cpython-310.pyc differ diff --git a/lib_v5/vr_network/__pycache__/layers_new.cpython-310.pyc b/lib_v5/vr_network/__pycache__/layers_new.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce5849722d6e06d2913c3942332f2bde5ccb8477 Binary files /dev/null and b/lib_v5/vr_network/__pycache__/layers_new.cpython-310.pyc differ diff --git a/lib_v5/vr_network/__pycache__/model_param_init.cpython-310.pyc b/lib_v5/vr_network/__pycache__/model_param_init.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a8884871d1eaba49de779460c7163987d6b7584 Binary files /dev/null and b/lib_v5/vr_network/__pycache__/model_param_init.cpython-310.pyc differ diff --git a/lib_v5/vr_network/__pycache__/nets.cpython-310.pyc b/lib_v5/vr_network/__pycache__/nets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d71e39f6245f48a27f81cccf3452a44b10bc8ea7 Binary files /dev/null and b/lib_v5/vr_network/__pycache__/nets.cpython-310.pyc differ diff --git a/lib_v5/vr_network/__pycache__/nets_new.cpython-310.pyc b/lib_v5/vr_network/__pycache__/nets_new.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..280f19c5f794f8f1d137d014058930bf46de85f5 Binary files /dev/null and b/lib_v5/vr_network/__pycache__/nets_new.cpython-310.pyc differ diff --git a/lib_v5/vr_network/layers.py b/lib_v5/vr_network/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..5913484a58f5a2f2bf780d6db980885d87950ba4 --- /dev/null +++ b/lib_v5/vr_network/layers.py @@ -0,0 +1,143 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from lib_v5 import spec_utils + +class Conv2DBNActiv(nn.Module): + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(Conv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, nout, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + bias=False), + nn.BatchNorm2d(nout), + activ() + ) + + def __call__(self, x): + return self.conv(x) + +class SeperableConv2DBNActiv(nn.Module): + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(SeperableConv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, nin, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + groups=nin, + bias=False), + nn.Conv2d( + nin, nout, + kernel_size=1, + bias=False), + nn.BatchNorm2d(nout), + activ() + ) + + def __call__(self, x): + return self.conv(x) + + +class Encoder(nn.Module): + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): + super(Encoder, self).__init__() + self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ) + + def __call__(self, x): + skip = self.conv1(x) + h = self.conv2(skip) + + return h, skip + + +class Decoder(nn.Module): + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False): + super(Decoder, self).__init__() + self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def __call__(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + if skip is not None: + skip = spec_utils.crop_center(skip, x) + x = torch.cat([x, skip], dim=1) + h = self.conv(x) + + if self.dropout is not None: + h = self.dropout(h) + + return h + + +class ASPPModule(nn.Module): + + def __init__(self, nn_architecture, nin, nout, dilations=(4, 8, 16), activ=nn.ReLU): + super(ASPPModule, self).__init__() + self.conv1 = nn.Sequential( + nn.AdaptiveAvgPool2d((1, None)), + Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ) + ) + + self.nn_architecture = nn_architecture + self.six_layer = [129605] + self.seven_layer = [537238, 537227, 33966] + + extra_conv = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ) + + self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ) + self.conv3 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[0], dilations[0], activ=activ) + self.conv4 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[1], dilations[1], activ=activ) + self.conv5 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ) + + if self.nn_architecture in self.six_layer: + self.conv6 = extra_conv + nin_x = 6 + elif self.nn_architecture in self.seven_layer: + self.conv6 = extra_conv + self.conv7 = extra_conv + nin_x = 7 + else: + nin_x = 5 + + self.bottleneck = nn.Sequential( + Conv2DBNActiv(nin * nin_x, nout, 1, 1, 0, activ=activ), + nn.Dropout2d(0.1) + ) + + def forward(self, x): + _, _, h, w = x.size() + feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True) + feat2 = self.conv2(x) + feat3 = self.conv3(x) + feat4 = self.conv4(x) + feat5 = self.conv5(x) + + if self.nn_architecture in self.six_layer: + feat6 = self.conv6(x) + out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6), dim=1) + elif self.nn_architecture in self.seven_layer: + feat6 = self.conv6(x) + feat7 = self.conv7(x) + out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6, feat7), dim=1) + else: + out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) + + bottle = self.bottleneck(out) + return bottle diff --git a/lib_v5/vr_network/layers_new.py b/lib_v5/vr_network/layers_new.py new file mode 100644 index 0000000000000000000000000000000000000000..33181dd8cdd244f91a684698e075da2ab3ef668e --- /dev/null +++ b/lib_v5/vr_network/layers_new.py @@ -0,0 +1,126 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from lib_v5 import spec_utils + +class Conv2DBNActiv(nn.Module): + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(Conv2DBNActiv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + nin, nout, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + bias=False), + nn.BatchNorm2d(nout), + activ() + ) + + def __call__(self, x): + return self.conv(x) + +class Encoder(nn.Module): + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): + super(Encoder, self).__init__() + self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ) + self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) + + def __call__(self, x): + h = self.conv1(x) + h = self.conv2(h) + + return h + + +class Decoder(nn.Module): + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False): + super(Decoder, self).__init__() + self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + # self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def __call__(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + + if skip is not None: + skip = spec_utils.crop_center(skip, x) + x = torch.cat([x, skip], dim=1) + + h = self.conv1(x) + # h = self.conv2(h) + + if self.dropout is not None: + h = self.dropout(h) + + return h + + +class ASPPModule(nn.Module): + + def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False): + super(ASPPModule, self).__init__() + self.conv1 = nn.Sequential( + nn.AdaptiveAvgPool2d((1, None)), + Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ) + ) + self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ) + self.conv3 = Conv2DBNActiv( + nin, nout, 3, 1, dilations[0], dilations[0], activ=activ + ) + self.conv4 = Conv2DBNActiv( + nin, nout, 3, 1, dilations[1], dilations[1], activ=activ + ) + self.conv5 = Conv2DBNActiv( + nin, nout, 3, 1, dilations[2], dilations[2], activ=activ + ) + self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def forward(self, x): + _, _, h, w = x.size() + feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True) + feat2 = self.conv2(x) + feat3 = self.conv3(x) + feat4 = self.conv4(x) + feat5 = self.conv5(x) + out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) + out = self.bottleneck(out) + + if self.dropout is not None: + out = self.dropout(out) + + return out + + +class LSTMModule(nn.Module): + + def __init__(self, nin_conv, nin_lstm, nout_lstm): + super(LSTMModule, self).__init__() + self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0) + self.lstm = nn.LSTM( + input_size=nin_lstm, + hidden_size=nout_lstm // 2, + bidirectional=True + ) + self.dense = nn.Sequential( + nn.Linear(nout_lstm, nin_lstm), + nn.BatchNorm1d(nin_lstm), + nn.ReLU() + ) + + def forward(self, x): + N, _, nbins, nframes = x.size() + h = self.conv(x)[:, 0] # N, nbins, nframes + h = h.permute(2, 0, 1) # nframes, N, nbins + h, _ = self.lstm(h) + h = self.dense(h.reshape(-1, h.size()[-1])) # nframes * N, nbins + h = h.reshape(nframes, N, 1, nbins) + h = h.permute(1, 2, 3, 0) + + return h diff --git a/lib_v5/vr_network/model_param_init.py b/lib_v5/vr_network/model_param_init.py new file mode 100644 index 0000000000000000000000000000000000000000..c859d2fbb2db2e94d8838921f421c655f7f31f34 --- /dev/null +++ b/lib_v5/vr_network/model_param_init.py @@ -0,0 +1,59 @@ +import json +import pathlib + +default_param = {} +default_param['bins'] = 768 +default_param['unstable_bins'] = 9 # training only +default_param['reduction_bins'] = 762 # training only +default_param['sr'] = 44100 +default_param['pre_filter_start'] = 757 +default_param['pre_filter_stop'] = 768 +default_param['band'] = {} + + +default_param['band'][1] = { + 'sr': 11025, + 'hl': 128, + 'n_fft': 960, + 'crop_start': 0, + 'crop_stop': 245, + 'lpf_start': 61, # inference only + 'res_type': 'polyphase' +} + +default_param['band'][2] = { + 'sr': 44100, + 'hl': 512, + 'n_fft': 1536, + 'crop_start': 24, + 'crop_stop': 547, + 'hpf_start': 81, # inference only + 'res_type': 'sinc_best' +} + + +def int_keys(d): + r = {} + for k, v in d: + if k.isdigit(): + k = int(k) + r[k] = v + return r + + +class ModelParameters(object): + def __init__(self, config_path=''): + if '.pth' == pathlib.Path(config_path).suffix: + import zipfile + + with zipfile.ZipFile(config_path, 'r') as zip: + self.param = json.loads(zip.read('param.json'), object_pairs_hook=int_keys) + elif '.json' == pathlib.Path(config_path).suffix: + with open(config_path, 'r') as f: + self.param = json.loads(f.read(), object_pairs_hook=int_keys) + else: + self.param = default_param + + for k in ['mid_side', 'mid_side_b', 'mid_side_b2', 'stereo_w', 'stereo_n', 'reverse']: + if not k in self.param: + self.param[k] = False \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/1band_sr16000_hl512.json b/lib_v5/vr_network/modelparams/1band_sr16000_hl512.json new file mode 100644 index 0000000000000000000000000000000000000000..72cb4499867ad2827185e85687f06fb73d33eced --- /dev/null +++ b/lib_v5/vr_network/modelparams/1band_sr16000_hl512.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 16000, + "hl": 512, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 1024, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 16000, + "pre_filter_start": 1023, + "pre_filter_stop": 1024 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/1band_sr32000_hl512.json b/lib_v5/vr_network/modelparams/1band_sr32000_hl512.json new file mode 100644 index 0000000000000000000000000000000000000000..3c00ecf0a105e55a6a86a3c32db301a2635b5b41 --- /dev/null +++ b/lib_v5/vr_network/modelparams/1band_sr32000_hl512.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 32000, + "hl": 512, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 1024, + "hpf_start": -1, + "res_type": "kaiser_fast" + } + }, + "sr": 32000, + "pre_filter_start": 1000, + "pre_filter_stop": 1021 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/1band_sr33075_hl384.json b/lib_v5/vr_network/modelparams/1band_sr33075_hl384.json new file mode 100644 index 0000000000000000000000000000000000000000..55666ac9a8d0547751fb4b4d3bffb1ee2c956913 --- /dev/null +++ b/lib_v5/vr_network/modelparams/1band_sr33075_hl384.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 33075, + "hl": 384, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 1024, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 33075, + "pre_filter_start": 1000, + "pre_filter_stop": 1021 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/1band_sr44100_hl1024.json b/lib_v5/vr_network/modelparams/1band_sr44100_hl1024.json new file mode 100644 index 0000000000000000000000000000000000000000..665abe20eb3cc39fe0f8493dad8f25f6ef634a14 --- /dev/null +++ b/lib_v5/vr_network/modelparams/1band_sr44100_hl1024.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 44100, + "hl": 1024, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 1024, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 44100, + "pre_filter_start": 1023, + "pre_filter_stop": 1024 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/1band_sr44100_hl256.json b/lib_v5/vr_network/modelparams/1band_sr44100_hl256.json new file mode 100644 index 0000000000000000000000000000000000000000..0e8b16f89b0231d06eabe8d2f7c2670c7caa2272 --- /dev/null +++ b/lib_v5/vr_network/modelparams/1band_sr44100_hl256.json @@ -0,0 +1,19 @@ +{ + "bins": 256, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 44100, + "hl": 256, + "n_fft": 512, + "crop_start": 0, + "crop_stop": 256, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 44100, + "pre_filter_start": 256, + "pre_filter_stop": 256 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/1band_sr44100_hl512.json b/lib_v5/vr_network/modelparams/1band_sr44100_hl512.json new file mode 100644 index 0000000000000000000000000000000000000000..3b38fcaf60ba204e03a47f5bd3f5bcfe75e1983a --- /dev/null +++ b/lib_v5/vr_network/modelparams/1band_sr44100_hl512.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 44100, + "hl": 512, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 1024, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 44100, + "pre_filter_start": 1023, + "pre_filter_stop": 1024 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/1band_sr44100_hl512_cut.json b/lib_v5/vr_network/modelparams/1band_sr44100_hl512_cut.json new file mode 100644 index 0000000000000000000000000000000000000000..630df3524e340f43a1ddb7b33ff02cc91fc1cb47 --- /dev/null +++ b/lib_v5/vr_network/modelparams/1band_sr44100_hl512_cut.json @@ -0,0 +1,19 @@ +{ + "bins": 1024, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 44100, + "hl": 512, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 700, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 44100, + "pre_filter_start": 1023, + "pre_filter_stop": 700 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/1band_sr44100_hl512_nf1024.json b/lib_v5/vr_network/modelparams/1band_sr44100_hl512_nf1024.json new file mode 100644 index 0000000000000000000000000000000000000000..be493b511baab69e9c211933114be454e1901141 --- /dev/null +++ b/lib_v5/vr_network/modelparams/1band_sr44100_hl512_nf1024.json @@ -0,0 +1,19 @@ +{ + "bins": 512, + "unstable_bins": 0, + "reduction_bins": 0, + "band": { + "1": { + "sr": 44100, + "hl": 512, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 512, + "hpf_start": -1, + "res_type": "sinc_best" + } + }, + "sr": 44100, + "pre_filter_start": 511, + "pre_filter_stop": 512 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/2band_32000.json b/lib_v5/vr_network/modelparams/2band_32000.json new file mode 100644 index 0000000000000000000000000000000000000000..ab9cf1150a818eb6252105408311be0a40d423b3 --- /dev/null +++ b/lib_v5/vr_network/modelparams/2band_32000.json @@ -0,0 +1,30 @@ +{ + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 705, + "band": { + "1": { + "sr": 6000, + "hl": 66, + "n_fft": 512, + "crop_start": 0, + "crop_stop": 240, + "lpf_start": 60, + "lpf_stop": 118, + "res_type": "sinc_fastest" + }, + "2": { + "sr": 32000, + "hl": 352, + "n_fft": 1024, + "crop_start": 22, + "crop_stop": 505, + "hpf_start": 44, + "hpf_stop": 23, + "res_type": "sinc_medium" + } + }, + "sr": 32000, + "pre_filter_start": 710, + "pre_filter_stop": 731 +} diff --git a/lib_v5/vr_network/modelparams/2band_44100_lofi.json b/lib_v5/vr_network/modelparams/2band_44100_lofi.json new file mode 100644 index 0000000000000000000000000000000000000000..7faa216d7b49aeece24123dbdd868847a1dbc03c --- /dev/null +++ b/lib_v5/vr_network/modelparams/2band_44100_lofi.json @@ -0,0 +1,30 @@ +{ + "bins": 512, + "unstable_bins": 7, + "reduction_bins": 510, + "band": { + "1": { + "sr": 11025, + "hl": 160, + "n_fft": 768, + "crop_start": 0, + "crop_stop": 192, + "lpf_start": 41, + "lpf_stop": 139, + "res_type": "sinc_fastest" + }, + "2": { + "sr": 44100, + "hl": 640, + "n_fft": 1024, + "crop_start": 10, + "crop_stop": 320, + "hpf_start": 47, + "hpf_stop": 15, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 510, + "pre_filter_stop": 512 +} diff --git a/lib_v5/vr_network/modelparams/2band_48000.json b/lib_v5/vr_network/modelparams/2band_48000.json new file mode 100644 index 0000000000000000000000000000000000000000..7e78175052b09cb1a32345e54006475992712f9a --- /dev/null +++ b/lib_v5/vr_network/modelparams/2band_48000.json @@ -0,0 +1,30 @@ +{ + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 705, + "band": { + "1": { + "sr": 6000, + "hl": 66, + "n_fft": 512, + "crop_start": 0, + "crop_stop": 240, + "lpf_start": 60, + "lpf_stop": 240, + "res_type": "sinc_fastest" + }, + "2": { + "sr": 48000, + "hl": 528, + "n_fft": 1536, + "crop_start": 22, + "crop_stop": 505, + "hpf_start": 82, + "hpf_stop": 22, + "res_type": "sinc_medium" + } + }, + "sr": 48000, + "pre_filter_start": 710, + "pre_filter_stop": 731 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/3band_44100.json b/lib_v5/vr_network/modelparams/3band_44100.json new file mode 100644 index 0000000000000000000000000000000000000000..d881d767ff83fbac0e18dfe2587ef16925b29b3c --- /dev/null +++ b/lib_v5/vr_network/modelparams/3band_44100.json @@ -0,0 +1,42 @@ +{ + "bins": 768, + "unstable_bins": 5, + "reduction_bins": 733, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 768, + "crop_start": 0, + "crop_stop": 278, + "lpf_start": 28, + "lpf_stop": 140, + "res_type": "polyphase" + }, + "2": { + "sr": 22050, + "hl": 256, + "n_fft": 768, + "crop_start": 14, + "crop_stop": 322, + "hpf_start": 70, + "hpf_stop": 14, + "lpf_start": 283, + "lpf_stop": 314, + "res_type": "polyphase" + }, + "3": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 131, + "crop_stop": 313, + "hpf_start": 154, + "hpf_stop": 141, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 757, + "pre_filter_stop": 768 +} diff --git a/lib_v5/vr_network/modelparams/3band_44100_mid.json b/lib_v5/vr_network/modelparams/3band_44100_mid.json new file mode 100644 index 0000000000000000000000000000000000000000..77ec198573b19f36519a028a509767d30764c0e2 --- /dev/null +++ b/lib_v5/vr_network/modelparams/3band_44100_mid.json @@ -0,0 +1,43 @@ +{ + "mid_side": true, + "bins": 768, + "unstable_bins": 5, + "reduction_bins": 733, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 768, + "crop_start": 0, + "crop_stop": 278, + "lpf_start": 28, + "lpf_stop": 140, + "res_type": "polyphase" + }, + "2": { + "sr": 22050, + "hl": 256, + "n_fft": 768, + "crop_start": 14, + "crop_stop": 322, + "hpf_start": 70, + "hpf_stop": 14, + "lpf_start": 283, + "lpf_stop": 314, + "res_type": "polyphase" + }, + "3": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 131, + "crop_stop": 313, + "hpf_start": 154, + "hpf_stop": 141, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 757, + "pre_filter_stop": 768 +} diff --git a/lib_v5/vr_network/modelparams/3band_44100_msb2.json b/lib_v5/vr_network/modelparams/3band_44100_msb2.json new file mode 100644 index 0000000000000000000000000000000000000000..85ee8a7d44541c9176e85ea3dce8728d34990938 --- /dev/null +++ b/lib_v5/vr_network/modelparams/3band_44100_msb2.json @@ -0,0 +1,43 @@ +{ + "mid_side_b2": true, + "bins": 640, + "unstable_bins": 7, + "reduction_bins": 565, + "band": { + "1": { + "sr": 11025, + "hl": 108, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 187, + "lpf_start": 92, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "2": { + "sr": 22050, + "hl": 216, + "n_fft": 768, + "crop_start": 0, + "crop_stop": 212, + "hpf_start": 68, + "hpf_stop": 34, + "lpf_start": 174, + "lpf_stop": 209, + "res_type": "polyphase" + }, + "3": { + "sr": 44100, + "hl": 432, + "n_fft": 640, + "crop_start": 66, + "crop_stop": 307, + "hpf_start": 86, + "hpf_stop": 72, + "res_type": "kaiser_fast" + } + }, + "sr": 44100, + "pre_filter_start": 639, + "pre_filter_stop": 640 +} diff --git a/lib_v5/vr_network/modelparams/4band_44100.json b/lib_v5/vr_network/modelparams/4band_44100.json new file mode 100644 index 0000000000000000000000000000000000000000..df123754204372aa50d464fbe9102a401f48cc73 --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_44100.json @@ -0,0 +1,54 @@ +{ + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} diff --git a/lib_v5/vr_network/modelparams/4band_44100_mid.json b/lib_v5/vr_network/modelparams/4band_44100_mid.json new file mode 100644 index 0000000000000000000000000000000000000000..e91b699eb63d3382c3b9e9edf46d40ed91d6122b --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_44100_mid.json @@ -0,0 +1,55 @@ +{ + "bins": 768, + "unstable_bins": 7, + "mid_side": true, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} diff --git a/lib_v5/vr_network/modelparams/4band_44100_msb.json b/lib_v5/vr_network/modelparams/4band_44100_msb.json new file mode 100644 index 0000000000000000000000000000000000000000..f852f280ec9d98fc1b65cec688290eaafec61b84 --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_44100_msb.json @@ -0,0 +1,55 @@ +{ + "mid_side_b": true, + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/4band_44100_msb2.json b/lib_v5/vr_network/modelparams/4band_44100_msb2.json new file mode 100644 index 0000000000000000000000000000000000000000..f852f280ec9d98fc1b65cec688290eaafec61b84 --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_44100_msb2.json @@ -0,0 +1,55 @@ +{ + "mid_side_b": true, + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/4band_44100_reverse.json b/lib_v5/vr_network/modelparams/4band_44100_reverse.json new file mode 100644 index 0000000000000000000000000000000000000000..7a07d5541bd83dc1caa20b531c3b43a2ffccac88 --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_44100_reverse.json @@ -0,0 +1,55 @@ +{ + "reverse": true, + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/4band_44100_sw.json b/lib_v5/vr_network/modelparams/4band_44100_sw.json new file mode 100644 index 0000000000000000000000000000000000000000..ba0cf342106de793e6ec3e876854c7fd451fbf76 --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_44100_sw.json @@ -0,0 +1,55 @@ +{ + "stereo_w": true, + "bins": 768, + "unstable_bins": 7, + "reduction_bins": 668, + "band": { + "1": { + "sr": 11025, + "hl": 128, + "n_fft": 1024, + "crop_start": 0, + "crop_stop": 186, + "lpf_start": 37, + "lpf_stop": 73, + "res_type": "polyphase" + }, + "2": { + "sr": 11025, + "hl": 128, + "n_fft": 512, + "crop_start": 4, + "crop_stop": 185, + "hpf_start": 36, + "hpf_stop": 18, + "lpf_start": 93, + "lpf_stop": 185, + "res_type": "polyphase" + }, + "3": { + "sr": 22050, + "hl": 256, + "n_fft": 512, + "crop_start": 46, + "crop_stop": 186, + "hpf_start": 93, + "hpf_stop": 46, + "lpf_start": 164, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 512, + "n_fft": 768, + "crop_start": 121, + "crop_stop": 382, + "hpf_start": 138, + "hpf_stop": 123, + "res_type": "sinc_medium" + } + }, + "sr": 44100, + "pre_filter_start": 740, + "pre_filter_stop": 768 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/4band_v2.json b/lib_v5/vr_network/modelparams/4band_v2.json new file mode 100644 index 0000000000000000000000000000000000000000..33281a0cf9916fc33558ddfda7a0287a2547faf4 --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_v2.json @@ -0,0 +1,54 @@ +{ + "bins": 672, + "unstable_bins": 8, + "reduction_bins": 637, + "band": { + "1": { + "sr": 7350, + "hl": 80, + "n_fft": 640, + "crop_start": 0, + "crop_stop": 85, + "lpf_start": 25, + "lpf_stop": 53, + "res_type": "polyphase" + }, + "2": { + "sr": 7350, + "hl": 80, + "n_fft": 320, + "crop_start": 4, + "crop_stop": 87, + "hpf_start": 25, + "hpf_stop": 12, + "lpf_start": 31, + "lpf_stop": 62, + "res_type": "polyphase" + }, + "3": { + "sr": 14700, + "hl": 160, + "n_fft": 512, + "crop_start": 17, + "crop_stop": 216, + "hpf_start": 48, + "hpf_stop": 24, + "lpf_start": 139, + "lpf_stop": 210, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 480, + "n_fft": 960, + "crop_start": 78, + "crop_stop": 383, + "hpf_start": 130, + "hpf_stop": 86, + "res_type": "kaiser_fast" + } + }, + "sr": 44100, + "pre_filter_start": 668, + "pre_filter_stop": 672 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/4band_v2_sn.json b/lib_v5/vr_network/modelparams/4band_v2_sn.json new file mode 100644 index 0000000000000000000000000000000000000000..2e5c770fe188779bf6b0873190b7a324d6a867b2 --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_v2_sn.json @@ -0,0 +1,55 @@ +{ + "bins": 672, + "unstable_bins": 8, + "reduction_bins": 637, + "band": { + "1": { + "sr": 7350, + "hl": 80, + "n_fft": 640, + "crop_start": 0, + "crop_stop": 85, + "lpf_start": 25, + "lpf_stop": 53, + "res_type": "polyphase" + }, + "2": { + "sr": 7350, + "hl": 80, + "n_fft": 320, + "crop_start": 4, + "crop_stop": 87, + "hpf_start": 25, + "hpf_stop": 12, + "lpf_start": 31, + "lpf_stop": 62, + "res_type": "polyphase" + }, + "3": { + "sr": 14700, + "hl": 160, + "n_fft": 512, + "crop_start": 17, + "crop_stop": 216, + "hpf_start": 48, + "hpf_stop": 24, + "lpf_start": 139, + "lpf_stop": 210, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 480, + "n_fft": 960, + "crop_start": 78, + "crop_stop": 383, + "hpf_start": 130, + "hpf_stop": 86, + "convert_channels": "stereo_n", + "res_type": "kaiser_fast" + } + }, + "sr": 44100, + "pre_filter_start": 668, + "pre_filter_stop": 672 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/4band_v3.json b/lib_v5/vr_network/modelparams/4band_v3.json new file mode 100644 index 0000000000000000000000000000000000000000..edb908b8853c6359d1e98ae381888d1a9906ca0f --- /dev/null +++ b/lib_v5/vr_network/modelparams/4band_v3.json @@ -0,0 +1,54 @@ +{ + "bins": 672, + "unstable_bins": 8, + "reduction_bins": 530, + "band": { + "1": { + "sr": 7350, + "hl": 80, + "n_fft": 640, + "crop_start": 0, + "crop_stop": 85, + "lpf_start": 25, + "lpf_stop": 53, + "res_type": "polyphase" + }, + "2": { + "sr": 7350, + "hl": 80, + "n_fft": 320, + "crop_start": 4, + "crop_stop": 87, + "hpf_start": 25, + "hpf_stop": 12, + "lpf_start": 31, + "lpf_stop": 62, + "res_type": "polyphase" + }, + "3": { + "sr": 14700, + "hl": 160, + "n_fft": 512, + "crop_start": 17, + "crop_stop": 216, + "hpf_start": 48, + "hpf_stop": 24, + "lpf_start": 139, + "lpf_stop": 210, + "res_type": "polyphase" + }, + "4": { + "sr": 44100, + "hl": 480, + "n_fft": 960, + "crop_start": 78, + "crop_stop": 383, + "hpf_start": 130, + "hpf_stop": 86, + "res_type": "kaiser_fast" + } + }, + "sr": 44100, + "pre_filter_start": 668, + "pre_filter_stop": 672 +} \ No newline at end of file diff --git a/lib_v5/vr_network/modelparams/ensemble.json b/lib_v5/vr_network/modelparams/ensemble.json new file mode 100644 index 0000000000000000000000000000000000000000..ee69beb46fc82f34619c5e48761e329fcabbbd00 --- /dev/null +++ b/lib_v5/vr_network/modelparams/ensemble.json @@ -0,0 +1,43 @@ +{ + "mid_side_b2": true, + "bins": 1280, + "unstable_bins": 7, + "reduction_bins": 565, + "band": { + "1": { + "sr": 11025, + "hl": 108, + "n_fft": 2048, + "crop_start": 0, + "crop_stop": 374, + "lpf_start": 92, + "lpf_stop": 186, + "res_type": "polyphase" + }, + "2": { + "sr": 22050, + "hl": 216, + "n_fft": 1536, + "crop_start": 0, + "crop_stop": 424, + "hpf_start": 68, + "hpf_stop": 34, + "lpf_start": 348, + "lpf_stop": 418, + "res_type": "polyphase" + }, + "3": { + "sr": 44100, + "hl": 432, + "n_fft": 1280, + "crop_start": 132, + "crop_stop": 614, + "hpf_start": 172, + "hpf_stop": 144, + "res_type": "polyphase" + } + }, + "sr": 44100, + "pre_filter_start": 1280, + "pre_filter_stop": 1280 +} \ No newline at end of file diff --git a/lib_v5/vr_network/nets.py b/lib_v5/vr_network/nets.py new file mode 100644 index 0000000000000000000000000000000000000000..064cad9e4a2a43681c09336d3ce8c3d1149bbd29 --- /dev/null +++ b/lib_v5/vr_network/nets.py @@ -0,0 +1,166 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from . import layers + +class BaseASPPNet(nn.Module): + + def __init__(self, nn_architecture, nin, ch, dilations=(4, 8, 16)): + super(BaseASPPNet, self).__init__() + self.nn_architecture = nn_architecture + self.enc1 = layers.Encoder(nin, ch, 3, 2, 1) + self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1) + self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1) + self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1) + + if self.nn_architecture == 129605: + self.enc5 = layers.Encoder(ch * 8, ch * 16, 3, 2, 1) + self.aspp = layers.ASPPModule(nn_architecture, ch * 16, ch * 32, dilations) + self.dec5 = layers.Decoder(ch * (16 + 32), ch * 16, 3, 1, 1) + else: + self.aspp = layers.ASPPModule(nn_architecture, ch * 8, ch * 16, dilations) + + self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1) + self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1) + self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1) + self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1) + + def __call__(self, x): + h, e1 = self.enc1(x) + h, e2 = self.enc2(h) + h, e3 = self.enc3(h) + h, e4 = self.enc4(h) + + if self.nn_architecture == 129605: + h, e5 = self.enc5(h) + h = self.aspp(h) + h = self.dec5(h, e5) + else: + h = self.aspp(h) + + h = self.dec4(h, e4) + h = self.dec3(h, e3) + h = self.dec2(h, e2) + h = self.dec1(h, e1) + + return h + +def determine_model_capacity(n_fft_bins, nn_architecture): + + sp_model_arch = [31191, 33966, 129605] + hp_model_arch = [123821, 123812] + hp2_model_arch = [537238, 537227] + + if nn_architecture in sp_model_arch: + model_capacity_data = [ + (2, 16), + (2, 16), + (18, 8, 1, 1, 0), + (8, 16), + (34, 16, 1, 1, 0), + (16, 32), + (32, 2, 1), + (16, 2, 1), + (16, 2, 1), + ] + + if nn_architecture in hp_model_arch: + model_capacity_data = [ + (2, 32), + (2, 32), + (34, 16, 1, 1, 0), + (16, 32), + (66, 32, 1, 1, 0), + (32, 64), + (64, 2, 1), + (32, 2, 1), + (32, 2, 1), + ] + + if nn_architecture in hp2_model_arch: + model_capacity_data = [ + (2, 64), + (2, 64), + (66, 32, 1, 1, 0), + (32, 64), + (130, 64, 1, 1, 0), + (64, 128), + (128, 2, 1), + (64, 2, 1), + (64, 2, 1), + ] + + cascaded = CascadedASPPNet + model = cascaded(n_fft_bins, model_capacity_data, nn_architecture) + + return model + +class CascadedASPPNet(nn.Module): + + def __init__(self, n_fft, model_capacity_data, nn_architecture): + super(CascadedASPPNet, self).__init__() + self.stg1_low_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[0]) + self.stg1_high_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[1]) + + self.stg2_bridge = layers.Conv2DBNActiv(*model_capacity_data[2]) + self.stg2_full_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[3]) + + self.stg3_bridge = layers.Conv2DBNActiv(*model_capacity_data[4]) + self.stg3_full_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[5]) + + self.out = nn.Conv2d(*model_capacity_data[6], bias=False) + self.aux1_out = nn.Conv2d(*model_capacity_data[7], bias=False) + self.aux2_out = nn.Conv2d(*model_capacity_data[8], bias=False) + + self.max_bin = n_fft // 2 + self.output_bin = n_fft // 2 + 1 + + self.offset = 128 + + def forward(self, x): + mix = x.detach() + x = x.clone() + + x = x[:, :, :self.max_bin] + + bandw = x.size()[2] // 2 + aux1 = torch.cat([ + self.stg1_low_band_net(x[:, :, :bandw]), + self.stg1_high_band_net(x[:, :, bandw:]) + ], dim=2) + + h = torch.cat([x, aux1], dim=1) + aux2 = self.stg2_full_band_net(self.stg2_bridge(h)) + + h = torch.cat([x, aux1, aux2], dim=1) + h = self.stg3_full_band_net(self.stg3_bridge(h)) + + mask = torch.sigmoid(self.out(h)) + mask = F.pad( + input=mask, + pad=(0, 0, 0, self.output_bin - mask.size()[2]), + mode='replicate') + + if self.training: + aux1 = torch.sigmoid(self.aux1_out(aux1)) + aux1 = F.pad( + input=aux1, + pad=(0, 0, 0, self.output_bin - aux1.size()[2]), + mode='replicate') + aux2 = torch.sigmoid(self.aux2_out(aux2)) + aux2 = F.pad( + input=aux2, + pad=(0, 0, 0, self.output_bin - aux2.size()[2]), + mode='replicate') + return mask * mix, aux1 * mix, aux2 * mix + else: + return mask# * mix + + def predict_mask(self, x): + mask = self.forward(x) + + if self.offset > 0: + mask = mask[:, :, :, self.offset:-self.offset] + + return mask \ No newline at end of file diff --git a/lib_v5/vr_network/nets_new.py b/lib_v5/vr_network/nets_new.py new file mode 100644 index 0000000000000000000000000000000000000000..db8260a8e8b8b7d0ebd6b54a43eff76eaea08e94 --- /dev/null +++ b/lib_v5/vr_network/nets_new.py @@ -0,0 +1,125 @@ +import torch +from torch import nn +import torch.nn.functional as F +from . import layers_new as layers + +class BaseNet(nn.Module): + + def __init__(self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))): + super(BaseNet, self).__init__() + self.enc1 = layers.Conv2DBNActiv(nin, nout, 3, 1, 1) + self.enc2 = layers.Encoder(nout, nout * 2, 3, 2, 1) + self.enc3 = layers.Encoder(nout * 2, nout * 4, 3, 2, 1) + self.enc4 = layers.Encoder(nout * 4, nout * 6, 3, 2, 1) + self.enc5 = layers.Encoder(nout * 6, nout * 8, 3, 2, 1) + + self.aspp = layers.ASPPModule(nout * 8, nout * 8, dilations, dropout=True) + + self.dec4 = layers.Decoder(nout * (6 + 8), nout * 6, 3, 1, 1) + self.dec3 = layers.Decoder(nout * (4 + 6), nout * 4, 3, 1, 1) + self.dec2 = layers.Decoder(nout * (2 + 4), nout * 2, 3, 1, 1) + self.lstm_dec2 = layers.LSTMModule(nout * 2, nin_lstm, nout_lstm) + self.dec1 = layers.Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1) + + def __call__(self, x): + e1 = self.enc1(x) + e2 = self.enc2(e1) + e3 = self.enc3(e2) + e4 = self.enc4(e3) + e5 = self.enc5(e4) + + h = self.aspp(e5) + + h = self.dec4(h, e4) + h = self.dec3(h, e3) + h = self.dec2(h, e2) + h = torch.cat([h, self.lstm_dec2(h)], dim=1) + h = self.dec1(h, e1) + + return h + +class CascadedNet(nn.Module): + + def __init__(self, n_fft, nn_arch_size, nout=32, nout_lstm=128): + super(CascadedNet, self).__init__() + + self.max_bin = n_fft // 2 + self.output_bin = n_fft // 2 + 1 + self.nin_lstm = self.max_bin // 2 + self.offset = 64 + nout = 64 if nn_arch_size == 218409 else nout + + self.stg1_low_band_net = nn.Sequential( + BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm), + layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0) + ) + + self.stg1_high_band_net = BaseNet(2, nout // 4, self.nin_lstm // 2, nout_lstm // 2) + + self.stg2_low_band_net = nn.Sequential( + BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm), + layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0) + ) + self.stg2_high_band_net = BaseNet(nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2) + + self.stg3_full_band_net = BaseNet(3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm) + + self.out = nn.Conv2d(nout, 2, 1, bias=False) + self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False) + + def forward(self, x): + x = x[:, :, :self.max_bin] + + bandw = x.size()[2] // 2 + l1_in = x[:, :, :bandw] + h1_in = x[:, :, bandw:] + l1 = self.stg1_low_band_net(l1_in) + h1 = self.stg1_high_band_net(h1_in) + aux1 = torch.cat([l1, h1], dim=2) + + l2_in = torch.cat([l1_in, l1], dim=1) + h2_in = torch.cat([h1_in, h1], dim=1) + l2 = self.stg2_low_band_net(l2_in) + h2 = self.stg2_high_band_net(h2_in) + aux2 = torch.cat([l2, h2], dim=2) + + f3_in = torch.cat([x, aux1, aux2], dim=1) + f3 = self.stg3_full_band_net(f3_in) + + mask = torch.sigmoid(self.out(f3)) + mask = F.pad( + input=mask, + pad=(0, 0, 0, self.output_bin - mask.size()[2]), + mode='replicate' + ) + + if self.training: + aux = torch.cat([aux1, aux2], dim=1) + aux = torch.sigmoid(self.aux_out(aux)) + aux = F.pad( + input=aux, + pad=(0, 0, 0, self.output_bin - aux.size()[2]), + mode='replicate' + ) + return mask, aux + else: + return mask + + def predict_mask(self, x): + mask = self.forward(x) + + if self.offset > 0: + mask = mask[:, :, :, self.offset:-self.offset] + assert mask.size()[3] > 0 + + return mask + + def predict(self, x): + mask = self.forward(x) + pred_mag = x * mask + + if self.offset > 0: + pred_mag = pred_mag[:, :, :, self.offset:-self.offset] + assert pred_mag.size()[3] > 0 + + return pred_mag diff --git a/models/Demucs_Models/model_data/model_data.json b/models/Demucs_Models/model_data/model_data.json new file mode 100644 index 0000000000000000000000000000000000000000..5588db2b4fb4b0fa9e4f0e3f6faa50d2385d3c62 --- /dev/null +++ b/models/Demucs_Models/model_data/model_data.json @@ -0,0 +1,219 @@ +{ + "0ddfc0eb5792638ad5dc27850236c246": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "26d308f91f3423a67dc69a6d12a8793d": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 9, + "mdx_n_fft_scale_set": 8192, + "primary_stem": "Other" + }, + "2cdd429caac38f0194b133884160f2c6": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "2f5501189a2f6db6349916fabe8c90de": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "398580b6d5d973af3120df54cee6759d": { + "compensate": 1.75, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "488b3e6f8bd3717d9d7c428476be2d75": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "4910e7827f335048bdac11fa967772f9": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 7, + "mdx_n_fft_scale_set": 4096, + "primary_stem": "Drums" + }, + "53c4baf4d12c3e6c3831bb8f5b532b93": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "5d343409ef0df48c7d78cce9f0106781": { + "compensate": 1.075, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "5f6483271e1efb9bfb59e4a3e6d4d098": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 9, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "65ab5919372a128e4167f5e01a8fda85": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 8192, + "primary_stem": "Other" + }, + "6703e39f36f18aa7855ee1047765621d": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 9, + "mdx_n_fft_scale_set": 16384, + "primary_stem": "Bass" + }, + "6b31de20e84392859a3d09d43f089515": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "867595e9de46f6ab699008295df62798": { + "compensate": 1.075, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "a3cd63058945e777505c01d2507daf37": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "b33d9b3950b6cbf5fe90a32608924700": { + "compensate": 1.075, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "c3b29bdce8c4fa17ec609e16220330ab": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 16384, + "primary_stem": "Bass" + }, + "ceed671467c1f64ebdfac8a2490d0d52": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "d2a1376f310e4f7fa37fb9b5774eb701": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "d7bff498db9324db933d913388cba6be": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "d94058f8c7f1fae4164868ae8ae66b20": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "dc41ede5961d50f277eb846db17f5319": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 9, + "mdx_n_fft_scale_set": 4096, + "primary_stem": "Drums" + }, + "e5572e58abf111f80d8241d2e44e7fa4": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "e7324c873b1f615c35c1967f912db92a": { + "compensate": 1.075, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "1c56ec0224f1d559c42fd6fd2a67b154": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 5120, + "primary_stem": "Instrumental" + }, + "f2df6d6863d8f435436d8b561594ff49": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "b06327a00d5e5fbc7d96e1781bbdb596": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Instrumental" + }, + "94ff780b977d3ca07c7a343dab2e25dd": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Instrumental" + }, + "73492b58195c3b52d34590d5474452f6": { + "compensate": 1.075, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "1d64a6d2c30f709b8c9b4ce1366d96ee": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 5120, + "primary_stem": "Instrumental" + }, + "203f2a3955221b64df85a41af87cf8f0": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Instrumental" + } +} \ No newline at end of file diff --git a/models/Demucs_Models/model_data/model_name_mapper.json b/models/Demucs_Models/model_data/model_name_mapper.json new file mode 100644 index 0000000000000000000000000000000000000000..249531114fbbc5eff774c70a76d832834b268e8e --- /dev/null +++ b/models/Demucs_Models/model_data/model_name_mapper.json @@ -0,0 +1,34 @@ +{ + "tasnet.th": "v1 | Tasnet", + "tasnet_extra.th": "v1 | Tasnet_extra", + "demucs.th": "v1 | Demucs", + "demucs_extra.th": "v1 | Demucs_extra", + "light.th": "v1 | Light", + "light_extra.th": "v1 | Light_extra", + "tasnet.th.gz": "v1 | Tasnet.gz", + "tasnet_extra.th.gz": "v1 | Tasnet_extra.gz", + "demucs.th.gz": "v1 | Demucs_extra.gz", + "light.th.gz": "v1 | Light.gz", + "light_extra.th.gz": "v1 | Light_extra.gz", + "tasnet-beb46fac.th": "v2 | Tasnet", + "tasnet_extra-df3777b2.th": "v2 | Tasnet_extra", + "demucs48_hq-28a1282c.th": "v2 | Demucs48_hq", + "demucs-e07c671f.th": "v2 | Demucs", + "demucs_extra-3646af93.th": "v2 | Demucs_extra", + "demucs_unittest-09ebc15f.th": "v2 | Demucs_unittest", + "mdx.yaml": "v3 | mdx", + "mdx_extra.yaml": "v3 | mdx_extra", + "mdx_extra_q.yaml": "v3 | mdx_extra_q", + "mdx_q.yaml": "v3 | mdx_q", + "repro_mdx_a.yaml": "v3 | repro_mdx_a", + "repro_mdx_a_hybrid_only.yaml": "v3 | repro_mdx_a_hybrid", + "repro_mdx_a_time_only.yaml": "v3 | repro_mdx_a_time", + "UVR_Demucs_Model_1.yaml": "v3 | UVR_Model_1", + "UVR_Demucs_Model_2.yaml": "v3 | UVR_Model_2", + "UVR_Demucs_Model_Bag.yaml": "v3 | UVR_Model_Bag", + "hdemucs_mmi.yaml": "v4 | hdemucs_mmi", + "htdemucs.yaml": "v4 | htdemucs", + "htdemucs_ft.yaml": "v4 | htdemucs_ft", + "htdemucs_6s.yaml": "v4 | htdemucs_6s", + "UVR_Demucs_Model_ht.yaml": "v4 | UVR_Model_ht" +} \ No newline at end of file diff --git a/models/Demucs_Models/v3_v4_repo/demucs_models.txt b/models/Demucs_Models/v3_v4_repo/demucs_models.txt new file mode 100644 index 0000000000000000000000000000000000000000..584c89e54b6df8077049cfc49686092c41b886a1 --- /dev/null +++ b/models/Demucs_Models/v3_v4_repo/demucs_models.txt @@ -0,0 +1 @@ +Demucs v3 and v4 models go here. \ No newline at end of file diff --git a/models/MDX_Net_Models/model_data/model_data.json b/models/MDX_Net_Models/model_data/model_data.json new file mode 100644 index 0000000000000000000000000000000000000000..3973d7b2b181510214e5ed6b1832c0d3e15b1c13 --- /dev/null +++ b/models/MDX_Net_Models/model_data/model_data.json @@ -0,0 +1,254 @@ +{ + "0ddfc0eb5792638ad5dc27850236c246": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "26d308f91f3423a67dc69a6d12a8793d": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 9, + "mdx_n_fft_scale_set": 8192, + "primary_stem": "Other" + }, + "2cdd429caac38f0194b133884160f2c6": { + "compensate": 1.045, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "2f5501189a2f6db6349916fabe8c90de": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "398580b6d5d973af3120df54cee6759d": { + "compensate": 1.75, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "488b3e6f8bd3717d9d7c428476be2d75": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "4910e7827f335048bdac11fa967772f9": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 7, + "mdx_n_fft_scale_set": 4096, + "primary_stem": "Drums" + }, + "53c4baf4d12c3e6c3831bb8f5b532b93": { + "compensate": 1.043, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "5d343409ef0df48c7d78cce9f0106781": { + "compensate": 1.075, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "5f6483271e1efb9bfb59e4a3e6d4d098": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 9, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "65ab5919372a128e4167f5e01a8fda85": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 8192, + "primary_stem": "Other" + }, + "6703e39f36f18aa7855ee1047765621d": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 9, + "mdx_n_fft_scale_set": 16384, + "primary_stem": "Bass" + }, + "6b31de20e84392859a3d09d43f089515": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "867595e9de46f6ab699008295df62798": { + "compensate": 1.030, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "a3cd63058945e777505c01d2507daf37": { + "compensate": 1.030, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "b33d9b3950b6cbf5fe90a32608924700": { + "compensate": 1.030, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "c3b29bdce8c4fa17ec609e16220330ab": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 16384, + "primary_stem": "Bass" + }, + "ceed671467c1f64ebdfac8a2490d0d52": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "d2a1376f310e4f7fa37fb9b5774eb701": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "d7bff498db9324db933d913388cba6be": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "d94058f8c7f1fae4164868ae8ae66b20": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Vocals" + }, + "dc41ede5961d50f277eb846db17f5319": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 9, + "mdx_n_fft_scale_set": 4096, + "primary_stem": "Drums" + }, + "e5572e58abf111f80d8241d2e44e7fa4": { + "compensate": 1.028, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "e7324c873b1f615c35c1967f912db92a": { + "compensate": 1.030, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "1c56ec0224f1d559c42fd6fd2a67b154": { + "compensate": 1.025, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 5120, + "primary_stem": "Instrumental" + }, + "f2df6d6863d8f435436d8b561594ff49": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + }, + "b06327a00d5e5fbc7d96e1781bbdb596": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Instrumental" + }, + "94ff780b977d3ca07c7a343dab2e25dd": { + "compensate": 1.039, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Instrumental" + }, + "73492b58195c3b52d34590d5474452f6": { + "compensate": 1.043, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Vocals" + }, + "1d64a6d2c30f709b8c9b4ce1366d96ee": { + "compensate": 1.035, + "mdx_dim_f_set": 2048, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 5120, + "primary_stem": "Instrumental" + }, + "203f2a3955221b64df85a41af87cf8f0": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Instrumental" + }, + "291c2049608edb52648b96e27eb80e95": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Instrumental" + }, + "ead8d05dab12ec571d67549b3aab03fc": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Instrumental" + }, + "cc63408db3d80b4d85b0287d1d7c9632": { + "compensate": 1.033, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Instrumental" + }, + "cd5b2989ad863f116c855db1dfe24e39": { + "compensate": 1.035, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 9, + "mdx_n_fft_scale_set": 6144, + "primary_stem": "Other" + }, + "b6bccda408a436db8500083ef3491e8b": { + "compensate": 1.020, + "mdx_dim_f_set": 3072, + "mdx_dim_t_set": 8, + "mdx_n_fft_scale_set": 7680, + "primary_stem": "Instrumental" + } +} diff --git a/models/MDX_Net_Models/model_data/model_name_mapper.json b/models/MDX_Net_Models/model_data/model_name_mapper.json new file mode 100644 index 0000000000000000000000000000000000000000..6ff2c51f5a47a0394f3e4a6da6f57d494db01bda --- /dev/null +++ b/models/MDX_Net_Models/model_data/model_name_mapper.json @@ -0,0 +1,17 @@ +{ + "UVR_MDXNET_1_9703": "UVR-MDX-NET 1", + "UVR_MDXNET_2_9682": "UVR-MDX-NET 2", + "UVR_MDXNET_3_9662": "UVR-MDX-NET 3", + "UVR_MDXNET_KARA": "UVR-MDX-NET Karaoke", + "UVR_MDXNET_Main": "UVR-MDX-NET Main", + "UVR-MDX-NET-Inst_1": "UVR-MDX-NET Inst 1", + "UVR-MDX-NET-Inst_2": "UVR-MDX-NET Inst 2", + "UVR-MDX-NET-Inst_3": "UVR-MDX-NET Inst 3", + "UVR-MDX-NET-Inst_Main": "UVR-MDX-NET Inst Main", + "UVR-MDX-NET-Inst_HQ_1": "UVR-MDX-NET Inst HQ 1", + "UVR-MDX-NET-Inst_HQ_2": "UVR-MDX-NET Inst HQ 2", + "UVR_MDXNET_KARA_2": "UVR-MDX-NET Karaoke 2", + "Kim_Vocal_1": "Kim Vocal 1", + "Kim_Inst": "Kim Inst", + "Reverb_HQ_By_FoxJoy": "Reverb HQ" +} diff --git a/models/VR_Models/model_data/model_data.json b/models/VR_Models/model_data/model_data.json new file mode 100644 index 0000000000000000000000000000000000000000..6501437f5523ea65bdc27b8fc5d1b11472db1ae2 --- /dev/null +++ b/models/VR_Models/model_data/model_data.json @@ -0,0 +1,110 @@ +{ + "0d0e6d143046b0eecc41a22e60224582": { + "vr_model_param": "3band_44100_mid", + "primary_stem": "Instrumental" + }, + "18b52f873021a0af556fb4ecd552bb8e": { + "vr_model_param": "2band_32000", + "primary_stem": "Instrumental" + }, + "1fc66027c82b499c7d8f55f79e64cadc": { + "vr_model_param": "2band_32000", + "primary_stem": "Instrumental" + }, + "2aa34fbc01f8e6d2bf509726481e7142": { + "vr_model_param": "4band_44100", + "primary_stem": "No Piano" + }, + "3e18f639b11abea7361db1a4a91c2559": { + "vr_model_param": "4band_44100", + "primary_stem": "Instrumental" + }, + "570b5f50054609a17741369a35007ddd": { + "vr_model_param": "4band_v3", + "primary_stem": "Instrumental" + }, + "5a6e24c1b530f2dab045a522ef89b751": { + "vr_model_param": "1band_sr44100_hl512", + "primary_stem": "Instrumental" + }, + "6b5916069a49be3fe29d4397ecfd73fa": { + "vr_model_param": "3band_44100_msb2", + "primary_stem": "Instrumental" + }, + "74b3bc5fa2b69f29baf7839b858bc679": { + "vr_model_param": "4band_44100", + "primary_stem": "Instrumental" + }, + "827213b316df36b52a1f3d04fec89369": { + "vr_model_param": "4band_44100", + "primary_stem": "Instrumental" + }, + "911d4048eee7223eca4ee0efb7d29256": { + "vr_model_param": "4band_44100", + "primary_stem": "Vocals" + }, + "941f3f7f0b0341f12087aacdfef644b1": { + "vr_model_param": "4band_v2", + "primary_stem": "Instrumental" + }, + "a02827cf69d75781a35c0e8a327f3195": { + "vr_model_param": "1band_sr33075_hl384", + "primary_stem": "Instrumental" + }, + "b165fbff113c959dba5303b74c6484bc": { + "vr_model_param": "3band_44100", + "primary_stem": "Instrumental" + }, + "b5f988cd3e891dca7253bf5f0f3427c7": { + "vr_model_param": "4band_44100", + "primary_stem": "Instrumental" + }, + "b99c35723bc35cb11ed14a4780006a80": { + "vr_model_param": "1band_sr44100_hl1024", + "primary_stem": "Instrumental" + }, + "ba02fd25b71d620eebbdb49e18e4c336": { + "vr_model_param": "3band_44100_mid", + "primary_stem": "Instrumental" + }, + "c4476ef424d8cba65f38d8d04e8514e2": { + "vr_model_param": "3band_44100_msb2", + "primary_stem": "Instrumental" + }, + "da2d37b8be2972e550a409bae08335aa": { + "vr_model_param": "4band_44100", + "primary_stem": "Vocals" + }, + "db57205d3133e39df8e050b435a78c80": { + "vr_model_param": "4band_44100", + "primary_stem": "Instrumental" + }, + "ea83b08e32ec2303456fe50659035f69": { + "vr_model_param": "4band_v3", + "primary_stem": "Instrumental" + }, + "f6ea8473ff86017b5ebd586ccacf156b": { + "vr_model_param": "4band_v2_sn", + "primary_stem": "Instrumental" + }, + "fd297a61eafc9d829033f8b987c39a3d": { + "vr_model_param": "1band_sr32000_hl512", + "primary_stem": "Instrumental" + }, + "0ec76fd9e65f81d8b4fbd13af4826ed8": { + "vr_model_param": "4band_v3", + "primary_stem": "No Woodwinds" + }, + "6857b2972e1754913aad0c9a1678c753": { + "vr_model_param": "4band_v3", + "primary_stem": "No Other", + "nout":48, + "nout_lstm":128 + }, + "f200a145434efc7dcf0cd093f517ed52": { + "vr_model_param": "4band_v3", + "primary_stem": "No Other", + "nout":48, + "nout_lstm":128 + } +} diff --git a/models/download_checks.json b/models/download_checks.json new file mode 100644 index 0000000000000000000000000000000000000000..c2bc27cce8f26ce05fac2e646aa4962e362b15f0 --- /dev/null +++ b/models/download_checks.json @@ -0,0 +1,201 @@ +{ + "current_version": "UVR_Patch_3_31_23_5_5", + "current_version_mac": "UVR_Patch_3_31_23_5_5", + "current_version_linux": "UVR_Patch_3_31_23_5_5", + "vr_download_list": { + "VR Arch Single Model v5: 1_HP-UVR": "1_HP-UVR.pth", + "VR Arch Single Model v5: 2_HP-UVR": "2_HP-UVR.pth", + "VR Arch Single Model v5: 3_HP-Vocal-UVR": "3_HP-Vocal-UVR.pth", + "VR Arch Single Model v5: 4_HP-Vocal-UVR": "4_HP-Vocal-UVR.pth", + "VR Arch Single Model v5: 5_HP-Karaoke-UVR": "5_HP-Karaoke-UVR.pth", + "VR Arch Single Model v5: 6_HP-Karaoke-UVR": "6_HP-Karaoke-UVR.pth", + "VR Arch Single Model v5: 7_HP2-UVR": "7_HP2-UVR.pth", + "VR Arch Single Model v5: 8_HP2-UVR": "8_HP2-UVR.pth", + "VR Arch Single Model v5: 9_HP2-UVR": "9_HP2-UVR.pth", + "VR Arch Single Model v5: 10_SP-UVR-2B-32000-1": "10_SP-UVR-2B-32000-1.pth", + "VR Arch Single Model v5: 11_SP-UVR-2B-32000-2": "11_SP-UVR-2B-32000-2.pth", + "VR Arch Single Model v5: 12_SP-UVR-3B-44100": "12_SP-UVR-3B-44100.pth", + "VR Arch Single Model v5: 13_SP-UVR-4B-44100-1": "13_SP-UVR-4B-44100-1.pth", + "VR Arch Single Model v5: 14_SP-UVR-4B-44100-2": "14_SP-UVR-4B-44100-2.pth", + "VR Arch Single Model v5: 15_SP-UVR-MID-44100-1": "15_SP-UVR-MID-44100-1.pth", + "VR Arch Single Model v5: 16_SP-UVR-MID-44100-2": "16_SP-UVR-MID-44100-2.pth", + "VR Arch Single Model v5: 17_HP-Wind_Inst-UVR": "17_HP-Wind_Inst-UVR.pth", + "VR Arch Single Model v5: UVR-De-Echo-Aggressive by FoxJoy": "UVR-De-Echo-Aggressive.pth", + "VR Arch Single Model v5: UVR-De-Echo-Normal by FoxJoy": "UVR-De-Echo-Normal.pth", + "VR Arch Single Model v4: MGM_HIGHEND_v4": "MGM_HIGHEND_v4.pth", + "VR Arch Single Model v4: MGM_LOWEND_A_v4": "MGM_LOWEND_A_v4.pth", + "VR Arch Single Model v4: MGM_LOWEND_B_v4": "MGM_LOWEND_B_v4.pth", + "VR Arch Single Model v4: MGM_MAIN_v4": "MGM_MAIN_v4.pth" + }, + + "mdx_download_list": { + "MDX-Net Model: UVR-MDX-NET Inst HQ 1": "UVR-MDX-NET-Inst_HQ_1.onnx", + "MDX-Net Model: UVR-MDX-NET Inst HQ 2": "UVR-MDX-NET-Inst_HQ_2.onnx", + "MDX-Net Model: UVR-MDX-NET Main": "UVR_MDXNET_Main.onnx", + "MDX-Net Model: UVR-MDX-NET Inst Main": "UVR-MDX-NET-Inst_Main.onnx", + "MDX-Net Model: UVR-MDX-NET 1": "UVR_MDXNET_1_9703.onnx", + "MDX-Net Model: UVR-MDX-NET 2": "UVR_MDXNET_2_9682.onnx", + "MDX-Net Model: UVR-MDX-NET 3": "UVR_MDXNET_3_9662.onnx", + "MDX-Net Model: UVR-MDX-NET Inst 1": "UVR-MDX-NET-Inst_1.onnx", + "MDX-Net Model: UVR-MDX-NET Inst 2": "UVR-MDX-NET-Inst_2.onnx", + "MDX-Net Model: UVR-MDX-NET Inst 3": "UVR-MDX-NET-Inst_3.onnx", + "MDX-Net Model: UVR-MDX-NET Karaoke": "UVR_MDXNET_KARA.onnx", + "MDX-Net Model: UVR-MDX-NET Karaoke 2": "UVR_MDXNET_KARA_2.onnx", + "MDX-Net Model: UVR_MDXNET_9482": "UVR_MDXNET_9482.onnx", + "MDX-Net Model: Kim Vocal 1": "Kim_Vocal_1.onnx", + "MDX-Net Model: Kim Inst": "Kim_Inst.onnx", + "MDX-Net Model: Reverb HQ By FoxJoy": "Reverb_HQ_By_FoxJoy.onnx", + "MDX-Net Model: kuielab_a_vocals": "kuielab_a_vocals.onnx", + "MDX-Net Model: kuielab_a_other": "kuielab_a_other.onnx", + "MDX-Net Model: kuielab_a_bass": "kuielab_a_bass.onnx", + "MDX-Net Model: kuielab_a_drums": "kuielab_a_drums.onnx", + "MDX-Net Model: kuielab_b_vocals": "kuielab_b_vocals.onnx", + "MDX-Net Model: kuielab_b_other": "kuielab_b_other.onnx", + "MDX-Net Model: kuielab_b_bass": "kuielab_b_bass.onnx", + "MDX-Net Model: kuielab_b_drums": "kuielab_b_drums.onnx" + }, + + "demucs_download_list":{ + + "Demucs v4: htdemucs_ft":{ + "f7e0c4bc-ba3fe64a.th":"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th", + "d12395a8-e57c48e6.th":"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/d12395a8-e57c48e6.th", + "92cfc3b6-ef3bcb9c.th":"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/92cfc3b6-ef3bcb9c.th", + "04573f0d-f3cf25b2.th":"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th", + "htdemucs_ft.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/htdemucs_ft.yaml" + }, + + "Demucs v4: htdemucs":{ + "955717e8-8726e21a.th": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/955717e8-8726e21a.th", + "htdemucs.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/htdemucs.yaml" + }, + + "Demucs v4: hdemucs_mmi":{ + "75fc33f5-1941ce65.th": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/75fc33f5-1941ce65.th", + "hdemucs_mmi.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/hdemucs_mmi.yaml" + }, + "Demucs v4: htdemucs_6s":{ + "5c90dfd2-34c22ccb.th": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/5c90dfd2-34c22ccb.th", + "htdemucs_6s.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/htdemucs_6s.yaml" + }, + "Demucs v3: mdx":{ + "0d19c1c6-0f06f20e.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/0d19c1c6-0f06f20e.th", + "7ecf8ec1-70f50cc9.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/7ecf8ec1-70f50cc9.th", + "c511e2ab-fe698775.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/c511e2ab-fe698775.th", + "7d865c68-3d5dd56b.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/7d865c68-3d5dd56b.th", + "mdx.yaml": "https://raw.githubusercontent.com/facebookresearch/demucs/main/demucs/remote/mdx.yaml" + }, + + "Demucs v3: mdx_q":{ + "6b9c2ca1-3fd82607.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/6b9c2ca1-3fd82607.th", + "b72baf4e-8778635e.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/b72baf4e-8778635e.th", + "42e558d4-196e0e1b.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/42e558d4-196e0e1b.th", + "305bc58f-18378783.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/305bc58f-18378783.th", + "mdx_q.yaml": "https://raw.githubusercontent.com/facebookresearch/demucs/main/demucs/remote/mdx_q.yaml" + }, + + "Demucs v3: mdx_extra":{ + "e51eebcc-c1b80bdd.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/e51eebcc-c1b80bdd.th", + "a1d90b5c-ae9d2452.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/a1d90b5c-ae9d2452.th", + "5d2d6c55-db83574e.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/5d2d6c55-db83574e.th", + "cfa93e08-61801ae1.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/cfa93e08-61801ae1.th", + "mdx_extra.yaml": "https://raw.githubusercontent.com/facebookresearch/demucs/main/demucs/remote/mdx_extra.yaml" + }, + + "Demucs v3: mdx_extra_q": { + "83fc094f-4a16d450.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/83fc094f-4a16d450.th", + "464b36d7-e5a9386e.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/464b36d7-e5a9386e.th", + "14fc6a69-a89dd0ee.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/14fc6a69-a89dd0ee.th", + "7fd6ef75-a905dd85.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/7fd6ef75-a905dd85.th", + "mdx_extra_q.yaml": "https://raw.githubusercontent.com/facebookresearch/demucs/main/demucs/remote/mdx_extra_q.yaml" + }, + + "Demucs v3: UVR Model":{ + "ebf34a2db.th": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/ebf34a2db.th", + "UVR_Demucs_Model_1.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/UVR_Demucs_Model_1.yaml" + }, + + "Demucs v3: repro_mdx_a":{ + "9a6b4851-03af0aa6.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/9a6b4851-03af0aa6.th", + "1ef250f1-592467ce.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/1ef250f1-592467ce.th", + "fa0cb7f9-100d8bf4.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/fa0cb7f9-100d8bf4.th", + "902315c2-b39ce9c9.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/902315c2-b39ce9c9.th", + "repro_mdx_a.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/repro_mdx_a.yaml" + }, + + "Demucs v3: repro_mdx_a_time_only":{ + "9a6b4851-03af0aa6.th":"https://dl.fbaipublicfiles.com/demucs/mdx_final/9a6b4851-03af0aa6.th", + "1ef250f1-592467ce.th":"https://dl.fbaipublicfiles.com/demucs/mdx_final/1ef250f1-592467ce.th", + "repro_mdx_a_time_only.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/repro_mdx_a_time_only.yaml" + }, + + "Demucs v3: repro_mdx_a_hybrid_only":{ + "fa0cb7f9-100d8bf4.th":"https://dl.fbaipublicfiles.com/demucs/mdx_final/fa0cb7f9-100d8bf4.th", + "902315c2-b39ce9c9.th":"https://dl.fbaipublicfiles.com/demucs/mdx_final/902315c2-b39ce9c9.th", + "repro_mdx_a_hybrid_only.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/repro_mdx_a_hybrid_only.yaml" + }, + + "Demucs v2: demucs": { + "demucs-e07c671f.th": "https://dl.fbaipublicfiles.com/demucs/v3.0/demucs-e07c671f.th" + }, + + "Demucs v2: demucs_extra": { + "demucs_extra-3646af93.th":"https://dl.fbaipublicfiles.com/demucs/v3.0/demucs_extra-3646af93.th" + }, + + "Demucs v2: demucs48_hq": { + "demucs48_hq-28a1282c.th":"https://dl.fbaipublicfiles.com/demucs/v3.0/demucs48_hq-28a1282c.th" + }, + + "Demucs v2: tasnet": { + "tasnet-beb46fac.th":"https://dl.fbaipublicfiles.com/demucs/v3.0/tasnet-beb46fac.th" + }, + + "Demucs v2: tasnet_extra": { + "tasnet_extra-df3777b2.th":"https://dl.fbaipublicfiles.com/demucs/v3.0/tasnet_extra-df3777b2.th" + }, + + "Demucs v2: demucs_unittest": { + "demucs_unittest-09ebc15f.th":"https://dl.fbaipublicfiles.com/demucs/v3.0/demucs_unittest-09ebc15f.th" + }, + + "Demucs v1: demucs": { + "demucs.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/demucs.th" + }, + + "Demucs v1: demucs_extra": { + "demucs_extra.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/demucs_extra.th" + }, + + "Demucs v1: light": { + "light.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/light.th" + }, + + "Demucs v1: light_extra": { + "light_extra.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/light_extra.th" + }, + + "Demucs v1: tasnet": { + "tasnet.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/tasnet.th" + }, + + "Demucs v1: tasnet_extra": { + "tasnet_extra.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/tasnet_extra.th" + } + }, + + "mdx_download_vip_list": { + "MDX-Net Model VIP: UVR-MDX-NET_Main_340": "UVR-MDX-NET_Main_340.onnx", + "MDX-Net Model VIP: UVR-MDX-NET_Main_390": "UVR-MDX-NET_Main_390.onnx", + "MDX-Net Model VIP: UVR-MDX-NET_Main_406": "UVR-MDX-NET_Main_406.onnx", + "MDX-Net Model VIP: UVR-MDX-NET_Main_427": "UVR-MDX-NET_Main_427.onnx", + "MDX-Net Model VIP: UVR-MDX-NET_Main_438": "UVR-MDX-NET_Main_438.onnx", + "MDX-Net Model VIP: UVR-MDX-NET_Inst_82_beta": "UVR-MDX-NET_Inst_82_beta.onnx", + "MDX-Net Model VIP: UVR-MDX-NET_Inst_90_beta": "UVR-MDX-NET_Inst_90_beta.onnx", + "MDX-Net Model VIP: UVR-MDX-NET_Inst_187_beta": "UVR-MDX-NET_Inst_187_beta.onnx", + "MDX-Net Model VIP: UVR-MDX-NET-Inst_full_292": "UVR-MDX-NET-Inst_full_292.onnx" + }, + + "vr_download_vip_list": [], + "demucs_download_vip_list": [] +} diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..5ea7c73c110e77879f55380ebd7d49e2bf2f2430 --- /dev/null +++ b/packages.txt @@ -0,0 +1,3 @@ +git-lfs +aria2 -y +ffmpeg \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2da18fea066172956a268917a19d93afc803950c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,43 @@ +altgraph==0.17.3 +audioread==3.0.0 +certifi==2022.12.7 +cffi==1.15.1 +cryptography==3.4.6 +diffq==0.2.3 +git+https://github.com/NathanEpstein/Dora@7119720 +einops==0.6.0 +future==0.18.2 +julius==0.2.7 +kthread==0.2.3 +librosa==0.9.2 +llvmlite==0.39.1 +natsort==8.2.0 +numba==0.56.4 +numpy==1.23.4 +omegaconf==2.2.3 +opencv-python==4.6.0.66 +onnx +onnxruntime==1.13.1 +Pillow==9.3.0 +playsound==1.3.0 +psutil==5.9.4 +pydub==0.25.1 +pyglet==1.5.23 +pyperclip==1.8.2 +pyrubberband==0.3.0 +pytorch_lightning==2.0.0 +PyYAML==6.0 +resampy==0.2.2 +scipy==1.9.3 +soundfile==0.11.0 +soundstretch==1.2 +torch==1.13.1 +tqdm +urllib3==1.26.12 +wget==3.2 +samplerate==0.1.0 +screeninfo==0.8.1 +PySoundFile==0.9.0.post1; sys_platform != 'windows' +SoundFile==0.9.0; sys_platform == 'windows' + +gradio >= 3.19 \ No newline at end of file diff --git a/separate.py b/separate.py new file mode 100644 index 0000000000000000000000000000000000000000..c0adade0b8c0dc48b4c503ee66cd1ce50be281cd --- /dev/null +++ b/separate.py @@ -0,0 +1,942 @@ +from __future__ import annotations +from typing import TYPE_CHECKING +from demucs.apply import apply_model, demucs_segments +from demucs.hdemucs import HDemucs +from demucs.model_v2 import auto_load_demucs_model_v2 +from demucs.pretrained import get_model as _gm +from demucs.utils import apply_model_v1 +from demucs.utils import apply_model_v2 +from lib_v5 import spec_utils +from lib_v5.vr_network import nets +from lib_v5.vr_network import nets_new +#from lib_v5.vr_network.model_param_init import ModelParameters +from pathlib import Path +from gui_data.constants import * +from gui_data.error_handling import * +import audioread +import gzip +import librosa +import math +import numpy as np +import onnxruntime as ort +import os +import torch +import warnings +import pydub +import soundfile as sf +import traceback +import lib_v5.mdxnet as MdxnetSet + +if TYPE_CHECKING: + from UVR import ModelData + +warnings.filterwarnings("ignore") +cpu = torch.device('cpu') + +class SeperateAttributes: + def __init__(self, model_data: ModelData, process_data: dict, main_model_primary_stem_4_stem=None, main_process_method=None): + + self.list_all_models: list + self.process_data = process_data + self.progress_value = 0 + self.set_progress_bar = process_data['set_progress_bar'] + self.write_to_console = process_data['write_to_console'] + self.audio_file = process_data['audio_file'] + self.audio_file_base = process_data['audio_file_base'] + self.export_path = process_data['export_path'] + self.cached_source_callback = process_data['cached_source_callback'] + self.cached_model_source_holder = process_data['cached_model_source_holder'] + self.is_4_stem_ensemble = process_data['is_4_stem_ensemble'] + self.list_all_models = process_data['list_all_models'] + self.process_iteration = process_data['process_iteration'] + self.mixer_path = model_data.mixer_path + self.model_samplerate = model_data.model_samplerate + self.model_capacity = model_data.model_capacity + self.is_vr_51_model = model_data.is_vr_51_model + self.is_pre_proc_model = model_data.is_pre_proc_model + self.is_secondary_model_activated = model_data.is_secondary_model_activated if not self.is_pre_proc_model else False + self.is_secondary_model = model_data.is_secondary_model if not self.is_pre_proc_model else True + self.process_method = model_data.process_method + self.model_path = model_data.model_path + self.model_name = model_data.model_name + self.model_basename = model_data.model_basename + self.wav_type_set = model_data.wav_type_set + self.mp3_bit_set = model_data.mp3_bit_set + self.save_format = model_data.save_format + self.is_gpu_conversion = model_data.is_gpu_conversion + self.is_normalization = model_data.is_normalization + self.is_primary_stem_only = model_data.is_primary_stem_only if not self.is_secondary_model else model_data.is_primary_model_primary_stem_only + self.is_secondary_stem_only = model_data.is_secondary_stem_only if not self.is_secondary_model else model_data.is_primary_model_secondary_stem_only + self.is_ensemble_mode = model_data.is_ensemble_mode + self.secondary_model = model_data.secondary_model # + self.primary_model_primary_stem = model_data.primary_model_primary_stem + self.primary_stem = model_data.primary_stem # + self.secondary_stem = model_data.secondary_stem # + self.is_invert_spec = model_data.is_invert_spec # + self.is_mixer_mode = model_data.is_mixer_mode # + self.secondary_model_scale = model_data.secondary_model_scale # + self.is_demucs_pre_proc_model_inst_mix = model_data.is_demucs_pre_proc_model_inst_mix # + self.primary_source_map = {} + self.secondary_source_map = {} + self.primary_source = None + self.secondary_source = None + self.secondary_source_primary = None + self.secondary_source_secondary = None + + if not model_data.process_method == DEMUCS_ARCH_TYPE: + if process_data['is_ensemble_master'] and not self.is_4_stem_ensemble: + if not model_data.ensemble_primary_stem == self.primary_stem: + self.is_primary_stem_only, self.is_secondary_stem_only = self.is_secondary_stem_only, self.is_primary_stem_only + + if self.is_secondary_model and not process_data['is_ensemble_master']: + if not self.primary_model_primary_stem == self.primary_stem and not main_model_primary_stem_4_stem: + self.is_primary_stem_only, self.is_secondary_stem_only = self.is_secondary_stem_only, self.is_primary_stem_only + + if main_model_primary_stem_4_stem: + self.is_primary_stem_only = True if main_model_primary_stem_4_stem == self.primary_stem else False + self.is_secondary_stem_only = True if not main_model_primary_stem_4_stem == self.primary_stem else False + + if self.is_pre_proc_model: + self.is_primary_stem_only = True if self.primary_stem == INST_STEM else False + self.is_secondary_stem_only = True if self.secondary_stem == INST_STEM else False + + if model_data.process_method == MDX_ARCH_TYPE: + self.is_mdx_ckpt = model_data.is_mdx_ckpt + self.primary_model_name, self.primary_sources = self.cached_source_callback(MDX_ARCH_TYPE, model_name=self.model_basename) + self.is_denoise = model_data.is_denoise + self.mdx_batch_size = model_data.mdx_batch_size + self.compensate = model_data.compensate + self.dim_f, self.dim_t = model_data.mdx_dim_f_set, 2**model_data.mdx_dim_t_set + self.n_fft = model_data.mdx_n_fft_scale_set + self.chunks = model_data.chunks + self.margin = model_data.margin + self.adjust = 1 + self.dim_c = 4 + self.hop = 1024 + + if self.is_gpu_conversion >= 0 and torch.cuda.is_available(): + self.device, self.run_type = torch.device('cuda:0'), ['CUDAExecutionProvider'] + else: + self.device, self.run_type = torch.device('cpu'), ['CPUExecutionProvider'] + + if model_data.process_method == DEMUCS_ARCH_TYPE: + self.demucs_stems = model_data.demucs_stems if not main_process_method in [MDX_ARCH_TYPE, VR_ARCH_TYPE] else None + self.secondary_model_4_stem = model_data.secondary_model_4_stem + self.secondary_model_4_stem_scale = model_data.secondary_model_4_stem_scale + self.primary_stem = model_data.ensemble_primary_stem if process_data['is_ensemble_master'] else model_data.primary_stem + self.secondary_stem = model_data.ensemble_secondary_stem if process_data['is_ensemble_master'] else model_data.secondary_stem + self.is_chunk_demucs = model_data.is_chunk_demucs + self.segment = model_data.segment + self.demucs_version = model_data.demucs_version + self.demucs_source_list = model_data.demucs_source_list + self.demucs_source_map = model_data.demucs_source_map + self.is_demucs_combine_stems = model_data.is_demucs_combine_stems + self.demucs_stem_count = model_data.demucs_stem_count + self.pre_proc_model = model_data.pre_proc_model + + if self.is_secondary_model and not process_data['is_ensemble_master']: + if not self.demucs_stem_count == 2 and model_data.primary_model_primary_stem == INST_STEM: + self.primary_stem = VOCAL_STEM + self.secondary_stem = INST_STEM + else: + self.primary_stem = model_data.primary_model_primary_stem + self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem] + + if self.is_chunk_demucs: + self.chunks_demucs = model_data.chunks_demucs + self.margin_demucs = model_data.margin_demucs + else: + self.chunks_demucs = 0 + self.margin_demucs = 44100 + + self.shifts = model_data.shifts + self.is_split_mode = model_data.is_split_mode if not self.demucs_version == DEMUCS_V4 else True + self.overlap = model_data.overlap + self.primary_model_name, self.primary_sources = self.cached_source_callback(DEMUCS_ARCH_TYPE, model_name=self.model_basename) + + if model_data.process_method == VR_ARCH_TYPE: + self.primary_model_name, self.primary_sources = self.cached_source_callback(VR_ARCH_TYPE, model_name=self.model_basename) + self.mp = model_data.vr_model_param + self.high_end_process = model_data.is_high_end_process + self.is_tta = model_data.is_tta + self.is_post_process = model_data.is_post_process + self.is_gpu_conversion = model_data.is_gpu_conversion + self.batch_size = model_data.batch_size + self.window_size = model_data.window_size + self.input_high_end_h = None + self.post_process_threshold = model_data.post_process_threshold + self.aggressiveness = {'value': model_data.aggression_setting, + 'split_bin': self.mp.param['band'][1]['crop_stop'], + 'aggr_correction': self.mp.param.get('aggr_correction')} + + def start_inference_console_write(self): + + if self.is_secondary_model and not self.is_pre_proc_model: + self.write_to_console(INFERENCE_STEP_2_SEC(self.process_method, self.model_basename)) + + if self.is_pre_proc_model: + self.write_to_console(INFERENCE_STEP_2_PRE(self.process_method, self.model_basename)) + + def running_inference_console_write(self, is_no_write=False): + + self.write_to_console(DONE, base_text='') if not is_no_write else None + self.set_progress_bar(0.05) if not is_no_write else None + + if self.is_secondary_model and not self.is_pre_proc_model: + self.write_to_console(INFERENCE_STEP_1_SEC) + elif self.is_pre_proc_model: + self.write_to_console(INFERENCE_STEP_1_PRE) + else: + self.write_to_console(INFERENCE_STEP_1) + + def running_inference_progress_bar(self, length, is_match_mix=False): + if not is_match_mix: + self.progress_value += 1 + + if (0.8/length*self.progress_value) >= 0.8: + length = self.progress_value + 1 + + self.set_progress_bar(0.1, (0.8/length*self.progress_value)) + + def load_cached_sources(self, is_4_stem_demucs=False): + + if self.is_secondary_model and not self.is_pre_proc_model: + self.write_to_console(INFERENCE_STEP_2_SEC_CACHED_MODOEL(self.process_method, self.model_basename)) + elif self.is_pre_proc_model: + self.write_to_console(INFERENCE_STEP_2_PRE_CACHED_MODOEL(self.process_method, self.model_basename)) + else: + self.write_to_console(INFERENCE_STEP_2_PRIMARY_CACHED) + + if not is_4_stem_demucs: + primary_stem, secondary_stem = gather_sources(self.primary_stem, self.secondary_stem, self.primary_sources) + + return primary_stem, secondary_stem + + def cache_source(self, secondary_sources): + + model_occurrences = self.list_all_models.count(self.model_basename) + + if not model_occurrences <= 1: + if self.process_method == MDX_ARCH_TYPE: + self.cached_model_source_holder(MDX_ARCH_TYPE, secondary_sources, self.model_basename) + + if self.process_method == VR_ARCH_TYPE: + self.cached_model_source_holder(VR_ARCH_TYPE, secondary_sources, self.model_basename) + + if self.process_method == DEMUCS_ARCH_TYPE: + self.cached_model_source_holder(DEMUCS_ARCH_TYPE, secondary_sources, self.model_basename) + + def write_audio(self, stem_path, stem_source, samplerate, secondary_model_source=None, model_scale=None): + + if not self.is_secondary_model: + if self.is_secondary_model_activated: + if isinstance(secondary_model_source, np.ndarray): + secondary_model_scale = model_scale if model_scale else self.secondary_model_scale + stem_source = spec_utils.average_dual_sources(stem_source, secondary_model_source, secondary_model_scale) + + sf.write(stem_path, stem_source, samplerate, subtype=self.wav_type_set) + save_format(stem_path, self.save_format, self.mp3_bit_set) if not self.is_ensemble_mode else None + + self.write_to_console(DONE, base_text='') + self.set_progress_bar(0.95) + + def run_mixer(self, mix, sources): + try: + if self.is_mixer_mode and len(sources) == 4: + mixer = MdxnetSet.Mixer(self.device, self.mixer_path).eval() + with torch.no_grad(): + mix = torch.tensor(mix, dtype=torch.float32) + sources_ = torch.tensor(sources).detach() + x = torch.cat([sources_, mix.unsqueeze(0)], 0) + sources_ = mixer(x) + final_source = np.array(sources_) + else: + final_source = sources + except Exception as e: + error_name = f'{type(e).__name__}' + traceback_text = ''.join(traceback.format_tb(e.__traceback__)) + message = f'{error_name}: "{e}"\n{traceback_text}"' + print('Mixer Failed: ', message) + final_source = sources + + return final_source + +class SeperateMDX(SeperateAttributes): + + def seperate(self): + samplerate = 44100 + + if self.primary_model_name == self.model_basename and self.primary_sources: + self.primary_source, self.secondary_source = self.load_cached_sources() + else: + self.start_inference_console_write() + + if self.is_mdx_ckpt: + model_params = torch.load(self.model_path, map_location=lambda storage, loc: storage)['hyper_parameters'] + self.dim_c, self.hop = model_params['dim_c'], model_params['hop_length'] + separator = MdxnetSet.ConvTDFNet(**model_params) + self.model_run = separator.load_from_checkpoint(self.model_path).to(self.device).eval() + else: + ort_ = ort.InferenceSession(self.model_path, providers=self.run_type) + self.model_run = lambda spek:ort_.run(None, {'input': spek.cpu().numpy()})[0] + + self.initialize_model_settings() + self.running_inference_console_write() + mdx_net_cut = True if self.primary_stem in MDX_NET_FREQ_CUT else False + mix, raw_mix, samplerate = prepare_mix(self.audio_file, self.chunks, self.margin, mdx_net_cut=mdx_net_cut) + source = self.demix_base(mix, is_ckpt=self.is_mdx_ckpt)[0] + self.write_to_console(DONE, base_text='') + + if self.is_secondary_model_activated: + if self.secondary_model: + self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method) + + if not self.is_secondary_stem_only: + self.write_to_console(f'{SAVING_STEM[0]}{self.primary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None + primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav') + if not isinstance(self.primary_source, np.ndarray): + self.primary_source = spec_utils.normalize(source, self.is_normalization).T + self.primary_source_map = {self.primary_stem: self.primary_source} + self.write_audio(primary_stem_path, self.primary_source, samplerate, self.secondary_source_primary) + + if not self.is_primary_stem_only: + self.write_to_console(f'{SAVING_STEM[0]}{self.secondary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None + secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav') + if not isinstance(self.secondary_source, np.ndarray): + raw_mix = self.demix_base(raw_mix, is_match_mix=True)[0] if mdx_net_cut else raw_mix + self.secondary_source, raw_mix = spec_utils.normalize_two_stem(source*self.compensate, raw_mix, self.is_normalization) + + if self.is_invert_spec: + self.secondary_source = spec_utils.invert_stem(raw_mix, self.secondary_source) + else: + self.secondary_source = (-self.secondary_source.T+raw_mix.T) + + self.secondary_source_map = {self.secondary_stem: self.secondary_source} + self.write_audio(secondary_stem_path, self.secondary_source, samplerate, self.secondary_source_secondary) + + torch.cuda.empty_cache() + secondary_sources = {**self.primary_source_map, **self.secondary_source_map} + + self.cache_source(secondary_sources) + + if self.is_secondary_model: + return secondary_sources + + def initialize_model_settings(self): + self.n_bins = self.n_fft//2+1 + self.trim = self.n_fft//2 + self.chunk_size = self.hop * (self.dim_t-1) + self.window = torch.hann_window(window_length=self.n_fft, periodic=False).to(self.device) + self.freq_pad = torch.zeros([1, self.dim_c, self.n_bins-self.dim_f, self.dim_t]).to(self.device) + self.gen_size = self.chunk_size-2*self.trim + + def initialize_mix(self, mix, is_ckpt=False): + if is_ckpt: + pad = self.gen_size + self.trim - ((mix.shape[-1]) % self.gen_size) + mixture = np.concatenate((np.zeros((2, self.trim), dtype='float32'),mix, np.zeros((2, pad), dtype='float32')), 1) + num_chunks = mixture.shape[-1] // self.gen_size + mix_waves = [mixture[:, i * self.gen_size: i * self.gen_size + self.chunk_size] for i in range(num_chunks)] + else: + mix_waves = [] + n_sample = mix.shape[1] + pad = self.gen_size - n_sample%self.gen_size + mix_p = np.concatenate((np.zeros((2,self.trim)), mix, np.zeros((2,pad)), np.zeros((2,self.trim))), 1) + i = 0 + while i < n_sample + pad: + waves = np.array(mix_p[:, i:i+self.chunk_size]) + mix_waves.append(waves) + i += self.gen_size + + mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(self.device) + + return mix_waves, pad + + def demix_base(self, mix, is_ckpt=False, is_match_mix=False): + chunked_sources = [] + for slice in mix: + sources = [] + tar_waves_ = [] + mix_p = mix[slice] + mix_waves, pad = self.initialize_mix(mix_p, is_ckpt=is_ckpt) + mix_waves = mix_waves.split(self.mdx_batch_size) + pad = mix_p.shape[-1] if is_ckpt else -pad + with torch.no_grad(): + for mix_wave in mix_waves: + self.running_inference_progress_bar(len(mix)*len(mix_waves), is_match_mix=is_match_mix) + tar_waves = self.run_model(mix_wave, is_ckpt=is_ckpt, is_match_mix=is_match_mix) + tar_waves_.append(tar_waves) + tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim:-self.trim] if is_ckpt else tar_waves_ + tar_waves = np.concatenate(tar_waves_, axis=-1)[:, :pad] + start = 0 if slice == 0 else self.margin + end = None if slice == list(mix.keys())[::-1][0] or self.margin == 0 else -self.margin + sources.append(tar_waves[:,start:end]*(1/self.adjust)) + chunked_sources.append(sources) + sources = np.concatenate(chunked_sources, axis=-1) + + return sources + + def run_model(self, mix, is_ckpt=False, is_match_mix=False): + + spek = self.stft(mix.to(self.device))*self.adjust + spek[:, :, :3, :] *= 0 + + if is_match_mix: + spec_pred = spek.cpu().numpy() + else: + spec_pred = -self.model_run(-spek)*0.5+self.model_run(spek)*0.5 if self.is_denoise else self.model_run(spek) + + if is_ckpt: + return self.istft(spec_pred).cpu().detach().numpy() + else: + return self.istft(torch.tensor(spec_pred).to(self.device)).to(cpu)[:,:,self.trim:-self.trim].transpose(0,1).reshape(2, -1).numpy() + + def stft(self, x): + x = x.reshape([-1, self.chunk_size]) + x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True,return_complex=True) + x=torch.view_as_real(x) + x = x.permute([0,3,1,2]) + x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,self.dim_c,self.n_bins,self.dim_t]) + return x[:,:,:self.dim_f] + + def istft(self, x, freq_pad=None): + freq_pad = self.freq_pad.repeat([x.shape[0],1,1,1]) if freq_pad is None else freq_pad + x = torch.cat([x, freq_pad], -2) + x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t]) + x = x.permute([0,2,3,1]) + x=x.contiguous() + x=torch.view_as_complex(x) + x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) + return x.reshape([-1,2,self.chunk_size]) + +class SeperateDemucs(SeperateAttributes): + + def seperate(self): + + samplerate = 44100 + source = None + model_scale = None + stem_source = None + stem_source_secondary = None + inst_mix = None + inst_raw_mix = None + raw_mix = None + inst_source = None + is_no_write = False + is_no_piano_guitar = False + + if self.primary_model_name == self.model_basename and type(self.primary_sources) is dict and not self.pre_proc_model: + self.primary_source, self.secondary_source = self.load_cached_sources() + elif self.primary_model_name == self.model_basename and isinstance(self.primary_sources, np.ndarray) and not self.pre_proc_model: + source = self.primary_sources + self.load_cached_sources(is_4_stem_demucs=True) + else: + self.start_inference_console_write() + + if self.is_gpu_conversion >= 0: + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + else: + self.device = torch.device('cpu') + + if self.demucs_version == DEMUCS_V1: + if str(self.model_path).endswith(".gz"): + self.model_path = gzip.open(self.model_path, "rb") + klass, args, kwargs, state = torch.load(self.model_path) + self.demucs = klass(*args, **kwargs) + self.demucs.to(self.device) + self.demucs.load_state_dict(state) + elif self.demucs_version == DEMUCS_V2: + self.demucs = auto_load_demucs_model_v2(self.demucs_source_list, self.model_path) + self.demucs.to(self.device) + self.demucs.load_state_dict(torch.load(self.model_path)) + self.demucs.eval() + else: + self.demucs = HDemucs(sources=self.demucs_source_list) + self.demucs = _gm(name=os.path.splitext(os.path.basename(self.model_path))[0], + repo=Path(os.path.dirname(self.model_path))) + self.demucs = demucs_segments(self.segment, self.demucs) + self.demucs.to(self.device) + self.demucs.eval() + + if self.pre_proc_model: + if self.primary_stem not in [VOCAL_STEM, INST_STEM]: + is_no_write = True + self.write_to_console(DONE, base_text='') + mix_no_voc = process_secondary_model(self.pre_proc_model, self.process_data, is_pre_proc_model=True) + inst_mix, inst_raw_mix, inst_samplerate = prepare_mix(mix_no_voc[INST_STEM], self.chunks_demucs, self.margin_demucs) + self.process_iteration() + self.running_inference_console_write(is_no_write=is_no_write) + inst_source = self.demix_demucs(inst_mix) + inst_source = self.run_mixer(inst_raw_mix, inst_source) + self.process_iteration() + + self.running_inference_console_write(is_no_write=is_no_write) if not self.pre_proc_model else None + mix, raw_mix, samplerate = prepare_mix(self.audio_file, self.chunks_demucs, self.margin_demucs) + + if self.primary_model_name == self.model_basename and isinstance(self.primary_sources, np.ndarray) and self.pre_proc_model: + source = self.primary_sources + else: + source = self.demix_demucs(mix) + source = self.run_mixer(raw_mix, source) + + self.write_to_console(DONE, base_text='') + + del self.demucs + torch.cuda.empty_cache() + + if isinstance(inst_source, np.ndarray): + source_reshape = spec_utils.reshape_sources(inst_source[self.demucs_source_map[VOCAL_STEM]], source[self.demucs_source_map[VOCAL_STEM]]) + inst_source[self.demucs_source_map[VOCAL_STEM]] = source_reshape + source = inst_source + + if isinstance(source, np.ndarray): + if len(source) == 2: + self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER + else: + self.demucs_source_map = DEMUCS_6_SOURCE_MAPPER if len(source) == 6 else DEMUCS_4_SOURCE_MAPPER + + if len(source) == 6 and self.process_data['is_ensemble_master'] or len(source) == 6 and self.is_secondary_model: + is_no_piano_guitar = True + six_stem_other_source = list(source) + six_stem_other_source = [i for n, i in enumerate(source) if n in [self.demucs_source_map[OTHER_STEM], self.demucs_source_map[GUITAR_STEM], self.demucs_source_map[PIANO_STEM]]] + other_source = np.zeros_like(six_stem_other_source[0]) + for i in six_stem_other_source: + other_source += i + source_reshape = spec_utils.reshape_sources(source[self.demucs_source_map[OTHER_STEM]], other_source) + source[self.demucs_source_map[OTHER_STEM]] = source_reshape + + if (self.demucs_stems == ALL_STEMS and not self.process_data['is_ensemble_master']) or self.is_4_stem_ensemble: + self.cache_source(source) + + for stem_name, stem_value in self.demucs_source_map.items(): + if self.is_secondary_model_activated and not self.is_secondary_model and not stem_value >= 4: + if self.secondary_model_4_stem[stem_value]: + model_scale = self.secondary_model_4_stem_scale[stem_value] + stem_source_secondary = process_secondary_model(self.secondary_model_4_stem[stem_value], self.process_data, main_model_primary_stem_4_stem=stem_name, is_4_stem_demucs=True) + if isinstance(stem_source_secondary, np.ndarray): + stem_source_secondary = stem_source_secondary[1 if self.secondary_model_4_stem[stem_value].demucs_stem_count == 2 else stem_value] + stem_source_secondary = spec_utils.normalize(stem_source_secondary, self.is_normalization).T + elif type(stem_source_secondary) is dict: + stem_source_secondary = stem_source_secondary[stem_name] + + stem_source_secondary = None if stem_value >= 4 else stem_source_secondary + self.write_to_console(f'{SAVING_STEM[0]}{stem_name}{SAVING_STEM[1]}') if not self.is_secondary_model else None + stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({stem_name}).wav') + stem_source = spec_utils.normalize(source[stem_value], self.is_normalization).T + self.write_audio(stem_path, stem_source, samplerate, secondary_model_source=stem_source_secondary, model_scale=model_scale) + + if self.is_secondary_model: + return source + else: + if self.is_secondary_model_activated: + if self.secondary_model: + self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method) + + if not self.is_secondary_stem_only: + self.write_to_console(f'{SAVING_STEM[0]}{self.primary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None + primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav') + if not isinstance(self.primary_source, np.ndarray): + self.primary_source = spec_utils.normalize(source[self.demucs_source_map[self.primary_stem]], self.is_normalization).T + self.primary_source_map = {self.primary_stem: self.primary_source} + self.write_audio(primary_stem_path, self.primary_source, samplerate, self.secondary_source_primary) + + if not self.is_primary_stem_only: + def secondary_save(sec_stem_name, source, raw_mixture=None, is_inst_mixture=False): + secondary_source = self.secondary_source if not is_inst_mixture else None + self.write_to_console(f'{SAVING_STEM[0]}{sec_stem_name}{SAVING_STEM[1]}') if not self.is_secondary_model else None + secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({sec_stem_name}).wav') + secondary_source_secondary = None + + if not isinstance(secondary_source, np.ndarray): + if self.is_demucs_combine_stems: + source = list(source) + if is_inst_mixture: + source = [i for n, i in enumerate(source) if not n in [self.demucs_source_map[self.primary_stem], self.demucs_source_map[VOCAL_STEM]]] + else: + source.pop(self.demucs_source_map[self.primary_stem]) + + source = source[:len(source) - 2] if is_no_piano_guitar else source + secondary_source = np.zeros_like(source[0]) + for i in source: + secondary_source += i + secondary_source = spec_utils.normalize(secondary_source, self.is_normalization).T + else: + if not isinstance(raw_mixture, np.ndarray): + raw_mixture = prepare_mix(self.audio_file, self.chunks_demucs, self.margin_demucs, is_missing_mix=True) + + secondary_source, raw_mixture = spec_utils.normalize_two_stem(source[self.demucs_source_map[self.primary_stem]], raw_mixture, self.is_normalization) + + if self.is_invert_spec: + secondary_source = spec_utils.invert_stem(raw_mixture, secondary_source) + else: + raw_mixture = spec_utils.reshape_sources(secondary_source, raw_mixture) + secondary_source = (-secondary_source.T+raw_mixture.T) + + if not is_inst_mixture: + self.secondary_source = secondary_source + secondary_source_secondary = self.secondary_source_secondary + self.secondary_source_map = {self.secondary_stem: self.secondary_source} + + self.write_audio(secondary_stem_path, secondary_source, samplerate, secondary_source_secondary) + + secondary_save(self.secondary_stem, source, raw_mixture=raw_mix) + + if self.is_demucs_pre_proc_model_inst_mix and self.pre_proc_model and not self.is_4_stem_ensemble: + secondary_save(f"{self.secondary_stem} {INST_STEM}", source, raw_mixture=inst_raw_mix, is_inst_mixture=True) + + secondary_sources = {**self.primary_source_map, **self.secondary_source_map} + + self.cache_source(secondary_sources) + + if self.is_secondary_model: + return secondary_sources + + def demix_demucs(self, mix): + processed = {} + + set_progress_bar = None if self.is_chunk_demucs else self.set_progress_bar + + for nmix in mix: + self.progress_value += 1 + self.set_progress_bar(0.1, (0.8/len(mix)*self.progress_value)) if self.is_chunk_demucs else None + cmix = mix[nmix] + cmix = torch.tensor(cmix, dtype=torch.float32) + ref = cmix.mean(0) + cmix = (cmix - ref.mean()) / ref.std() + mix_infer = cmix + + with torch.no_grad(): + if self.demucs_version == DEMUCS_V1: + sources = apply_model_v1(self.demucs, + mix_infer.to(self.device), + self.shifts, + self.is_split_mode, + set_progress_bar=set_progress_bar) + elif self.demucs_version == DEMUCS_V2: + sources = apply_model_v2(self.demucs, + mix_infer.to(self.device), + self.shifts, + self.is_split_mode, + self.overlap, + set_progress_bar=set_progress_bar) + else: + sources = apply_model(self.demucs, + mix_infer[None], + self.shifts, + self.is_split_mode, + self.overlap, + static_shifts=1 if self.shifts == 0 else self.shifts, + set_progress_bar=set_progress_bar, + device=self.device)[0] + + sources = (sources * ref.std() + ref.mean()).cpu().numpy() + sources[[0,1]] = sources[[1,0]] + start = 0 if nmix == 0 else self.margin_demucs + end = None if nmix == list(mix.keys())[::-1][0] else -self.margin_demucs + if self.margin_demucs == 0: + end = None + processed[nmix] = sources[:,:,start:end].copy() + sources = list(processed.values()) + sources = np.concatenate(sources, axis=-1) + + return sources + +class SeperateVR(SeperateAttributes): + + def seperate(self): + if self.primary_model_name == self.model_basename and self.primary_sources: + self.primary_source, self.secondary_source = self.load_cached_sources() + else: + self.start_inference_console_write() + if self.is_gpu_conversion >= 0: + if OPERATING_SYSTEM == 'Darwin': + device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') + else: + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + else: + device = torch.device('cpu') + + nn_arch_sizes = [ + 31191, # default + 33966, 56817, 123821, 123812, 129605, 218409, 537238, 537227] + vr_5_1_models = [56817, 218409] + model_size = math.ceil(os.stat(self.model_path).st_size / 1024) + nn_arch_size = min(nn_arch_sizes, key=lambda x:abs(x-model_size)) + + if nn_arch_size in vr_5_1_models or self.is_vr_51_model: + self.model_run = nets_new.CascadedNet(self.mp.param['bins'] * 2, nn_arch_size, nout=self.model_capacity[0], nout_lstm=self.model_capacity[1]) + else: + self.model_run = nets.determine_model_capacity(self.mp.param['bins'] * 2, nn_arch_size) + + self.model_run.load_state_dict(torch.load(self.model_path, map_location=cpu)) + self.model_run.to(device) + + self.running_inference_console_write() + + y_spec, v_spec = self.inference_vr(self.loading_mix(), device, self.aggressiveness) + self.write_to_console(DONE, base_text='') + + if self.is_secondary_model_activated: + if self.secondary_model: + self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(self.secondary_model, self.process_data, main_process_method=self.process_method) + + if not self.is_secondary_stem_only: + self.write_to_console(f'{SAVING_STEM[0]}{self.primary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None + primary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.primary_stem}).wav') + if not isinstance(self.primary_source, np.ndarray): + self.primary_source = spec_utils.normalize(self.spec_to_wav(y_spec), self.is_normalization).T + if not self.model_samplerate == 44100: + self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T + + self.primary_source_map = {self.primary_stem: self.primary_source} + + self.write_audio(primary_stem_path, self.primary_source, 44100, self.secondary_source_primary) + + if not self.is_primary_stem_only: + self.write_to_console(f'{SAVING_STEM[0]}{self.secondary_stem}{SAVING_STEM[1]}') if not self.is_secondary_model else None + secondary_stem_path = os.path.join(self.export_path, f'{self.audio_file_base}_({self.secondary_stem}).wav') + if not isinstance(self.secondary_source, np.ndarray): + self.secondary_source = self.spec_to_wav(v_spec) + self.secondary_source = spec_utils.normalize(self.spec_to_wav(v_spec), self.is_normalization).T + if not self.model_samplerate == 44100: + self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T + + self.secondary_source_map = {self.secondary_stem: self.secondary_source} + + self.write_audio(secondary_stem_path, self.secondary_source, 44100, self.secondary_source_secondary) + + torch.cuda.empty_cache() + secondary_sources = {**self.primary_source_map, **self.secondary_source_map} + self.cache_source(secondary_sources) + + if self.is_secondary_model: + return secondary_sources + + def loading_mix(self): + + X_wave, X_spec_s = {}, {} + + bands_n = len(self.mp.param['band']) + + for d in range(bands_n, 0, -1): + bp = self.mp.param['band'][d] + + if OPERATING_SYSTEM == 'Darwin': + wav_resolution = 'polyphase' if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else bp['res_type'] + else: + wav_resolution = bp['res_type'] + + if d == bands_n: # high-end band + X_wave[d], _ = librosa.load(self.audio_file, bp['sr'], False, dtype=np.float32, res_type=wav_resolution) + + if not np.any(X_wave[d]) and self.audio_file.endswith('.mp3'): + X_wave[d] = rerun_mp3(self.audio_file, bp['sr']) + + if X_wave[d].ndim == 1: + X_wave[d] = np.asarray([X_wave[d], X_wave[d]]) + else: # lower bands + X_wave[d] = librosa.resample(X_wave[d+1], self.mp.param['band'][d+1]['sr'], bp['sr'], res_type=wav_resolution) + + X_spec_s[d] = spec_utils.wave_to_spectrogram_mt(X_wave[d], bp['hl'], bp['n_fft'], self.mp.param['mid_side'], + self.mp.param['mid_side_b2'], self.mp.param['reverse']) + + if d == bands_n and self.high_end_process != 'none': + self.input_high_end_h = (bp['n_fft']//2 - bp['crop_stop']) + (self.mp.param['pre_filter_stop'] - self.mp.param['pre_filter_start']) + self.input_high_end = X_spec_s[d][:, bp['n_fft']//2-self.input_high_end_h:bp['n_fft']//2, :] + + X_spec = spec_utils.combine_spectrograms(X_spec_s, self.mp) + + del X_wave, X_spec_s + + return X_spec + + def inference_vr(self, X_spec, device, aggressiveness): + def _execute(X_mag_pad, roi_size): + X_dataset = [] + patches = (X_mag_pad.shape[2] - 2 * self.model_run.offset) // roi_size + total_iterations = patches//self.batch_size if not self.is_tta else (patches//self.batch_size)*2 + for i in range(patches): + start = i * roi_size + X_mag_window = X_mag_pad[:, :, start:start + self.window_size] + X_dataset.append(X_mag_window) + + X_dataset = np.asarray(X_dataset) + self.model_run.eval() + with torch.no_grad(): + mask = [] + for i in range(0, patches, self.batch_size): + self.progress_value += 1 + if self.progress_value >= total_iterations: + self.progress_value = total_iterations + self.set_progress_bar(0.1, 0.8/total_iterations*self.progress_value) + X_batch = X_dataset[i: i + self.batch_size] + X_batch = torch.from_numpy(X_batch).to(device) + pred = self.model_run.predict_mask(X_batch) + if not pred.size()[3] > 0: + raise Exception(ERROR_MAPPER[WINDOW_SIZE_ERROR]) + pred = pred.detach().cpu().numpy() + pred = np.concatenate(pred, axis=2) + mask.append(pred) + if len(mask) == 0: + raise Exception(ERROR_MAPPER[WINDOW_SIZE_ERROR]) + + mask = np.concatenate(mask, axis=2) + return mask + + def postprocess(mask, X_mag, X_phase): + + is_non_accom_stem = False + for stem in NON_ACCOM_STEMS: + if stem == self.primary_stem: + is_non_accom_stem = True + + mask = spec_utils.adjust_aggr(mask, is_non_accom_stem, aggressiveness) + + if self.is_post_process: + mask = spec_utils.merge_artifacts(mask, thres=self.post_process_threshold) + + y_spec = mask * X_mag * np.exp(1.j * X_phase) + v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase) + + return y_spec, v_spec + X_mag, X_phase = spec_utils.preprocess(X_spec) + n_frame = X_mag.shape[2] + pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.window_size, self.model_run.offset) + X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') + X_mag_pad /= X_mag_pad.max() + mask = _execute(X_mag_pad, roi_size) + + if self.is_tta: + pad_l += roi_size // 2 + pad_r += roi_size // 2 + X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') + X_mag_pad /= X_mag_pad.max() + mask_tta = _execute(X_mag_pad, roi_size) + mask_tta = mask_tta[:, :, roi_size // 2:] + mask = (mask[:, :, :n_frame] + mask_tta[:, :, :n_frame]) * 0.5 + else: + mask = mask[:, :, :n_frame] + + y_spec, v_spec = postprocess(mask, X_mag, X_phase) + + return y_spec, v_spec + + def spec_to_wav(self, spec): + + if self.high_end_process.startswith('mirroring'): + input_high_end_ = spec_utils.mirroring(self.high_end_process, spec, self.input_high_end, self.mp) + wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp, self.input_high_end_h, input_high_end_) + else: + wav = spec_utils.cmb_spectrogram_to_wave(spec, self.mp) + + return wav + +def process_secondary_model(secondary_model: ModelData, process_data, main_model_primary_stem_4_stem=None, is_4_stem_demucs=False, main_process_method=None, is_pre_proc_model=False): + + if not is_pre_proc_model: + process_iteration = process_data['process_iteration'] + process_iteration() + + if secondary_model.process_method == VR_ARCH_TYPE: + seperator = SeperateVR(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method) + if secondary_model.process_method == MDX_ARCH_TYPE: + seperator = SeperateMDX(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method) + if secondary_model.process_method == DEMUCS_ARCH_TYPE: + seperator = SeperateDemucs(secondary_model, process_data, main_model_primary_stem_4_stem=main_model_primary_stem_4_stem, main_process_method=main_process_method) + + secondary_sources = seperator.seperate() + + if type(secondary_sources) is dict and not is_4_stem_demucs and not is_pre_proc_model: + return gather_sources(secondary_model.primary_model_primary_stem, STEM_PAIR_MAPPER[secondary_model.primary_model_primary_stem], secondary_sources) + else: + return secondary_sources + +def gather_sources(primary_stem_name, secondary_stem_name, secondary_sources: dict): + + source_primary = False + source_secondary = False + + for key, value in secondary_sources.items(): + if key in primary_stem_name: + source_primary = value + if key in secondary_stem_name: + source_secondary = value + + return source_primary, source_secondary + +def prepare_mix(mix, chunk_set, margin_set, mdx_net_cut=False, is_missing_mix=False): + + audio_path = mix + samplerate = 44100 + + if not isinstance(mix, np.ndarray): + mix, samplerate = librosa.load(mix, mono=False, sr=44100) + else: + mix = mix.T + + if not np.any(mix) and audio_path.endswith('.mp3'): + mix = rerun_mp3(audio_path) + + if mix.ndim == 1: + mix = np.asfortranarray([mix,mix]) + + def get_segmented_mix(chunk_set=chunk_set): + segmented_mix = {} + + samples = mix.shape[-1] + margin = margin_set + chunk_size = chunk_set*44100 + assert not margin == 0, 'margin cannot be zero!' + + if margin > chunk_size: + margin = chunk_size + if chunk_set == 0 or samples < chunk_size: + chunk_size = samples + + counter = -1 + for skip in range(0, samples, chunk_size): + counter+=1 + s_margin = 0 if counter == 0 else margin + end = min(skip+chunk_size+margin, samples) + start = skip-s_margin + segmented_mix[skip] = mix[:,start:end].copy() + if end == samples: + break + + return segmented_mix + + if is_missing_mix: + return mix + else: + segmented_mix = get_segmented_mix() + raw_mix = get_segmented_mix(chunk_set=0) if mdx_net_cut else mix + return segmented_mix, raw_mix, samplerate + +def rerun_mp3(audio_file, sample_rate=44100): + + with audioread.audio_open(audio_file) as f: + track_length = int(f.duration) + + return librosa.load(audio_file, duration=track_length, mono=False, sr=sample_rate)[0] + +def save_format(audio_path, save_format, mp3_bit_set): + + if not save_format == WAV: + + if OPERATING_SYSTEM == 'Darwin': + FFMPEG_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'ffmpeg') + pydub.AudioSegment.converter = FFMPEG_PATH + + musfile = pydub.AudioSegment.from_wav(audio_path) + + if save_format == FLAC: + audio_path_flac = audio_path.replace(".wav", ".flac") + musfile.export(audio_path_flac, format="flac") + + if save_format == MP3: + audio_path_mp3 = audio_path.replace(".wav", ".mp3") + musfile.export(audio_path_mp3, format="mp3", bitrate=mp3_bit_set) + + try: + os.remove(audio_path) + except Exception as e: + print(e) diff --git a/webUI.py b/webUI.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c7a864027899fb8c566f8fd5bd3722f9ec94eb --- /dev/null +++ b/webUI.py @@ -0,0 +1,270 @@ +import os +import json + +import librosa +import soundfile +import numpy as np + +import gradio as gr +from UVR_interface import root, UVRInterface, VR_MODELS_DIR, MDX_MODELS_DIR, DEMUCS_MODELS_DIR +from gui_data.constants import * +from typing import List, Dict, Callable, Union + + +class UVRWebUI: + def __init__(self, uvr: UVRInterface, online_data_path: str) -> None: + self.uvr = uvr + self.models_url = self.get_models_url(online_data_path) + self.define_layout() + + self.input_temp_dir = "__temp" + self.export_path = "out" + if not os.path.exists(self.input_temp_dir): + os.mkdir(self.input_temp_dir) + + def get_models_url(self, models_info_path: str) -> Dict[str, Dict]: + with open(models_info_path, "r") as f: + online_data = json.loads(f.read()) + models_url = {} + for arch, download_list_key in zip([VR_ARCH_TYPE, MDX_ARCH_TYPE], ["vr_download_list", "mdx_download_list"]): + models_url[arch] = {model: NORMAL_REPO+model_path for model, model_path in online_data[download_list_key].items()} + models_url[DEMUCS_ARCH_TYPE] = online_data["demucs_download_list"] + return models_url + + def get_local_models(self, arch: str) -> List[str]: + model_config = { + VR_ARCH_TYPE: (VR_MODELS_DIR, ".pth"), + MDX_ARCH_TYPE: (MDX_MODELS_DIR, ".onnx"), + DEMUCS_ARCH_TYPE: (DEMUCS_MODELS_DIR, ".yaml"), + } + try: + model_dir, suffix = model_config[arch] + except KeyError: + raise ValueError(f"Unkown arch type: {arch}") + return [os.path.splitext(f)[0] for f in os.listdir(model_dir) if f.endswith(suffix)] + + def set_arch_setting_value(self, arch: str, setting1, setting2): + if arch == VR_ARCH_TYPE: + root.window_size_var.set(setting1) + root.aggression_setting_var.set(setting2) + elif arch == MDX_ARCH_TYPE: + root.mdx_batch_size_var.set(setting1) + root.compensate_var.set(setting2) + elif arch == DEMUCS_ARCH_TYPE: + pass + + def arch_select_update(self, arch: str) -> List[Dict]: + choices = self.get_local_models(arch) + if arch == VR_ARCH_TYPE: + model_update = self.model_choice.update(choices=choices, value=CHOOSE_MODEL, label=SELECT_VR_MODEL_MAIN_LABEL) + setting1_update = self.arch_setting1.update(choices=VR_WINDOW, label=WINDOW_SIZE_MAIN_LABEL, value=root.window_size_var.get()) + setting2_update = self.arch_setting2.update(choices=VR_AGGRESSION, label=AGGRESSION_SETTING_MAIN_LABEL, value=root.aggression_setting_var.get()) + elif arch == MDX_ARCH_TYPE: + model_update = self.model_choice.update(choices=choices, value=CHOOSE_MODEL, label=CHOOSE_MDX_MODEL_MAIN_LABEL) + setting1_update = self.arch_setting1.update(choices=BATCH_SIZE, label=BATCHES_MDX_MAIN_LABEL, value=root.mdx_batch_size_var.get()) + setting2_update = self.arch_setting2.update(choices=VOL_COMPENSATION, label=VOL_COMP_MDX_MAIN_LABEL, value=root.compensate_var.get()) + elif arch == DEMUCS_ARCH_TYPE: + model_update = self.model_choice.update(choices=choices, value=CHOOSE_MODEL, label=CHOOSE_DEMUCS_MODEL_MAIN_LABEL) + raise gr.Error(f"{DEMUCS_ARCH_TYPE} not implempted") + else: + raise gr.Error(f"Unkown arch type: {arch}") + return [model_update, setting1_update, setting2_update] + + def model_select_update(self, arch: str, model_name: str) -> List[Union[str, Dict, None]]: + if model_name == CHOOSE_MODEL: + return [None for _ in range(4)] + model, = self.uvr.assemble_model_data(model_name, arch) + if not model.model_status: + raise gr.Error(f"Cannot get model data, model hash = {model.model_hash}") + + stem1_check_update = self.primary_stem_only.update(label=f"{model.primary_stem} Only") + stem2_check_update = self.secondary_stem_only.update(label=f"{model.secondary_stem} Only") + stem1_out_update = self.primary_stem_out.update(label=f"Output {model.primary_stem}") + stem2_out_update = self.secondary_stem_out.update(label=f"Output {model.secondary_stem}") + + return [stem1_check_update, stem2_check_update, stem1_out_update, stem2_out_update] + + def checkbox_set_root_value(self, checkbox: gr.Checkbox, root_attr: str): + checkbox.change(lambda value: root.__getattribute__(root_attr).set(value), inputs=checkbox) + + def set_checkboxes_exclusive(self, checkboxes: List[gr.Checkbox], pure_callbacks: List[Callable], exclusive_value=True): + def exclusive_onchange(i, callback_i): + def new_onchange(*check_values): + if check_values[i] == exclusive_value: + return_values = [] + for j, value_j in enumerate(check_values): + if j != i and value_j == exclusive_value: + return_values.append(not exclusive_value) + else: + return_values.append(value_j) + else: + return_values = check_values + callback_i(check_values[i]) + return return_values + return new_onchange + + for i, (checkbox, callback) in enumerate(zip(checkboxes, pure_callbacks)): + checkbox.change(exclusive_onchange(i, callback), inputs=checkboxes, outputs=checkboxes) + + def process(self, input_audio, input_filename, model_name, arch, setting1, setting2, progress=gr.Progress()): + def set_progress_func(step, inference_iterations=0): + progress_curr = step + inference_iterations + progress(progress_curr) + + sampling_rate, audio = input_audio + audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32) + if len(audio.shape) > 1: + audio = librosa.to_mono(audio.transpose(1, 0)) + input_path = os.path.join(self.input_temp_dir, input_filename) + soundfile.write(input_path, audio, sampling_rate, format="wav") + + self.set_arch_setting_value(arch, setting1, setting2) + + seperator = uvr.process( + model_name=model_name, + arch_type=arch, + audio_file=input_path, + export_path=self.export_path, + is_model_sample_mode=root.model_sample_mode_var.get(), + set_progress_func=set_progress_func, + ) + + primary_audio = None + secondary_audio = None + msg = "" + if not seperator.is_secondary_stem_only: + primary_stem_path = os.path.join(seperator.export_path, f"{seperator.audio_file_base}_({seperator.primary_stem}).wav") + audio, rate = soundfile.read(primary_stem_path) + primary_audio = (rate, audio) + msg += f"{seperator.primary_stem} saved at {primary_stem_path}\n" + if not seperator.is_primary_stem_only: + secondary_stem_path = os.path.join(seperator.export_path, f"{seperator.audio_file_base}_({seperator.secondary_stem}).wav") + audio, rate = soundfile.read(secondary_stem_path) + secondary_audio = (rate, audio) + msg += f"{seperator.secondary_stem} saved at {secondary_stem_path}\n" + + os.remove(input_path) + + return primary_audio, secondary_audio, msg + + def define_layout(self): + with gr.Blocks(title="Ultimate Vocal Remover WebUI", theme=gr.themes.Base(font=[gr.themes.GoogleFont("Noto Sans Thai"), "sans-serif"])) as app: + self.app = app + gr.Label('🗣️ Ultimate Vocal Remover WebUI 🎵', show_label=False) + gr.Markdown("- Forked from [**Ultimate-Vocal-Remover-WebUI**](https://huggingface.co/spaces/r3gm/Ultimate-Vocal-Remover-WebUI) by [***@r3gm***](https://huggingface.co/r3gm)") + with gr.Tabs(): + with gr.TabItem("UVR5"): + with gr.Row(): + self.arch_choice = gr.Dropdown( + choices=[VR_ARCH_TYPE, MDX_ARCH_TYPE], value=VR_ARCH_TYPE, # choices=[VR_ARCH_TYPE, MDX_ARCH_TYPE, DEMUCS_ARCH_TYPE], value=VR_ARCH_TYPE, + label=CHOOSE_PROC_METHOD_MAIN_LABEL, interactive=True) + self.model_choice = gr.Dropdown( + choices=self.get_local_models(VR_ARCH_TYPE), value=CHOOSE_MODEL, + label=SELECT_VR_MODEL_MAIN_LABEL, interactive=True) + with gr.Row(): + self.arch_setting1 = gr.Dropdown( + choices=VR_WINDOW, value=root.window_size_var.get(), + label=WINDOW_SIZE_MAIN_LABEL, interactive=True) + self.arch_setting2 = gr.Dropdown( + choices=VR_AGGRESSION, value=root.aggression_setting_var.get(), + label=AGGRESSION_SETTING_MAIN_LABEL, interactive=True) + with gr.Row(): + self.use_gpu = gr.Checkbox( + label='Rhythmic Transmutation Device', value=True, interactive=True) #label=GPU_CONVERSION_MAIN_LABEL, value=root.is_gpu_conversion_var.get(), interactive=True) + self.primary_stem_only = gr.Checkbox( + label=f"{PRIMARY_STEM} only", value=root.is_primary_stem_only_var.get(), interactive=True) + self.secondary_stem_only = gr.Checkbox( + label=f"{SECONDARY_STEM} only", value=root.is_secondary_stem_only_var.get(), interactive=True) + self.sample_mode = gr.Checkbox( + label=SAMPLE_MODE_CHECKBOX(root.model_sample_mode_duration_var.get()), + value=root.model_sample_mode_var.get(), interactive=True) + + with gr.Row(): + self.input_filename = gr.Textbox(label="Input filename", value="temp.wav", interactive=True) + with gr.Row(): + self.audio_in = gr.Audio(label="Input audio", interactive=True) + with gr.Row(): + self.process_submit = gr.Button(START_PROCESSING, variant="primary") + with gr.Row(): + self.primary_stem_out = gr.Audio(label=f"Output {PRIMARY_STEM}", interactive=False) + self.secondary_stem_out = gr.Audio(label=f"Output {SECONDARY_STEM}", interactive=False) + with gr.Row(): + self.out_message = gr.Textbox(label="Output Message", interactive=False, show_progress=False) + + with gr.TabItem("Settings"): + with gr.Tabs(): + with gr.TabItem("Additional Settigns"): + self.wav_type = gr.Dropdown(choices=WAV_TYPE, label="Wav Type", value="PCM_16", interactive=True) + self.mp3_rate = gr.Dropdown(choices=MP3_BIT_RATES, label="MP3 Bitrate", value="320k",interactive=True) + with gr.TabItem("Download models"): + + def md_url(url, text=None): + if text is None: + text = url + return f"[{url}]({url})" + + with gr.Row(): + vr_models = self.models_url[VR_ARCH_TYPE] + self.vr_download_choice = gr.Dropdown(choices=list(vr_models.keys()), label=f"Select {VR_ARCH_TYPE} Model", interactive=True) + self.vr_download_url = gr.Markdown() + self.vr_download_choice.change(lambda model: md_url(vr_models[model]), inputs=self.vr_download_choice, outputs=self.vr_download_url) + with gr.Row(variant="panel"): + mdx_models = self.models_url[MDX_ARCH_TYPE] + self.mdx_download_choice = gr.Dropdown(choices=list(mdx_models.keys()), label=f"Select {MDX_ARCH_TYPE} Model", interactive=True) + self.mdx_download_url = gr.Markdown() + self.mdx_download_choice.change(lambda model: md_url(mdx_models[model]), inputs=self.mdx_download_choice, outputs=self.mdx_download_url) + + self.arch_choice.change( + self.arch_select_update, inputs=self.arch_choice, + outputs=[self.model_choice, self.arch_setting1, self.arch_setting2]) + self.model_choice.change( + self.model_select_update, inputs=[self.arch_choice, self.model_choice], + outputs=[self.primary_stem_only, self.secondary_stem_only, self.primary_stem_out, self.secondary_stem_out]) + + self.checkbox_set_root_value(self.use_gpu, 'is_gpu_conversion_var') + self.checkbox_set_root_value(self.sample_mode, 'model_sample_mode_var') + self.set_checkboxes_exclusive( + [self.primary_stem_only, self.secondary_stem_only], + [lambda value: root.is_primary_stem_only_var.set(value), lambda value: root.is_secondary_stem_only_var.set(value)]) + + self.process_submit.click( + self.process, + inputs=[self.audio_in, self.input_filename, self.model_choice, self.arch_choice, self.arch_setting1, self.arch_setting2], + outputs=[self.primary_stem_out, self.secondary_stem_out, self.out_message]) + + def launch(self, **kwargs): + self.app.queue().launch(**kwargs) + + +uvr = UVRInterface() +uvr.cached_sources_clear() + +webui = UVRWebUI(uvr, online_data_path='models/download_checks.json') + + +print(webui.models_url) +model_dict = webui.models_url + +import os +import wget + +for category, models in model_dict.items(): + if category in ['VR Arc', 'MDX-Net']: + if category == 'VR Arc': + model_path = 'models/VR_Models' + elif category == 'MDX-Net': + model_path = 'models/MDX_Net_Models' + + for model_name, model_url in models.items(): + cmd = f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -j5 -x16 -s16 -k1M -c -d {model_path} -Z {model_url}" + os.system(cmd) + + print("Models downloaded successfully.") + else: + print(f"Ignoring category: {category}") + + + + +webui = UVRWebUI(uvr, online_data_path='models/download_checks.json') +webui.launch()