oItsMineZ commited on
Commit
fa4dd2b
1 Parent(s): d7296fa

Upload 75 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +7 -7
  2. UVR_interface.py +852 -0
  3. app.py +1 -0
  4. demucs/__init__.py +5 -0
  5. demucs/__main__.py +272 -0
  6. demucs/apply.py +294 -0
  7. demucs/demucs.py +459 -0
  8. demucs/filtering.py +502 -0
  9. demucs/hdemucs.py +782 -0
  10. demucs/htdemucs.py +648 -0
  11. demucs/model.py +218 -0
  12. demucs/model_v2.py +218 -0
  13. demucs/pretrained.py +180 -0
  14. demucs/repo.py +148 -0
  15. demucs/spec.py +41 -0
  16. demucs/states.py +148 -0
  17. demucs/tasnet.py +447 -0
  18. demucs/tasnet_v2.py +452 -0
  19. demucs/transformer.py +839 -0
  20. demucs/utils.py +502 -0
  21. gui_data/constants.py +1147 -0
  22. gui_data/error_handling.py +106 -0
  23. gui_data/old_data_check.py +27 -0
  24. lib_v5/mdxnet.py +140 -0
  25. lib_v5/mixer.ckpt +3 -0
  26. lib_v5/modules.py +74 -0
  27. lib_v5/pyrb.py +92 -0
  28. lib_v5/spec_utils.py +692 -0
  29. lib_v5/vr_network/__init__.py +1 -0
  30. lib_v5/vr_network/__pycache__/__init__.cpython-310.pyc +0 -0
  31. lib_v5/vr_network/__pycache__/layers.cpython-310.pyc +0 -0
  32. lib_v5/vr_network/__pycache__/layers_new.cpython-310.pyc +0 -0
  33. lib_v5/vr_network/__pycache__/model_param_init.cpython-310.pyc +0 -0
  34. lib_v5/vr_network/__pycache__/nets.cpython-310.pyc +0 -0
  35. lib_v5/vr_network/__pycache__/nets_new.cpython-310.pyc +0 -0
  36. lib_v5/vr_network/layers.py +143 -0
  37. lib_v5/vr_network/layers_new.py +126 -0
  38. lib_v5/vr_network/model_param_init.py +59 -0
  39. lib_v5/vr_network/modelparams/1band_sr16000_hl512.json +19 -0
  40. lib_v5/vr_network/modelparams/1band_sr32000_hl512.json +19 -0
  41. lib_v5/vr_network/modelparams/1band_sr33075_hl384.json +19 -0
  42. lib_v5/vr_network/modelparams/1band_sr44100_hl1024.json +19 -0
  43. lib_v5/vr_network/modelparams/1band_sr44100_hl256.json +19 -0
  44. lib_v5/vr_network/modelparams/1band_sr44100_hl512.json +19 -0
  45. lib_v5/vr_network/modelparams/1band_sr44100_hl512_cut.json +19 -0
  46. lib_v5/vr_network/modelparams/1band_sr44100_hl512_nf1024.json +19 -0
  47. lib_v5/vr_network/modelparams/2band_32000.json +30 -0
  48. lib_v5/vr_network/modelparams/2band_44100_lofi.json +30 -0
  49. lib_v5/vr_network/modelparams/2band_48000.json +30 -0
  50. lib_v5/vr_network/modelparams/3band_44100.json +42 -0
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
  title: Ultimate Vocal Remover WebUI
3
- emoji: 🌍
4
- colorFrom: gray
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 4.26.0
8
  app_file: app.py
9
  pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Ultimate Vocal Remover WebUI
3
+ emoji: 🗣️-🎵
4
+ colorFrom: purple
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.44.2
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
+ short_description: Remove Vocal and Instrument from Music!
12
+ ---
UVR_interface.py ADDED
@@ -0,0 +1,852 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import audioread
2
+ import librosa
3
+
4
+ import os
5
+ import sys
6
+ import json
7
+ import time
8
+ from tqdm import tqdm
9
+ import pickle
10
+ import hashlib
11
+ import logging
12
+ import traceback
13
+ import shutil
14
+ import soundfile as sf
15
+
16
+ import torch
17
+
18
+ from gui_data.constants import *
19
+ from gui_data.old_data_check import file_check, remove_unneeded_yamls, remove_temps
20
+ from lib_v5.vr_network.model_param_init import ModelParameters
21
+ from lib_v5 import spec_utils
22
+ from pathlib import Path
23
+ from separate import SeperateAttributes, SeperateDemucs, SeperateMDX, SeperateVR, save_format
24
+ from typing import List
25
+
26
+
27
+ logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO)
28
+ logging.info('UVR BEGIN')
29
+
30
+ PREVIOUS_PATCH_WIN = 'UVR_Patch_1_12_23_14_54'
31
+
32
+ is_dnd_compatible = True
33
+ banner_placement = -2
34
+
35
+ def save_data(data):
36
+ """
37
+ Saves given data as a .pkl (pickle) file
38
+
39
+ Paramters:
40
+ data(dict):
41
+ Dictionary containing all the necessary data to save
42
+ """
43
+ # Open data file, create it if it does not exist
44
+ with open('data.pkl', 'wb') as data_file:
45
+ pickle.dump(data, data_file)
46
+
47
+ def load_data() -> dict:
48
+ """
49
+ Loads saved pkl file and returns the stored data
50
+
51
+ Returns(dict):
52
+ Dictionary containing all the saved data
53
+ """
54
+ try:
55
+ with open('data.pkl', 'rb') as data_file: # Open data file
56
+ data = pickle.load(data_file)
57
+
58
+ return data
59
+ except (ValueError, FileNotFoundError):
60
+ # Data File is corrupted or not found so recreate it
61
+
62
+ save_data(data=DEFAULT_DATA)
63
+
64
+ return load_data()
65
+
66
+ def load_model_hash_data(dictionary):
67
+ '''Get the model hash dictionary'''
68
+
69
+ with open(dictionary) as d:
70
+ data = d.read()
71
+
72
+ return json.loads(data)
73
+
74
+ # Change the current working directory to the directory
75
+ # this file sits in
76
+ if getattr(sys, 'frozen', False):
77
+ # If the application is run as a bundle, the PyInstaller bootloader
78
+ # extends the sys module by a flag frozen=True and sets the app
79
+ # path into variable _MEIPASS'.
80
+ BASE_PATH = sys._MEIPASS
81
+ else:
82
+ BASE_PATH = os.path.dirname(os.path.abspath(__file__))
83
+
84
+ os.chdir(BASE_PATH) # Change the current working directory to the base path
85
+
86
+ debugger = []
87
+
88
+ #--Constants--
89
+ #Models
90
+ MODELS_DIR = os.path.join(BASE_PATH, 'models')
91
+ VR_MODELS_DIR = os.path.join(MODELS_DIR, 'VR_Models')
92
+ MDX_MODELS_DIR = os.path.join(MODELS_DIR, 'MDX_Net_Models')
93
+ DEMUCS_MODELS_DIR = os.path.join(MODELS_DIR, 'Demucs_Models')
94
+ DEMUCS_NEWER_REPO_DIR = os.path.join(DEMUCS_MODELS_DIR, 'v3_v4_repo')
95
+ MDX_MIXER_PATH = os.path.join(BASE_PATH, 'lib_v5', 'mixer.ckpt')
96
+
97
+ #Cache & Parameters
98
+ VR_HASH_DIR = os.path.join(VR_MODELS_DIR, 'model_data')
99
+ VR_HASH_JSON = os.path.join(VR_MODELS_DIR, 'model_data', 'model_data.json')
100
+ MDX_HASH_DIR = os.path.join(MDX_MODELS_DIR, 'model_data')
101
+ MDX_HASH_JSON = os.path.join(MDX_MODELS_DIR, 'model_data', 'model_data.json')
102
+ DEMUCS_MODEL_NAME_SELECT = os.path.join(DEMUCS_MODELS_DIR, 'model_data', 'model_name_mapper.json')
103
+ MDX_MODEL_NAME_SELECT = os.path.join(MDX_MODELS_DIR, 'model_data', 'model_name_mapper.json')
104
+ ENSEMBLE_CACHE_DIR = os.path.join(BASE_PATH, 'gui_data', 'saved_ensembles')
105
+ SETTINGS_CACHE_DIR = os.path.join(BASE_PATH, 'gui_data', 'saved_settings')
106
+ VR_PARAM_DIR = os.path.join(BASE_PATH, 'lib_v5', 'vr_network', 'modelparams')
107
+ SAMPLE_CLIP_PATH = os.path.join(BASE_PATH, 'temp_sample_clips')
108
+ ENSEMBLE_TEMP_PATH = os.path.join(BASE_PATH, 'ensemble_temps')
109
+
110
+ #Style
111
+ ICON_IMG_PATH = os.path.join(BASE_PATH, 'gui_data', 'img', 'GUI-Icon.ico')
112
+ FONT_PATH = os.path.join(BASE_PATH, 'gui_data', 'fonts', 'centurygothic', 'GOTHIC.TTF')#ensemble_temps
113
+
114
+ #Other
115
+ COMPLETE_CHIME = os.path.join(BASE_PATH, 'gui_data', 'complete_chime.wav')
116
+ FAIL_CHIME = os.path.join(BASE_PATH, 'gui_data', 'fail_chime.wav')
117
+ CHANGE_LOG = os.path.join(BASE_PATH, 'gui_data', 'change_log.txt')
118
+ SPLASH_DOC = os.path.join(BASE_PATH, 'tmp', 'splash.txt')
119
+
120
+ file_check(os.path.join(MODELS_DIR, 'Main_Models'), VR_MODELS_DIR)
121
+ file_check(os.path.join(DEMUCS_MODELS_DIR, 'v3_repo'), DEMUCS_NEWER_REPO_DIR)
122
+ remove_unneeded_yamls(DEMUCS_MODELS_DIR)
123
+
124
+ remove_temps(ENSEMBLE_TEMP_PATH)
125
+ remove_temps(SAMPLE_CLIP_PATH)
126
+ remove_temps(os.path.join(BASE_PATH, 'img'))
127
+
128
+ if not os.path.isdir(ENSEMBLE_TEMP_PATH):
129
+ os.mkdir(ENSEMBLE_TEMP_PATH)
130
+
131
+ if not os.path.isdir(SAMPLE_CLIP_PATH):
132
+ os.mkdir(SAMPLE_CLIP_PATH)
133
+
134
+ model_hash_table = {}
135
+ data = load_data()
136
+
137
+ class ModelData():
138
+ def __init__(self, model_name: str,
139
+ selected_process_method=ENSEMBLE_MODE,
140
+ is_secondary_model=False,
141
+ primary_model_primary_stem=None,
142
+ is_primary_model_primary_stem_only=False,
143
+ is_primary_model_secondary_stem_only=False,
144
+ is_pre_proc_model=False,
145
+ is_dry_check=False):
146
+
147
+ self.is_gpu_conversion = 0 if root.is_gpu_conversion_var.get() else -1
148
+ self.is_normalization = root.is_normalization_var.get()
149
+ self.is_primary_stem_only = root.is_primary_stem_only_var.get()
150
+ self.is_secondary_stem_only = root.is_secondary_stem_only_var.get()
151
+ self.is_denoise = root.is_denoise_var.get()
152
+ self.mdx_batch_size = 1 if root.mdx_batch_size_var.get() == DEF_OPT else int(root.mdx_batch_size_var.get())
153
+ self.is_mdx_ckpt = False
154
+ self.wav_type_set = root.wav_type_set
155
+ self.mp3_bit_set = root.mp3_bit_set_var.get()
156
+ self.save_format = root.save_format_var.get()
157
+ self.is_invert_spec = root.is_invert_spec_var.get()
158
+ self.is_mixer_mode = root.is_mixer_mode_var.get()
159
+ self.demucs_stems = root.demucs_stems_var.get()
160
+ self.demucs_source_list = []
161
+ self.demucs_stem_count = 0
162
+ self.mixer_path = MDX_MIXER_PATH
163
+ self.model_name = model_name
164
+ self.process_method = selected_process_method
165
+ self.model_status = False if self.model_name == CHOOSE_MODEL or self.model_name == NO_MODEL else True
166
+ self.primary_stem = None
167
+ self.secondary_stem = None
168
+ self.is_ensemble_mode = False
169
+ self.ensemble_primary_stem = None
170
+ self.ensemble_secondary_stem = None
171
+ self.primary_model_primary_stem = primary_model_primary_stem
172
+ self.is_secondary_model = is_secondary_model
173
+ self.secondary_model = None
174
+ self.secondary_model_scale = None
175
+ self.demucs_4_stem_added_count = 0
176
+ self.is_demucs_4_stem_secondaries = False
177
+ self.is_4_stem_ensemble = False
178
+ self.pre_proc_model = None
179
+ self.pre_proc_model_activated = False
180
+ self.is_pre_proc_model = is_pre_proc_model
181
+ self.is_dry_check = is_dry_check
182
+ self.model_samplerate = 44100
183
+ self.model_capacity = 32, 128
184
+ self.is_vr_51_model = False
185
+ self.is_demucs_pre_proc_model_inst_mix = False
186
+ self.manual_download_Button = None
187
+ self.secondary_model_4_stem = []
188
+ self.secondary_model_4_stem_scale = []
189
+ self.secondary_model_4_stem_names = []
190
+ self.secondary_model_4_stem_model_names_list = []
191
+ self.all_models = []
192
+ self.secondary_model_other = None
193
+ self.secondary_model_scale_other = None
194
+ self.secondary_model_bass = None
195
+ self.secondary_model_scale_bass = None
196
+ self.secondary_model_drums = None
197
+ self.secondary_model_scale_drums = None
198
+
199
+ if selected_process_method == ENSEMBLE_MODE:
200
+ partitioned_name = model_name.partition(ENSEMBLE_PARTITION)
201
+ self.process_method = partitioned_name[0]
202
+ self.model_name = partitioned_name[2]
203
+ self.model_and_process_tag = model_name
204
+ self.ensemble_primary_stem, self.ensemble_secondary_stem = root.return_ensemble_stems()
205
+ self.is_ensemble_mode = True if not is_secondary_model and not is_pre_proc_model else False
206
+ self.is_4_stem_ensemble = True if root.ensemble_main_stem_var.get() == FOUR_STEM_ENSEMBLE and self.is_ensemble_mode else False
207
+ self.pre_proc_model_activated = root.is_demucs_pre_proc_model_activate_var.get() if not self.ensemble_primary_stem == VOCAL_STEM else False
208
+
209
+ if self.process_method == VR_ARCH_TYPE:
210
+ self.is_secondary_model_activated = root.vr_is_secondary_model_activate_var.get() if not self.is_secondary_model else False
211
+ self.aggression_setting = float(int(root.aggression_setting_var.get())/100)
212
+ self.is_tta = root.is_tta_var.get()
213
+ self.is_post_process = root.is_post_process_var.get()
214
+ self.window_size = int(root.window_size_var.get())
215
+ self.batch_size = 1 if root.batch_size_var.get() == DEF_OPT else int(root.batch_size_var.get())
216
+ self.crop_size = int(root.crop_size_var.get())
217
+ self.is_high_end_process = 'mirroring' if root.is_high_end_process_var.get() else 'None'
218
+ self.post_process_threshold = float(root.post_process_threshold_var.get())
219
+ self.model_capacity = 32, 128
220
+ self.model_path = os.path.join(VR_MODELS_DIR, f"{self.model_name}.pth")
221
+ self.get_model_hash()
222
+ if self.model_hash:
223
+ self.model_data = self.get_model_data(VR_HASH_DIR, root.vr_hash_MAPPER) if not self.model_hash == WOOD_INST_MODEL_HASH else WOOD_INST_PARAMS
224
+ if self.model_data:
225
+ vr_model_param = os.path.join(VR_PARAM_DIR, "{}.json".format(self.model_data["vr_model_param"]))
226
+ self.primary_stem = self.model_data["primary_stem"]
227
+ self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem]
228
+ self.vr_model_param = ModelParameters(vr_model_param)
229
+ self.model_samplerate = self.vr_model_param.param['sr']
230
+ if "nout" in self.model_data.keys() and "nout_lstm" in self.model_data.keys():
231
+ self.model_capacity = self.model_data["nout"], self.model_data["nout_lstm"]
232
+ self.is_vr_51_model = True
233
+ else:
234
+ self.model_status = False
235
+
236
+ if self.process_method == MDX_ARCH_TYPE:
237
+ self.is_secondary_model_activated = root.mdx_is_secondary_model_activate_var.get() if not is_secondary_model else False
238
+ self.margin = int(root.margin_var.get())
239
+ self.chunks = root.determine_auto_chunks(root.chunks_var.get(), self.is_gpu_conversion) if root.is_chunk_mdxnet_var.get() else 0
240
+ self.get_mdx_model_path()
241
+ self.get_model_hash()
242
+ if self.model_hash:
243
+ self.model_data = self.get_model_data(MDX_HASH_DIR, root.mdx_hash_MAPPER)
244
+ if self.model_data:
245
+ self.compensate = self.model_data["compensate"] if root.compensate_var.get() == AUTO_SELECT else float(root.compensate_var.get())
246
+ self.mdx_dim_f_set = self.model_data["mdx_dim_f_set"]
247
+ self.mdx_dim_t_set = self.model_data["mdx_dim_t_set"]
248
+ self.mdx_n_fft_scale_set = self.model_data["mdx_n_fft_scale_set"]
249
+ self.primary_stem = self.model_data["primary_stem"]
250
+ self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem]
251
+ else:
252
+ self.model_status = False
253
+
254
+ if self.process_method == DEMUCS_ARCH_TYPE:
255
+ self.is_secondary_model_activated = root.demucs_is_secondary_model_activate_var.get() if not is_secondary_model else False
256
+ if not self.is_ensemble_mode:
257
+ self.pre_proc_model_activated = root.is_demucs_pre_proc_model_activate_var.get() if not root.demucs_stems_var.get() in [VOCAL_STEM, INST_STEM] else False
258
+ self.overlap = float(root.overlap_var.get())
259
+ self.margin_demucs = int(root.margin_demucs_var.get())
260
+ self.chunks_demucs = root.determine_auto_chunks(root.chunks_demucs_var.get(), self.is_gpu_conversion)
261
+ self.shifts = int(root.shifts_var.get())
262
+ self.is_split_mode = root.is_split_mode_var.get()
263
+ self.segment = root.segment_var.get()
264
+ self.is_chunk_demucs = root.is_chunk_demucs_var.get()
265
+ self.is_demucs_combine_stems = root.is_demucs_combine_stems_var.get()
266
+ self.is_primary_stem_only = root.is_primary_stem_only_var.get() if self.is_ensemble_mode else root.is_primary_stem_only_Demucs_var.get()
267
+ self.is_secondary_stem_only = root.is_secondary_stem_only_var.get() if self.is_ensemble_mode else root.is_secondary_stem_only_Demucs_var.get()
268
+ self.get_demucs_model_path()
269
+ self.get_demucs_model_data()
270
+
271
+ self.model_basename = os.path.splitext(os.path.basename(self.model_path))[0] if self.model_status else None
272
+ self.pre_proc_model_activated = self.pre_proc_model_activated if not self.is_secondary_model else False
273
+
274
+ self.is_primary_model_primary_stem_only = is_primary_model_primary_stem_only
275
+ self.is_primary_model_secondary_stem_only = is_primary_model_secondary_stem_only
276
+
277
+ if self.is_secondary_model_activated and self.model_status:
278
+ if (not self.is_ensemble_mode and root.demucs_stems_var.get() == ALL_STEMS and self.process_method == DEMUCS_ARCH_TYPE) or self.is_4_stem_ensemble:
279
+ for key in DEMUCS_4_SOURCE_LIST:
280
+ self.secondary_model_data(key)
281
+ self.secondary_model_4_stem.append(self.secondary_model)
282
+ self.secondary_model_4_stem_scale.append(self.secondary_model_scale)
283
+ self.secondary_model_4_stem_names.append(key)
284
+ self.demucs_4_stem_added_count = sum(i is not None for i in self.secondary_model_4_stem)
285
+ self.is_secondary_model_activated = False if all(i is None for i in self.secondary_model_4_stem) else True
286
+ self.demucs_4_stem_added_count = self.demucs_4_stem_added_count - 1 if self.is_secondary_model_activated else self.demucs_4_stem_added_count
287
+ if self.is_secondary_model_activated:
288
+ self.secondary_model_4_stem_model_names_list = [None if i is None else i.model_basename for i in self.secondary_model_4_stem]
289
+ self.is_demucs_4_stem_secondaries = True
290
+ else:
291
+ primary_stem = self.ensemble_primary_stem if self.is_ensemble_mode and self.process_method == DEMUCS_ARCH_TYPE else self.primary_stem
292
+ self.secondary_model_data(primary_stem)
293
+
294
+ if self.process_method == DEMUCS_ARCH_TYPE and not is_secondary_model:
295
+ if self.demucs_stem_count >= 3 and self.pre_proc_model_activated:
296
+ self.pre_proc_model_activated = True
297
+ self.pre_proc_model = root.process_determine_demucs_pre_proc_model(self.primary_stem)
298
+ self.is_demucs_pre_proc_model_inst_mix = root.is_demucs_pre_proc_model_inst_mix_var.get() if self.pre_proc_model else False
299
+
300
+ def secondary_model_data(self, primary_stem):
301
+ secondary_model_data = root.process_determine_secondary_model(self.process_method, primary_stem, self.is_primary_stem_only, self.is_secondary_stem_only)
302
+ self.secondary_model = secondary_model_data[0]
303
+ self.secondary_model_scale = secondary_model_data[1]
304
+ self.is_secondary_model_activated = False if not self.secondary_model else True
305
+ if self.secondary_model:
306
+ self.is_secondary_model_activated = False if self.secondary_model.model_basename == self.model_basename else True
307
+
308
+ def get_mdx_model_path(self):
309
+
310
+ if self.model_name.endswith(CKPT):
311
+ # self.chunks = 0
312
+ # self.is_mdx_batch_mode = True
313
+ self.is_mdx_ckpt = True
314
+
315
+ ext = '' if self.is_mdx_ckpt else ONNX
316
+
317
+ for file_name, chosen_mdx_model in root.mdx_name_select_MAPPER.items():
318
+ if self.model_name in chosen_mdx_model:
319
+ self.model_path = os.path.join(MDX_MODELS_DIR, f"{file_name}{ext}")
320
+ break
321
+ else:
322
+ self.model_path = os.path.join(MDX_MODELS_DIR, f"{self.model_name}{ext}")
323
+
324
+ self.mixer_path = os.path.join(MDX_MODELS_DIR, f"mixer_val.ckpt")
325
+
326
+ def get_demucs_model_path(self):
327
+
328
+ demucs_newer = [True for x in DEMUCS_NEWER_TAGS if x in self.model_name]
329
+ demucs_model_dir = DEMUCS_NEWER_REPO_DIR if demucs_newer else DEMUCS_MODELS_DIR
330
+
331
+ for file_name, chosen_model in root.demucs_name_select_MAPPER.items():
332
+ if self.model_name in chosen_model:
333
+ self.model_path = os.path.join(demucs_model_dir, file_name)
334
+ break
335
+ else:
336
+ self.model_path = os.path.join(DEMUCS_NEWER_REPO_DIR, f'{self.model_name}.yaml')
337
+
338
+ def get_demucs_model_data(self):
339
+
340
+ self.demucs_version = DEMUCS_V4
341
+
342
+ for key, value in DEMUCS_VERSION_MAPPER.items():
343
+ if value in self.model_name:
344
+ self.demucs_version = key
345
+
346
+ self.demucs_source_list = DEMUCS_2_SOURCE if DEMUCS_UVR_MODEL in self.model_name else DEMUCS_4_SOURCE
347
+ self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER if DEMUCS_UVR_MODEL in self.model_name else DEMUCS_4_SOURCE_MAPPER
348
+ self.demucs_stem_count = 2 if DEMUCS_UVR_MODEL in self.model_name else 4
349
+
350
+ if not self.is_ensemble_mode:
351
+ self.primary_stem = PRIMARY_STEM if self.demucs_stems == ALL_STEMS else self.demucs_stems
352
+ self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem]
353
+
354
+ def get_model_data(self, model_hash_dir, hash_mapper):
355
+ model_settings_json = os.path.join(model_hash_dir, "{}.json".format(self.model_hash))
356
+
357
+ if os.path.isfile(model_settings_json):
358
+ return json.load(open(model_settings_json))
359
+ else:
360
+ for hash, settings in hash_mapper.items():
361
+ if self.model_hash in hash:
362
+ return settings
363
+ else:
364
+ return self.get_model_data_from_popup()
365
+
366
+ def get_model_data_from_popup(self):
367
+ return None
368
+
369
+ def get_model_hash(self):
370
+ self.model_hash = None
371
+
372
+ if not os.path.isfile(self.model_path):
373
+ self.model_status = False
374
+ self.model_hash is None
375
+ else:
376
+ if model_hash_table:
377
+ for (key, value) in model_hash_table.items():
378
+ if self.model_path == key:
379
+ self.model_hash = value
380
+ break
381
+
382
+ if not self.model_hash:
383
+ try:
384
+ with open(self.model_path, 'rb') as f:
385
+ f.seek(- 10000 * 1024, 2)
386
+ self.model_hash = hashlib.md5(f.read()).hexdigest()
387
+ except:
388
+ self.model_hash = hashlib.md5(open(self.model_path,'rb').read()).hexdigest()
389
+
390
+ table_entry = {self.model_path: self.model_hash}
391
+ model_hash_table.update(table_entry)
392
+
393
+
394
+ class Ensembler():
395
+ def __init__(self, is_manual_ensemble=False):
396
+ self.is_save_all_outputs_ensemble = root.is_save_all_outputs_ensemble_var.get()
397
+ chosen_ensemble_name = '{}'.format(root.chosen_ensemble_var.get().replace(" ", "_")) if not root.chosen_ensemble_var.get() == CHOOSE_ENSEMBLE_OPTION else 'Ensembled'
398
+ ensemble_algorithm = root.ensemble_type_var.get().partition("/")
399
+ ensemble_main_stem_pair = root.ensemble_main_stem_var.get().partition("/")
400
+ time_stamp = round(time.time())
401
+ self.audio_tool = MANUAL_ENSEMBLE
402
+ self.main_export_path = Path(root.export_path_var.get())
403
+ self.chosen_ensemble = f"_{chosen_ensemble_name}" if root.is_append_ensemble_name_var.get() else ''
404
+ ensemble_folder_name = self.main_export_path if self.is_save_all_outputs_ensemble else ENSEMBLE_TEMP_PATH
405
+ self.ensemble_folder_name = os.path.join(ensemble_folder_name, '{}_Outputs_{}'.format(chosen_ensemble_name, time_stamp))
406
+ self.is_testing_audio = f"{time_stamp}_" if root.is_testing_audio_var.get() else ''
407
+ self.primary_algorithm = ensemble_algorithm[0]
408
+ self.secondary_algorithm = ensemble_algorithm[2]
409
+ self.ensemble_primary_stem = ensemble_main_stem_pair[0]
410
+ self.ensemble_secondary_stem = ensemble_main_stem_pair[2]
411
+ self.is_normalization = root.is_normalization_var.get()
412
+ self.wav_type_set = root.wav_type_set
413
+ self.mp3_bit_set = root.mp3_bit_set_var.get()
414
+ self.save_format = root.save_format_var.get()
415
+ if not is_manual_ensemble:
416
+ os.mkdir(self.ensemble_folder_name)
417
+
418
+ def ensemble_outputs(self, audio_file_base, export_path, stem, is_4_stem=False, is_inst_mix=False):
419
+ """Processes the given outputs and ensembles them with the chosen algorithm"""
420
+
421
+ if is_4_stem:
422
+ algorithm = root.ensemble_type_var.get()
423
+ stem_tag = stem
424
+ else:
425
+ if is_inst_mix:
426
+ algorithm = self.secondary_algorithm
427
+ stem_tag = f"{self.ensemble_secondary_stem} {INST_STEM}"
428
+ else:
429
+ algorithm = self.primary_algorithm if stem == PRIMARY_STEM else self.secondary_algorithm
430
+ stem_tag = self.ensemble_primary_stem if stem == PRIMARY_STEM else self.ensemble_secondary_stem
431
+
432
+ stem_outputs = self.get_files_to_ensemble(folder=export_path, prefix=audio_file_base, suffix=f"_({stem_tag}).wav")
433
+ audio_file_output = f"{self.is_testing_audio}{audio_file_base}{self.chosen_ensemble}_({stem_tag})"
434
+ stem_save_path = os.path.join('{}'.format(self.main_export_path),'{}.wav'.format(audio_file_output))
435
+
436
+ if stem_outputs:
437
+ spec_utils.ensemble_inputs(stem_outputs, algorithm, self.is_normalization, self.wav_type_set, stem_save_path)
438
+ save_format(stem_save_path, self.save_format, self.mp3_bit_set)
439
+
440
+ if self.is_save_all_outputs_ensemble:
441
+ for i in stem_outputs:
442
+ save_format(i, self.save_format, self.mp3_bit_set)
443
+ else:
444
+ for i in stem_outputs:
445
+ try:
446
+ os.remove(i)
447
+ except Exception as e:
448
+ print(e)
449
+
450
+ def ensemble_manual(self, audio_inputs, audio_file_base, is_bulk=False):
451
+ """Processes the given outputs and ensembles them with the chosen algorithm"""
452
+
453
+ is_mv_sep = True
454
+
455
+ if is_bulk:
456
+ number_list = list(set([os.path.basename(i).split("_")[0] for i in audio_inputs]))
457
+ for n in number_list:
458
+ current_list = [i for i in audio_inputs if os.path.basename(i).startswith(n)]
459
+ audio_file_base = os.path.basename(current_list[0]).split('.wav')[0]
460
+ stem_testing = "instrum" if "Instrumental" in audio_file_base else "vocals"
461
+ if is_mv_sep:
462
+ audio_file_base = audio_file_base.split("_")
463
+ audio_file_base = f"{audio_file_base[1]}_{audio_file_base[2]}_{stem_testing}"
464
+ self.ensemble_manual_process(current_list, audio_file_base, is_bulk)
465
+ else:
466
+ self.ensemble_manual_process(audio_inputs, audio_file_base, is_bulk)
467
+
468
+ def ensemble_manual_process(self, audio_inputs, audio_file_base, is_bulk):
469
+
470
+ algorithm = root.choose_algorithm_var.get()
471
+ algorithm_text = "" if is_bulk else f"_({root.choose_algorithm_var.get()})"
472
+ stem_save_path = os.path.join('{}'.format(self.main_export_path),'{}{}{}.wav'.format(self.is_testing_audio, audio_file_base, algorithm_text))
473
+ spec_utils.ensemble_inputs(audio_inputs, algorithm, self.is_normalization, self.wav_type_set, stem_save_path)
474
+ save_format(stem_save_path, self.save_format, self.mp3_bit_set)
475
+
476
+ def get_files_to_ensemble(self, folder="", prefix="", suffix=""):
477
+ """Grab all the files to be ensembled"""
478
+
479
+ return [os.path.join(folder, i) for i in os.listdir(folder) if i.startswith(prefix) and i.endswith(suffix)]
480
+
481
+
482
+ def secondary_stem(stem):
483
+ """Determines secondary stem"""
484
+
485
+ for key, value in STEM_PAIR_MAPPER.items():
486
+ if stem in key:
487
+ secondary_stem = value
488
+
489
+ return secondary_stem
490
+
491
+
492
+ class UVRInterface:
493
+ def __init__(self) -> None:
494
+ pass
495
+
496
+ def assemble_model_data(self, model=None, arch_type=ENSEMBLE_MODE, is_dry_check=False) -> List[ModelData]:
497
+ if arch_type == ENSEMBLE_STEM_CHECK:
498
+ model_data = self.model_data_table
499
+ missing_models = [model.model_status for model in model_data if not model.model_status]
500
+
501
+ if missing_models or not model_data:
502
+ model_data: List[ModelData] = [ModelData(model_name, is_dry_check=is_dry_check) for model_name in self.ensemble_model_list]
503
+ self.model_data_table = model_data
504
+
505
+ if arch_type == ENSEMBLE_MODE:
506
+ model_data: List[ModelData] = [ModelData(model_name) for model_name in self.ensemble_listbox_get_all_selected_models()]
507
+ if arch_type == ENSEMBLE_CHECK:
508
+ model_data: List[ModelData] = [ModelData(model)]
509
+ if arch_type == VR_ARCH_TYPE or arch_type == VR_ARCH_PM:
510
+ model_data: List[ModelData] = [ModelData(model, VR_ARCH_TYPE)]
511
+ if arch_type == MDX_ARCH_TYPE:
512
+ model_data: List[ModelData] = [ModelData(model, MDX_ARCH_TYPE)]
513
+ if arch_type == DEMUCS_ARCH_TYPE:
514
+ model_data: List[ModelData] = [ModelData(model, DEMUCS_ARCH_TYPE)]#
515
+
516
+ return model_data
517
+
518
+ def create_sample(self, audio_file, sample_path=SAMPLE_CLIP_PATH):
519
+ try:
520
+ with audioread.audio_open(audio_file) as f:
521
+ track_length = int(f.duration)
522
+ except Exception as e:
523
+ print('Audioread failed to get duration. Trying Librosa...')
524
+ y, sr = librosa.load(audio_file, mono=False, sr=44100)
525
+ track_length = int(librosa.get_duration(y=y, sr=sr))
526
+
527
+ clip_duration = int(root.model_sample_mode_duration_var.get())
528
+
529
+ if track_length >= clip_duration:
530
+ offset_cut = track_length//3
531
+ off_cut = offset_cut + track_length
532
+ if not off_cut >= clip_duration:
533
+ offset_cut = 0
534
+ name_apped = f'{clip_duration}_second_'
535
+ else:
536
+ offset_cut, clip_duration = 0, track_length
537
+ name_apped = ''
538
+
539
+ sample = librosa.load(audio_file, offset=offset_cut, duration=clip_duration, mono=False, sr=44100)[0].T
540
+ audio_sample = os.path.join(sample_path, f'{os.path.splitext(os.path.basename(audio_file))[0]}_{name_apped}sample.wav')
541
+ sf.write(audio_sample, sample, 44100)
542
+
543
+ return audio_sample
544
+
545
+ def verify_audio(self, audio_file, is_process=True, sample_path=None):
546
+ is_good = False
547
+ error_data = ''
548
+
549
+ if os.path.isfile(audio_file):
550
+ try:
551
+ librosa.load(audio_file, duration=3, mono=False, sr=44100) if not type(sample_path) is str else self.create_sample(audio_file, sample_path)
552
+ is_good = True
553
+ except Exception as e:
554
+ error_name = f'{type(e).__name__}'
555
+ traceback_text = ''.join(traceback.format_tb(e.__traceback__))
556
+ message = f'{error_name}: "{e}"\n{traceback_text}"'
557
+ if is_process:
558
+ audio_base_name = os.path.basename(audio_file)
559
+ self.error_log_var.set(f'Error Loading the Following File:\n\n\"{audio_base_name}\"\n\nRaw Error Details:\n\n{message}')
560
+ else:
561
+ error_data = AUDIO_VERIFICATION_CHECK(audio_file, message)
562
+
563
+ if is_process:
564
+ return is_good
565
+ else:
566
+ return is_good, error_data
567
+
568
+ def cached_sources_clear(self):
569
+ self.vr_cache_source_mapper = {}
570
+ self.mdx_cache_source_mapper = {}
571
+ self.demucs_cache_source_mapper = {}
572
+
573
+ def cached_model_source_holder(self, process_method, sources, model_name=None):
574
+ if process_method == VR_ARCH_TYPE:
575
+ self.vr_cache_source_mapper = {**self.vr_cache_source_mapper, **{model_name: sources}}
576
+ if process_method == MDX_ARCH_TYPE:
577
+ self.mdx_cache_source_mapper = {**self.mdx_cache_source_mapper, **{model_name: sources}}
578
+ if process_method == DEMUCS_ARCH_TYPE:
579
+ self.demucs_cache_source_mapper = {**self.demucs_cache_source_mapper, **{model_name: sources}}
580
+
581
+ def cached_source_callback(self, process_method, model_name=None):
582
+ model, sources = None, None
583
+
584
+ if process_method == VR_ARCH_TYPE:
585
+ mapper = self.vr_cache_source_mapper
586
+ if process_method == MDX_ARCH_TYPE:
587
+ mapper = self.mdx_cache_source_mapper
588
+ if process_method == DEMUCS_ARCH_TYPE:
589
+ mapper = self.demucs_cache_source_mapper
590
+
591
+ for key, value in mapper.items():
592
+ if model_name in key:
593
+ model = key
594
+ sources = value
595
+
596
+ return model, sources
597
+
598
+ def cached_source_model_list_check(self, model_list: List[ModelData]):
599
+ model: ModelData
600
+ primary_model_names = lambda process_method:[model.model_basename if model.process_method == process_method else None for model in model_list]
601
+ secondary_model_names = lambda process_method:[model.secondary_model.model_basename if model.is_secondary_model_activated and model.process_method == process_method else None for model in model_list]
602
+
603
+ self.vr_primary_model_names = primary_model_names(VR_ARCH_TYPE)
604
+ self.mdx_primary_model_names = primary_model_names(MDX_ARCH_TYPE)
605
+ self.demucs_primary_model_names = primary_model_names(DEMUCS_ARCH_TYPE)
606
+ self.vr_secondary_model_names = secondary_model_names(VR_ARCH_TYPE)
607
+ self.mdx_secondary_model_names = secondary_model_names(MDX_ARCH_TYPE)
608
+ self.demucs_secondary_model_names = [model.secondary_model.model_basename if model.is_secondary_model_activated and model.process_method == DEMUCS_ARCH_TYPE and not model.secondary_model is None else None for model in model_list]
609
+ self.demucs_pre_proc_model_name = [model.pre_proc_model.model_basename if model.pre_proc_model else None for model in model_list]#list(dict.fromkeys())
610
+
611
+ for model in model_list:
612
+ if model.process_method == DEMUCS_ARCH_TYPE and model.is_demucs_4_stem_secondaries:
613
+ if not model.is_4_stem_ensemble:
614
+ self.demucs_secondary_model_names = model.secondary_model_4_stem_model_names_list
615
+ break
616
+ else:
617
+ for i in model.secondary_model_4_stem_model_names_list:
618
+ self.demucs_secondary_model_names.append(i)
619
+
620
+ self.all_models = self.vr_primary_model_names + self.mdx_primary_model_names + self.demucs_primary_model_names + self.vr_secondary_model_names + self.mdx_secondary_model_names + self.demucs_secondary_model_names + self.demucs_pre_proc_model_name
621
+
622
+ def process(self, model_name, arch_type, audio_file, export_path, is_model_sample_mode=False, is_4_stem_ensemble=False, set_progress_func=None, console_write=print) -> SeperateAttributes:
623
+ stime = time.perf_counter()
624
+ time_elapsed = lambda:f'Time Elapsed: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - stime)))}'
625
+
626
+ if arch_type==ENSEMBLE_MODE:
627
+ model_list, ensemble = self.assemble_model_data(), Ensembler()
628
+ export_path = ensemble.ensemble_folder_name
629
+ is_ensemble = True
630
+ else:
631
+ model_list = self.assemble_model_data(model_name, arch_type)
632
+ is_ensemble = False
633
+ self.cached_source_model_list_check(model_list)
634
+ model = model_list[0]
635
+
636
+ if self.verify_audio(audio_file):
637
+ audio_file = self.create_sample(audio_file) if is_model_sample_mode else audio_file
638
+ else:
639
+ print(f'"{os.path.basename(audio_file)}\" is missing or currupted.\n')
640
+ exit()
641
+
642
+ audio_file_base = f"{os.path.splitext(os.path.basename(audio_file))[0]}"
643
+ audio_file_base = audio_file_base if is_ensemble else f"{round(time.time())}_{audio_file_base}"
644
+ audio_file_base = audio_file_base if not is_ensemble else f"{audio_file_base}_{model.model_basename}"
645
+ if not is_ensemble:
646
+ audio_file_base = f"{audio_file_base}_{model.model_basename}"
647
+
648
+ if not is_ensemble:
649
+ export_path = os.path.join(Path(export_path), model.model_basename, os.path.splitext(os.path.basename(audio_file))[0])
650
+ if not os.path.isdir(export_path):
651
+ os.makedirs(export_path)
652
+
653
+ if set_progress_func is None:
654
+ pbar = tqdm(total=1)
655
+ self._progress = 0
656
+ def set_progress_func(step, inference_iterations=0):
657
+ progress_curr = step + inference_iterations
658
+ pbar.update(progress_curr-self._progress)
659
+ self._progress = progress_curr
660
+
661
+ def postprocess():
662
+ pbar.close()
663
+ else:
664
+ def postprocess():
665
+ pass
666
+
667
+ process_data = {
668
+ 'model_data': model,
669
+ 'export_path': export_path,
670
+ 'audio_file_base': audio_file_base,
671
+ 'audio_file': audio_file,
672
+ 'set_progress_bar': set_progress_func,
673
+ 'write_to_console': lambda progress_text, base_text='': console_write(base_text + progress_text),
674
+ 'process_iteration': lambda:None,
675
+ 'cached_source_callback': self.cached_source_callback,
676
+ 'cached_model_source_holder': self.cached_model_source_holder,
677
+ 'list_all_models': self.all_models,
678
+ 'is_ensemble_master': is_ensemble,
679
+ 'is_4_stem_ensemble': is_ensemble and is_4_stem_ensemble
680
+ }
681
+ if model.process_method == VR_ARCH_TYPE:
682
+ seperator = SeperateVR(model, process_data)
683
+ if model.process_method == MDX_ARCH_TYPE:
684
+ seperator = SeperateMDX(model, process_data)
685
+ if model.process_method == DEMUCS_ARCH_TYPE:
686
+ seperator = SeperateDemucs(model, process_data)
687
+
688
+ seperator.seperate()
689
+ postprocess()
690
+
691
+ if is_ensemble:
692
+ audio_file_base = audio_file_base.replace(f"_{model.model_basename}", "")
693
+ console_write(ENSEMBLING_OUTPUTS)
694
+
695
+ if is_4_stem_ensemble:
696
+ for output_stem in DEMUCS_4_SOURCE_LIST:
697
+ ensemble.ensemble_outputs(audio_file_base, export_path, output_stem, is_4_stem=True)
698
+ else:
699
+ if not root.is_secondary_stem_only_var.get():
700
+ ensemble.ensemble_outputs(audio_file_base, export_path, PRIMARY_STEM)
701
+ if not root.is_primary_stem_only_var.get():
702
+ ensemble.ensemble_outputs(audio_file_base, export_path, SECONDARY_STEM)
703
+ ensemble.ensemble_outputs(audio_file_base, export_path, SECONDARY_STEM, is_inst_mix=True)
704
+
705
+ console_write(DONE)
706
+
707
+ if is_model_sample_mode:
708
+ if os.path.isfile(audio_file):
709
+ os.remove(audio_file)
710
+
711
+ torch.cuda.empty_cache()
712
+
713
+ if is_ensemble and len(os.listdir(export_path)) == 0:
714
+ shutil.rmtree(export_path)
715
+ console_write(f'Process Complete, using time: {time_elapsed()}\nOutput path: {export_path}')
716
+ self.cached_sources_clear()
717
+ return seperator
718
+
719
+
720
+ class RootWrapper:
721
+ def __init__(self, var) -> None:
722
+ self.var=var
723
+
724
+ def set(self, val):
725
+ self.var=val
726
+
727
+ def get(self):
728
+ return self.var
729
+
730
+ class FakeRoot:
731
+ def __init__(self) -> None:
732
+ self.wav_type_set = 'PCM_16'
733
+ self.vr_hash_MAPPER = load_model_hash_data(VR_HASH_JSON)
734
+ self.mdx_hash_MAPPER = load_model_hash_data(MDX_HASH_JSON)
735
+ self.mdx_name_select_MAPPER = load_model_hash_data(MDX_MODEL_NAME_SELECT)
736
+ self.demucs_name_select_MAPPER = load_model_hash_data(DEMUCS_MODEL_NAME_SELECT)
737
+
738
+ def __getattribute__(self, __name: str):
739
+ try:
740
+ return super().__getattribute__(__name)
741
+ except AttributeError:
742
+ wrapped=RootWrapper(None)
743
+ super().__setattr__(__name, wrapped)
744
+ return wrapped
745
+
746
+ def load_saved_settings(self, loaded_setting: dict, process_method=None):
747
+ """Loads user saved application settings or resets to default"""
748
+
749
+ for key, value in DEFAULT_DATA.items():
750
+ if not key in loaded_setting.keys():
751
+ loaded_setting = {**loaded_setting, **{key:value}}
752
+ loaded_setting['batch_size'] = DEF_OPT
753
+
754
+ is_ensemble = True if process_method == ENSEMBLE_MODE else False
755
+
756
+ if not process_method or process_method == VR_ARCH_PM or is_ensemble:
757
+ self.vr_model_var.set(loaded_setting['vr_model'])
758
+ self.aggression_setting_var.set(loaded_setting['aggression_setting'])
759
+ self.window_size_var.set(loaded_setting['window_size'])
760
+ self.batch_size_var.set(loaded_setting['batch_size'])
761
+ self.crop_size_var.set(loaded_setting['crop_size'])
762
+ self.is_tta_var.set(loaded_setting['is_tta'])
763
+ self.is_output_image_var.set(loaded_setting['is_output_image'])
764
+ self.is_post_process_var.set(loaded_setting['is_post_process'])
765
+ self.is_high_end_process_var.set(loaded_setting['is_high_end_process'])
766
+ self.post_process_threshold_var.set(loaded_setting['post_process_threshold'])
767
+ self.vr_voc_inst_secondary_model_var.set(loaded_setting['vr_voc_inst_secondary_model'])
768
+ self.vr_other_secondary_model_var.set(loaded_setting['vr_other_secondary_model'])
769
+ self.vr_bass_secondary_model_var.set(loaded_setting['vr_bass_secondary_model'])
770
+ self.vr_drums_secondary_model_var.set(loaded_setting['vr_drums_secondary_model'])
771
+ self.vr_is_secondary_model_activate_var.set(loaded_setting['vr_is_secondary_model_activate'])
772
+ self.vr_voc_inst_secondary_model_scale_var.set(loaded_setting['vr_voc_inst_secondary_model_scale'])
773
+ self.vr_other_secondary_model_scale_var.set(loaded_setting['vr_other_secondary_model_scale'])
774
+ self.vr_bass_secondary_model_scale_var.set(loaded_setting['vr_bass_secondary_model_scale'])
775
+ self.vr_drums_secondary_model_scale_var.set(loaded_setting['vr_drums_secondary_model_scale'])
776
+
777
+ if not process_method or process_method == DEMUCS_ARCH_TYPE or is_ensemble:
778
+ self.demucs_model_var.set(loaded_setting['demucs_model'])
779
+ self.segment_var.set(loaded_setting['segment'])
780
+ self.overlap_var.set(loaded_setting['overlap'])
781
+ self.shifts_var.set(loaded_setting['shifts'])
782
+ self.chunks_demucs_var.set(loaded_setting['chunks_demucs'])
783
+ self.margin_demucs_var.set(loaded_setting['margin_demucs'])
784
+ self.is_chunk_demucs_var.set(loaded_setting['is_chunk_demucs'])
785
+ self.is_chunk_mdxnet_var.set(loaded_setting['is_chunk_mdxnet'])
786
+ self.is_primary_stem_only_Demucs_var.set(loaded_setting['is_primary_stem_only_Demucs'])
787
+ self.is_secondary_stem_only_Demucs_var.set(loaded_setting['is_secondary_stem_only_Demucs'])
788
+ self.is_split_mode_var.set(loaded_setting['is_split_mode'])
789
+ self.is_demucs_combine_stems_var.set(loaded_setting['is_demucs_combine_stems'])
790
+ self.demucs_voc_inst_secondary_model_var.set(loaded_setting['demucs_voc_inst_secondary_model'])
791
+ self.demucs_other_secondary_model_var.set(loaded_setting['demucs_other_secondary_model'])
792
+ self.demucs_bass_secondary_model_var.set(loaded_setting['demucs_bass_secondary_model'])
793
+ self.demucs_drums_secondary_model_var.set(loaded_setting['demucs_drums_secondary_model'])
794
+ self.demucs_is_secondary_model_activate_var.set(loaded_setting['demucs_is_secondary_model_activate'])
795
+ self.demucs_voc_inst_secondary_model_scale_var.set(loaded_setting['demucs_voc_inst_secondary_model_scale'])
796
+ self.demucs_other_secondary_model_scale_var.set(loaded_setting['demucs_other_secondary_model_scale'])
797
+ self.demucs_bass_secondary_model_scale_var.set(loaded_setting['demucs_bass_secondary_model_scale'])
798
+ self.demucs_drums_secondary_model_scale_var.set(loaded_setting['demucs_drums_secondary_model_scale'])
799
+ self.demucs_stems_var.set(loaded_setting['demucs_stems'])
800
+ # self.update_stem_checkbox_labels(self.demucs_stems_var.get(), demucs=True)
801
+ self.demucs_pre_proc_model_var.set(data['demucs_pre_proc_model'])
802
+ self.is_demucs_pre_proc_model_activate_var.set(data['is_demucs_pre_proc_model_activate'])
803
+ self.is_demucs_pre_proc_model_inst_mix_var.set(data['is_demucs_pre_proc_model_inst_mix'])
804
+
805
+ if not process_method or process_method == MDX_ARCH_TYPE or is_ensemble:
806
+ self.mdx_net_model_var.set(loaded_setting['mdx_net_model'])
807
+ self.chunks_var.set(loaded_setting['chunks'])
808
+ self.margin_var.set(loaded_setting['margin'])
809
+ self.compensate_var.set(loaded_setting['compensate'])
810
+ self.is_denoise_var.set(loaded_setting['is_denoise'])
811
+ self.is_invert_spec_var.set(loaded_setting['is_invert_spec'])
812
+ self.is_mixer_mode_var.set(loaded_setting['is_mixer_mode'])
813
+ self.mdx_batch_size_var.set(loaded_setting['mdx_batch_size'])
814
+ self.mdx_voc_inst_secondary_model_var.set(loaded_setting['mdx_voc_inst_secondary_model'])
815
+ self.mdx_other_secondary_model_var.set(loaded_setting['mdx_other_secondary_model'])
816
+ self.mdx_bass_secondary_model_var.set(loaded_setting['mdx_bass_secondary_model'])
817
+ self.mdx_drums_secondary_model_var.set(loaded_setting['mdx_drums_secondary_model'])
818
+ self.mdx_is_secondary_model_activate_var.set(loaded_setting['mdx_is_secondary_model_activate'])
819
+ self.mdx_voc_inst_secondary_model_scale_var.set(loaded_setting['mdx_voc_inst_secondary_model_scale'])
820
+ self.mdx_other_secondary_model_scale_var.set(loaded_setting['mdx_other_secondary_model_scale'])
821
+ self.mdx_bass_secondary_model_scale_var.set(loaded_setting['mdx_bass_secondary_model_scale'])
822
+ self.mdx_drums_secondary_model_scale_var.set(loaded_setting['mdx_drums_secondary_model_scale'])
823
+
824
+ if not process_method or is_ensemble:
825
+ self.is_save_all_outputs_ensemble_var.set(loaded_setting['is_save_all_outputs_ensemble'])
826
+ self.is_append_ensemble_name_var.set(loaded_setting['is_append_ensemble_name'])
827
+ self.chosen_audio_tool_var.set(loaded_setting['chosen_audio_tool'])
828
+ self.choose_algorithm_var.set(loaded_setting['choose_algorithm'])
829
+ self.time_stretch_rate_var.set(loaded_setting['time_stretch_rate'])
830
+ self.pitch_rate_var.set(loaded_setting['pitch_rate'])
831
+ self.is_primary_stem_only_var.set(loaded_setting['is_primary_stem_only'])
832
+ self.is_secondary_stem_only_var.set(loaded_setting['is_secondary_stem_only'])
833
+ self.is_testing_audio_var.set(loaded_setting['is_testing_audio'])
834
+ self.is_add_model_name_var.set(loaded_setting['is_add_model_name'])
835
+ self.is_accept_any_input_var.set(loaded_setting["is_accept_any_input"])
836
+ self.is_task_complete_var.set(loaded_setting['is_task_complete'])
837
+ self.is_create_model_folder_var.set(loaded_setting['is_create_model_folder'])
838
+ self.mp3_bit_set_var.set(loaded_setting['mp3_bit_set'])
839
+ self.save_format_var.set(loaded_setting['save_format'])
840
+ self.wav_type_set_var.set(loaded_setting['wav_type_set'])
841
+ self.user_code_var.set(loaded_setting['user_code'])
842
+
843
+ self.is_gpu_conversion_var.set(loaded_setting['is_gpu_conversion'])
844
+ self.is_normalization_var.set(loaded_setting['is_normalization'])
845
+ self.help_hints_var.set(loaded_setting['help_hints_var'])
846
+
847
+ self.model_sample_mode_var.set(loaded_setting['model_sample_mode'])
848
+ self.model_sample_mode_duration_var.set(loaded_setting['model_sample_mode_duration'])
849
+
850
+
851
+ root = FakeRoot()
852
+ root.load_saved_settings(DEFAULT_DATA)
app.py CHANGED
@@ -1,2 +1,3 @@
1
  import os
 
2
  os.system("python webUI.py")
 
1
  import os
2
+
3
  os.system("python webUI.py")
demucs/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
demucs/__main__.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ import sys
10
+ import time
11
+ from dataclasses import dataclass, field
12
+ from fractions import Fraction
13
+
14
+ import torch as th
15
+ from torch import distributed, nn
16
+ from torch.nn.parallel.distributed import DistributedDataParallel
17
+
18
+ from .augment import FlipChannels, FlipSign, Remix, Shift
19
+ from .compressed import StemsSet, build_musdb_metadata, get_musdb_tracks
20
+ from .model import Demucs
21
+ from .parser import get_name, get_parser
22
+ from .raw import Rawset
23
+ from .tasnet import ConvTasNet
24
+ from .test import evaluate
25
+ from .train import train_model, validate_model
26
+ from .utils import human_seconds, load_model, save_model, sizeof_fmt
27
+
28
+
29
+ @dataclass
30
+ class SavedState:
31
+ metrics: list = field(default_factory=list)
32
+ last_state: dict = None
33
+ best_state: dict = None
34
+ optimizer: dict = None
35
+
36
+
37
+ def main():
38
+ parser = get_parser()
39
+ args = parser.parse_args()
40
+ name = get_name(parser, args)
41
+ print(f"Experiment {name}")
42
+
43
+ if args.musdb is None and args.rank == 0:
44
+ print(
45
+ "You must provide the path to the MusDB dataset with the --musdb flag. "
46
+ "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.",
47
+ file=sys.stderr)
48
+ sys.exit(1)
49
+
50
+ eval_folder = args.evals / name
51
+ eval_folder.mkdir(exist_ok=True, parents=True)
52
+ args.logs.mkdir(exist_ok=True)
53
+ metrics_path = args.logs / f"{name}.json"
54
+ eval_folder.mkdir(exist_ok=True, parents=True)
55
+ args.checkpoints.mkdir(exist_ok=True, parents=True)
56
+ args.models.mkdir(exist_ok=True, parents=True)
57
+
58
+ if args.device is None:
59
+ device = "cpu"
60
+ if th.cuda.is_available():
61
+ device = "cuda"
62
+ else:
63
+ device = args.device
64
+
65
+ th.manual_seed(args.seed)
66
+ # Prevents too many threads to be started when running `museval` as it can be quite
67
+ # inefficient on NUMA architectures.
68
+ os.environ["OMP_NUM_THREADS"] = "1"
69
+
70
+ if args.world_size > 1:
71
+ if device != "cuda" and args.rank == 0:
72
+ print("Error: distributed training is only available with cuda device", file=sys.stderr)
73
+ sys.exit(1)
74
+ th.cuda.set_device(args.rank % th.cuda.device_count())
75
+ distributed.init_process_group(backend="nccl",
76
+ init_method="tcp://" + args.master,
77
+ rank=args.rank,
78
+ world_size=args.world_size)
79
+
80
+ checkpoint = args.checkpoints / f"{name}.th"
81
+ checkpoint_tmp = args.checkpoints / f"{name}.th.tmp"
82
+ if args.restart and checkpoint.exists():
83
+ checkpoint.unlink()
84
+
85
+ if args.test:
86
+ args.epochs = 1
87
+ args.repeat = 0
88
+ model = load_model(args.models / args.test)
89
+ elif args.tasnet:
90
+ model = ConvTasNet(audio_channels=args.audio_channels, samplerate=args.samplerate, X=args.X)
91
+ else:
92
+ model = Demucs(
93
+ audio_channels=args.audio_channels,
94
+ channels=args.channels,
95
+ context=args.context,
96
+ depth=args.depth,
97
+ glu=args.glu,
98
+ growth=args.growth,
99
+ kernel_size=args.kernel_size,
100
+ lstm_layers=args.lstm_layers,
101
+ rescale=args.rescale,
102
+ rewrite=args.rewrite,
103
+ sources=4,
104
+ stride=args.conv_stride,
105
+ upsample=args.upsample,
106
+ samplerate=args.samplerate
107
+ )
108
+ model.to(device)
109
+ if args.show:
110
+ print(model)
111
+ size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
112
+ print(f"Model size {size}")
113
+ return
114
+
115
+ optimizer = th.optim.Adam(model.parameters(), lr=args.lr)
116
+
117
+ try:
118
+ saved = th.load(checkpoint, map_location='cpu')
119
+ except IOError:
120
+ saved = SavedState()
121
+ else:
122
+ model.load_state_dict(saved.last_state)
123
+ optimizer.load_state_dict(saved.optimizer)
124
+
125
+ if args.save_model:
126
+ if args.rank == 0:
127
+ model.to("cpu")
128
+ model.load_state_dict(saved.best_state)
129
+ save_model(model, args.models / f"{name}.th")
130
+ return
131
+
132
+ if args.rank == 0:
133
+ done = args.logs / f"{name}.done"
134
+ if done.exists():
135
+ done.unlink()
136
+
137
+ if args.augment:
138
+ augment = nn.Sequential(FlipSign(), FlipChannels(), Shift(args.data_stride),
139
+ Remix(group_size=args.remix_group_size)).to(device)
140
+ else:
141
+ augment = Shift(args.data_stride)
142
+
143
+ if args.mse:
144
+ criterion = nn.MSELoss()
145
+ else:
146
+ criterion = nn.L1Loss()
147
+
148
+ # Setting number of samples so that all convolution windows are full.
149
+ # Prevents hard to debug mistake with the prediction being shifted compared
150
+ # to the input mixture.
151
+ samples = model.valid_length(args.samples)
152
+ print(f"Number of training samples adjusted to {samples}")
153
+
154
+ if args.raw:
155
+ train_set = Rawset(args.raw / "train",
156
+ samples=samples + args.data_stride,
157
+ channels=args.audio_channels,
158
+ streams=[0, 1, 2, 3, 4],
159
+ stride=args.data_stride)
160
+
161
+ valid_set = Rawset(args.raw / "valid", channels=args.audio_channels)
162
+ else:
163
+ if not args.metadata.is_file() and args.rank == 0:
164
+ build_musdb_metadata(args.metadata, args.musdb, args.workers)
165
+ if args.world_size > 1:
166
+ distributed.barrier()
167
+ metadata = json.load(open(args.metadata))
168
+ duration = Fraction(samples + args.data_stride, args.samplerate)
169
+ stride = Fraction(args.data_stride, args.samplerate)
170
+ train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"),
171
+ metadata,
172
+ duration=duration,
173
+ stride=stride,
174
+ samplerate=args.samplerate,
175
+ channels=args.audio_channels)
176
+ valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"),
177
+ metadata,
178
+ samplerate=args.samplerate,
179
+ channels=args.audio_channels)
180
+
181
+ best_loss = float("inf")
182
+ for epoch, metrics in enumerate(saved.metrics):
183
+ print(f"Epoch {epoch:03d}: "
184
+ f"train={metrics['train']:.8f} "
185
+ f"valid={metrics['valid']:.8f} "
186
+ f"best={metrics['best']:.4f} "
187
+ f"duration={human_seconds(metrics['duration'])}")
188
+ best_loss = metrics['best']
189
+
190
+ if args.world_size > 1:
191
+ dmodel = DistributedDataParallel(model,
192
+ device_ids=[th.cuda.current_device()],
193
+ output_device=th.cuda.current_device())
194
+ else:
195
+ dmodel = model
196
+
197
+ for epoch in range(len(saved.metrics), args.epochs):
198
+ begin = time.time()
199
+ model.train()
200
+ train_loss = train_model(epoch,
201
+ train_set,
202
+ dmodel,
203
+ criterion,
204
+ optimizer,
205
+ augment,
206
+ batch_size=args.batch_size,
207
+ device=device,
208
+ repeat=args.repeat,
209
+ seed=args.seed,
210
+ workers=args.workers,
211
+ world_size=args.world_size)
212
+ model.eval()
213
+ valid_loss = validate_model(epoch,
214
+ valid_set,
215
+ model,
216
+ criterion,
217
+ device=device,
218
+ rank=args.rank,
219
+ split=args.split_valid,
220
+ world_size=args.world_size)
221
+
222
+ duration = time.time() - begin
223
+ if valid_loss < best_loss:
224
+ best_loss = valid_loss
225
+ saved.best_state = {
226
+ key: value.to("cpu").clone()
227
+ for key, value in model.state_dict().items()
228
+ }
229
+ saved.metrics.append({
230
+ "train": train_loss,
231
+ "valid": valid_loss,
232
+ "best": best_loss,
233
+ "duration": duration
234
+ })
235
+ if args.rank == 0:
236
+ json.dump(saved.metrics, open(metrics_path, "w"))
237
+
238
+ saved.last_state = model.state_dict()
239
+ saved.optimizer = optimizer.state_dict()
240
+ if args.rank == 0 and not args.test:
241
+ th.save(saved, checkpoint_tmp)
242
+ checkpoint_tmp.rename(checkpoint)
243
+
244
+ print(f"Epoch {epoch:03d}: "
245
+ f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} "
246
+ f"duration={human_seconds(duration)}")
247
+
248
+ del dmodel
249
+ model.load_state_dict(saved.best_state)
250
+ if args.eval_cpu:
251
+ device = "cpu"
252
+ model.to(device)
253
+ model.eval()
254
+ evaluate(model,
255
+ args.musdb,
256
+ eval_folder,
257
+ rank=args.rank,
258
+ world_size=args.world_size,
259
+ device=device,
260
+ save=args.save,
261
+ split=args.split_valid,
262
+ shifts=args.shifts,
263
+ workers=args.eval_workers)
264
+ model.to("cpu")
265
+ save_model(model, args.models / f"{name}.th")
266
+ if args.rank == 0:
267
+ print("done")
268
+ done.write_text("done")
269
+
270
+
271
+ if __name__ == "__main__":
272
+ main()
demucs/apply.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Code to apply a model to a mix. It will handle chunking with overlaps and
8
+ inteprolation between chunks, as well as the "shift trick".
9
+ """
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ import random
12
+ import typing as tp
13
+ from multiprocessing import Process,Queue,Pipe
14
+
15
+ import torch as th
16
+ from torch import nn
17
+ from torch.nn import functional as F
18
+ import tqdm
19
+ import tkinter as tk
20
+
21
+ from .demucs import Demucs
22
+ from .hdemucs import HDemucs
23
+ from .utils import center_trim, DummyPoolExecutor
24
+
25
+ Model = tp.Union[Demucs, HDemucs]
26
+
27
+ progress_bar_num = 0
28
+
29
+ class BagOfModels(nn.Module):
30
+ def __init__(self, models: tp.List[Model],
31
+ weights: tp.Optional[tp.List[tp.List[float]]] = None,
32
+ segment: tp.Optional[float] = None):
33
+ """
34
+ Represents a bag of models with specific weights.
35
+ You should call `apply_model` rather than calling directly the forward here for
36
+ optimal performance.
37
+
38
+ Args:
39
+ models (list[nn.Module]): list of Demucs/HDemucs models.
40
+ weights (list[list[float]]): list of weights. If None, assumed to
41
+ be all ones, otherwise it should be a list of N list (N number of models),
42
+ each containing S floats (S number of sources).
43
+ segment (None or float): overrides the `segment` attribute of each model
44
+ (this is performed inplace, be careful if you reuse the models passed).
45
+ """
46
+
47
+ super().__init__()
48
+ assert len(models) > 0
49
+ first = models[0]
50
+ for other in models:
51
+ assert other.sources == first.sources
52
+ assert other.samplerate == first.samplerate
53
+ assert other.audio_channels == first.audio_channels
54
+ if segment is not None:
55
+ other.segment = segment
56
+
57
+ self.audio_channels = first.audio_channels
58
+ self.samplerate = first.samplerate
59
+ self.sources = first.sources
60
+ self.models = nn.ModuleList(models)
61
+
62
+ if weights is None:
63
+ weights = [[1. for _ in first.sources] for _ in models]
64
+ else:
65
+ assert len(weights) == len(models)
66
+ for weight in weights:
67
+ assert len(weight) == len(first.sources)
68
+ self.weights = weights
69
+
70
+ def forward(self, x):
71
+ raise NotImplementedError("Call `apply_model` on this.")
72
+
73
+ class TensorChunk:
74
+ def __init__(self, tensor, offset=0, length=None):
75
+ total_length = tensor.shape[-1]
76
+ assert offset >= 0
77
+ assert offset < total_length
78
+
79
+ if length is None:
80
+ length = total_length - offset
81
+ else:
82
+ length = min(total_length - offset, length)
83
+
84
+ if isinstance(tensor, TensorChunk):
85
+ self.tensor = tensor.tensor
86
+ self.offset = offset + tensor.offset
87
+ else:
88
+ self.tensor = tensor
89
+ self.offset = offset
90
+ self.length = length
91
+ self.device = tensor.device
92
+
93
+ @property
94
+ def shape(self):
95
+ shape = list(self.tensor.shape)
96
+ shape[-1] = self.length
97
+ return shape
98
+
99
+ def padded(self, target_length):
100
+ delta = target_length - self.length
101
+ total_length = self.tensor.shape[-1]
102
+ assert delta >= 0
103
+
104
+ start = self.offset - delta // 2
105
+ end = start + target_length
106
+
107
+ correct_start = max(0, start)
108
+ correct_end = min(total_length, end)
109
+
110
+ pad_left = correct_start - start
111
+ pad_right = end - correct_end
112
+
113
+ out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
114
+ assert out.shape[-1] == target_length
115
+ return out
116
+
117
+ def tensor_chunk(tensor_or_chunk):
118
+ if isinstance(tensor_or_chunk, TensorChunk):
119
+ return tensor_or_chunk
120
+ else:
121
+ assert isinstance(tensor_or_chunk, th.Tensor)
122
+ return TensorChunk(tensor_or_chunk)
123
+
124
+ def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1., static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None):
125
+ """
126
+ Apply model to a given mixture.
127
+
128
+ Args:
129
+ shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
130
+ and apply the oppositve shift to the output. This is repeated `shifts` time and
131
+ all predictions are averaged. This effectively makes the model time equivariant
132
+ and improves SDR by up to 0.2 points.
133
+ split (bool): if True, the input will be broken down in 8 seconds extracts
134
+ and predictions will be performed individually on each and concatenated.
135
+ Useful for model with large memory footprint like Tasnet.
136
+ progress (bool): if True, show a progress bar (requires split=True)
137
+ device (torch.device, str, or None): if provided, device on which to
138
+ execute the computation, otherwise `mix.device` is assumed.
139
+ When `device` is different from `mix.device`, only local computations will
140
+ be on `device`, while the entire tracks will be stored on `mix.device`.
141
+ """
142
+
143
+ global fut_length
144
+ global bag_num
145
+ global prog_bar
146
+
147
+ if device is None:
148
+ device = mix.device
149
+ else:
150
+ device = th.device(device)
151
+ if pool is None:
152
+ if num_workers > 0 and device.type == 'cpu':
153
+ pool = ThreadPoolExecutor(num_workers)
154
+ else:
155
+ pool = DummyPoolExecutor()
156
+
157
+ kwargs = {
158
+ 'shifts': shifts,
159
+ 'split': split,
160
+ 'overlap': overlap,
161
+ 'transition_power': transition_power,
162
+ 'progress': progress,
163
+ 'device': device,
164
+ 'pool': pool,
165
+ 'set_progress_bar': set_progress_bar,
166
+ 'static_shifts': static_shifts,
167
+ }
168
+
169
+ if isinstance(model, BagOfModels):
170
+ # Special treatment for bag of model.
171
+ # We explicitely apply multiple times `apply_model` so that the random shifts
172
+ # are different for each model.
173
+
174
+ estimates = 0
175
+ totals = [0] * len(model.sources)
176
+ bag_num = len(model.models)
177
+ fut_length = 0
178
+ prog_bar = 0
179
+ current_model = 0 #(bag_num + 1)
180
+ for sub_model, weight in zip(model.models, model.weights):
181
+ original_model_device = next(iter(sub_model.parameters())).device
182
+ sub_model.to(device)
183
+ fut_length += fut_length
184
+ current_model += 1
185
+ out = apply_model(sub_model, mix, **kwargs)
186
+ sub_model.to(original_model_device)
187
+ for k, inst_weight in enumerate(weight):
188
+ out[:, k, :, :] *= inst_weight
189
+ totals[k] += inst_weight
190
+ estimates += out
191
+ del out
192
+
193
+ for k in range(estimates.shape[1]):
194
+ estimates[:, k, :, :] /= totals[k]
195
+ return estimates
196
+
197
+ model.to(device)
198
+ model.eval()
199
+ assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
200
+ batch, channels, length = mix.shape
201
+
202
+ if shifts:
203
+ kwargs['shifts'] = 0
204
+ max_shift = int(0.5 * model.samplerate)
205
+ mix = tensor_chunk(mix)
206
+ padded_mix = mix.padded(length + 2 * max_shift)
207
+ out = 0
208
+ for _ in range(shifts):
209
+ offset = random.randint(0, max_shift)
210
+ shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
211
+ shifted_out = apply_model(model, shifted, **kwargs)
212
+ out += shifted_out[..., max_shift - offset:]
213
+ out /= shifts
214
+ return out
215
+ elif split:
216
+ kwargs['split'] = False
217
+ out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
218
+ sum_weight = th.zeros(length, device=mix.device)
219
+ segment = int(model.samplerate * model.segment)
220
+ stride = int((1 - overlap) * segment)
221
+ offsets = range(0, length, stride)
222
+ scale = float(format(stride / model.samplerate, ".2f"))
223
+ # We start from a triangle shaped weight, with maximal weight in the middle
224
+ # of the segment. Then we normalize and take to the power `transition_power`.
225
+ # Large values of transition power will lead to sharper transitions.
226
+ weight = th.cat([th.arange(1, segment // 2 + 1, device=device),
227
+ th.arange(segment - segment // 2, 0, -1, device=device)])
228
+ assert len(weight) == segment
229
+ # If the overlap < 50%, this will translate to linear transition when
230
+ # transition_power is 1.
231
+ weight = (weight / weight.max())**transition_power
232
+ futures = []
233
+ for offset in offsets:
234
+ chunk = TensorChunk(mix, offset, segment)
235
+ future = pool.submit(apply_model, model, chunk, **kwargs)
236
+ futures.append((future, offset))
237
+ offset += segment
238
+ if progress:
239
+ futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
240
+ for future, offset in futures:
241
+ if set_progress_bar:
242
+ fut_length = (len(futures) * bag_num * static_shifts)
243
+ prog_bar += 1
244
+ set_progress_bar(0.1, (0.8/fut_length*prog_bar))
245
+ chunk_out = future.result()
246
+ chunk_length = chunk_out.shape[-1]
247
+ out[..., offset:offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device)
248
+ sum_weight[offset:offset + segment] += weight[:chunk_length].to(mix.device)
249
+ assert sum_weight.min() > 0
250
+ out /= sum_weight
251
+ return out
252
+ else:
253
+ if hasattr(model, 'valid_length'):
254
+ valid_length = model.valid_length(length)
255
+ else:
256
+ valid_length = length
257
+ mix = tensor_chunk(mix)
258
+ padded_mix = mix.padded(valid_length).to(device)
259
+ with th.no_grad():
260
+ out = model(padded_mix)
261
+ return center_trim(out, length)
262
+
263
+ def demucs_segments(demucs_segment, demucs_model):
264
+
265
+ if demucs_segment == 'Default':
266
+ segment = None
267
+ if isinstance(demucs_model, BagOfModels):
268
+ if segment is not None:
269
+ for sub in demucs_model.models:
270
+ sub.segment = segment
271
+ else:
272
+ if segment is not None:
273
+ sub.segment = segment
274
+ else:
275
+ try:
276
+ segment = int(demucs_segment)
277
+ if isinstance(demucs_model, BagOfModels):
278
+ if segment is not None:
279
+ for sub in demucs_model.models:
280
+ sub.segment = segment
281
+ else:
282
+ if segment is not None:
283
+ sub.segment = segment
284
+ except:
285
+ segment = None
286
+ if isinstance(demucs_model, BagOfModels):
287
+ if segment is not None:
288
+ for sub in demucs_model.models:
289
+ sub.segment = segment
290
+ else:
291
+ if segment is not None:
292
+ sub.segment = segment
293
+
294
+ return demucs_model
demucs/demucs.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import typing as tp
9
+
10
+ import julius
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+ from .states import capture_init
16
+ from .utils import center_trim, unfold
17
+
18
+
19
+ class BLSTM(nn.Module):
20
+ """
21
+ BiLSTM with same hidden units as input dim.
22
+ If `max_steps` is not None, input will be splitting in overlapping
23
+ chunks and the LSTM applied separately on each chunk.
24
+ """
25
+ def __init__(self, dim, layers=1, max_steps=None, skip=False):
26
+ super().__init__()
27
+ assert max_steps is None or max_steps % 4 == 0
28
+ self.max_steps = max_steps
29
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
30
+ self.linear = nn.Linear(2 * dim, dim)
31
+ self.skip = skip
32
+
33
+ def forward(self, x):
34
+ B, C, T = x.shape
35
+ y = x
36
+ framed = False
37
+ if self.max_steps is not None and T > self.max_steps:
38
+ width = self.max_steps
39
+ stride = width // 2
40
+ frames = unfold(x, width, stride)
41
+ nframes = frames.shape[2]
42
+ framed = True
43
+ x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
44
+
45
+ x = x.permute(2, 0, 1)
46
+
47
+ x = self.lstm(x)[0]
48
+ x = self.linear(x)
49
+ x = x.permute(1, 2, 0)
50
+ if framed:
51
+ out = []
52
+ frames = x.reshape(B, -1, C, width)
53
+ limit = stride // 2
54
+ for k in range(nframes):
55
+ if k == 0:
56
+ out.append(frames[:, k, :, :-limit])
57
+ elif k == nframes - 1:
58
+ out.append(frames[:, k, :, limit:])
59
+ else:
60
+ out.append(frames[:, k, :, limit:-limit])
61
+ out = torch.cat(out, -1)
62
+ out = out[..., :T]
63
+ x = out
64
+ if self.skip:
65
+ x = x + y
66
+ return x
67
+
68
+
69
+ def rescale_conv(conv, reference):
70
+ """Rescale initial weight scale. It is unclear why it helps but it certainly does.
71
+ """
72
+ std = conv.weight.std().detach()
73
+ scale = (std / reference)**0.5
74
+ conv.weight.data /= scale
75
+ if conv.bias is not None:
76
+ conv.bias.data /= scale
77
+
78
+
79
+ def rescale_module(module, reference):
80
+ for sub in module.modules():
81
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
82
+ rescale_conv(sub, reference)
83
+
84
+
85
+ class LayerScale(nn.Module):
86
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
87
+ This rescales diagonaly residual outputs close to 0 initially, then learnt.
88
+ """
89
+ def __init__(self, channels: int, init: float = 0):
90
+ super().__init__()
91
+ self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
92
+ self.scale.data[:] = init
93
+
94
+ def forward(self, x):
95
+ return self.scale[:, None] * x
96
+
97
+
98
+ class DConv(nn.Module):
99
+ """
100
+ New residual branches in each encoder layer.
101
+ This alternates dilated convolutions, potentially with LSTMs and attention.
102
+ Also before entering each residual branch, dimension is projected on a smaller subspace,
103
+ e.g. of dim `channels // compress`.
104
+ """
105
+ def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
106
+ norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
107
+ kernel=3, dilate=True):
108
+ """
109
+ Args:
110
+ channels: input/output channels for residual branch.
111
+ compress: amount of channel compression inside the branch.
112
+ depth: number of layers in the residual branch. Each layer has its own
113
+ projection, and potentially LSTM and attention.
114
+ init: initial scale for LayerNorm.
115
+ norm: use GroupNorm.
116
+ attn: use LocalAttention.
117
+ heads: number of heads for the LocalAttention.
118
+ ndecay: number of decay controls in the LocalAttention.
119
+ lstm: use LSTM.
120
+ gelu: Use GELU activation.
121
+ kernel: kernel size for the (dilated) convolutions.
122
+ dilate: if true, use dilation, increasing with the depth.
123
+ """
124
+
125
+ super().__init__()
126
+ assert kernel % 2 == 1
127
+ self.channels = channels
128
+ self.compress = compress
129
+ self.depth = abs(depth)
130
+ dilate = depth > 0
131
+
132
+ norm_fn: tp.Callable[[int], nn.Module]
133
+ norm_fn = lambda d: nn.Identity() # noqa
134
+ if norm:
135
+ norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
136
+
137
+ hidden = int(channels / compress)
138
+
139
+ act: tp.Type[nn.Module]
140
+ if gelu:
141
+ act = nn.GELU
142
+ else:
143
+ act = nn.ReLU
144
+
145
+ self.layers = nn.ModuleList([])
146
+ for d in range(self.depth):
147
+ dilation = 2 ** d if dilate else 1
148
+ padding = dilation * (kernel // 2)
149
+ mods = [
150
+ nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
151
+ norm_fn(hidden), act(),
152
+ nn.Conv1d(hidden, 2 * channels, 1),
153
+ norm_fn(2 * channels), nn.GLU(1),
154
+ LayerScale(channels, init),
155
+ ]
156
+ if attn:
157
+ mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
158
+ if lstm:
159
+ mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
160
+ layer = nn.Sequential(*mods)
161
+ self.layers.append(layer)
162
+
163
+ def forward(self, x):
164
+ for layer in self.layers:
165
+ x = x + layer(x)
166
+ return x
167
+
168
+
169
+ class LocalState(nn.Module):
170
+ """Local state allows to have attention based only on data (no positional embedding),
171
+ but while setting a constraint on the time window (e.g. decaying penalty term).
172
+
173
+ Also a failed experiments with trying to provide some frequency based attention.
174
+ """
175
+ def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
176
+ super().__init__()
177
+ assert channels % heads == 0, (channels, heads)
178
+ self.heads = heads
179
+ self.nfreqs = nfreqs
180
+ self.ndecay = ndecay
181
+ self.content = nn.Conv1d(channels, channels, 1)
182
+ self.query = nn.Conv1d(channels, channels, 1)
183
+ self.key = nn.Conv1d(channels, channels, 1)
184
+ if nfreqs:
185
+ self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
186
+ if ndecay:
187
+ self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
188
+ # Initialize decay close to zero (there is a sigmoid), for maximum initial window.
189
+ self.query_decay.weight.data *= 0.01
190
+ assert self.query_decay.bias is not None # stupid type checker
191
+ self.query_decay.bias.data[:] = -2
192
+ self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
193
+
194
+ def forward(self, x):
195
+ B, C, T = x.shape
196
+ heads = self.heads
197
+ indexes = torch.arange(T, device=x.device, dtype=x.dtype)
198
+ # left index are keys, right index are queries
199
+ delta = indexes[:, None] - indexes[None, :]
200
+
201
+ queries = self.query(x).view(B, heads, -1, T)
202
+ keys = self.key(x).view(B, heads, -1, T)
203
+ # t are keys, s are queries
204
+ dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
205
+ dots /= keys.shape[2]**0.5
206
+ if self.nfreqs:
207
+ periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
208
+ freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
209
+ freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
210
+ dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
211
+ if self.ndecay:
212
+ decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
213
+ decay_q = self.query_decay(x).view(B, heads, -1, T)
214
+ decay_q = torch.sigmoid(decay_q) / 2
215
+ decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
216
+ dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
217
+
218
+ # Kill self reference.
219
+ dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
220
+ weights = torch.softmax(dots, dim=2)
221
+
222
+ content = self.content(x).view(B, heads, -1, T)
223
+ result = torch.einsum("bhts,bhct->bhcs", weights, content)
224
+ if self.nfreqs:
225
+ time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
226
+ result = torch.cat([result, time_sig], 2)
227
+ result = result.reshape(B, -1, T)
228
+ return x + self.proj(result)
229
+
230
+
231
+ class Demucs(nn.Module):
232
+ @capture_init
233
+ def __init__(self,
234
+ sources,
235
+ # Channels
236
+ audio_channels=2,
237
+ channels=64,
238
+ growth=2.,
239
+ # Main structure
240
+ depth=6,
241
+ rewrite=True,
242
+ lstm_layers=0,
243
+ # Convolutions
244
+ kernel_size=8,
245
+ stride=4,
246
+ context=1,
247
+ # Activations
248
+ gelu=True,
249
+ glu=True,
250
+ # Normalization
251
+ norm_starts=4,
252
+ norm_groups=4,
253
+ # DConv residual branch
254
+ dconv_mode=1,
255
+ dconv_depth=2,
256
+ dconv_comp=4,
257
+ dconv_attn=4,
258
+ dconv_lstm=4,
259
+ dconv_init=1e-4,
260
+ # Pre/post processing
261
+ normalize=True,
262
+ resample=True,
263
+ # Weight init
264
+ rescale=0.1,
265
+ # Metadata
266
+ samplerate=44100,
267
+ segment=4 * 10):
268
+ """
269
+ Args:
270
+ sources (list[str]): list of source names
271
+ audio_channels (int): stereo or mono
272
+ channels (int): first convolution channels
273
+ depth (int): number of encoder/decoder layers
274
+ growth (float): multiply (resp divide) number of channels by that
275
+ for each layer of the encoder (resp decoder)
276
+ depth (int): number of layers in the encoder and in the decoder.
277
+ rewrite (bool): add 1x1 convolution to each layer.
278
+ lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
279
+ by default, as this is now replaced by the smaller and faster small LSTMs
280
+ in the DConv branches.
281
+ kernel_size (int): kernel size for convolutions
282
+ stride (int): stride for convolutions
283
+ context (int): kernel size of the convolution in the
284
+ decoder before the transposed convolution. If > 1,
285
+ will provide some context from neighboring time steps.
286
+ gelu: use GELU activation function.
287
+ glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
288
+ norm_starts: layer at which group norm starts being used.
289
+ decoder layers are numbered in reverse order.
290
+ norm_groups: number of groups for group norm.
291
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
292
+ dconv_depth: depth of residual DConv branch.
293
+ dconv_comp: compression of DConv branch.
294
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
295
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
296
+ dconv_init: initial scale for the DConv branch LayerScale.
297
+ normalize (bool): normalizes the input audio on the fly, and scales back
298
+ the output by the same amount.
299
+ resample (bool): upsample x2 the input and downsample /2 the output.
300
+ rescale (int): rescale initial weights of convolutions
301
+ to get their standard deviation closer to `rescale`.
302
+ samplerate (int): stored as meta information for easing
303
+ future evaluations of the model.
304
+ segment (float): duration of the chunks of audio to ideally evaluate the model on.
305
+ This is used by `demucs.apply.apply_model`.
306
+ """
307
+
308
+ super().__init__()
309
+ self.audio_channels = audio_channels
310
+ self.sources = sources
311
+ self.kernel_size = kernel_size
312
+ self.context = context
313
+ self.stride = stride
314
+ self.depth = depth
315
+ self.resample = resample
316
+ self.channels = channels
317
+ self.normalize = normalize
318
+ self.samplerate = samplerate
319
+ self.segment = segment
320
+ self.encoder = nn.ModuleList()
321
+ self.decoder = nn.ModuleList()
322
+ self.skip_scales = nn.ModuleList()
323
+
324
+ if glu:
325
+ activation = nn.GLU(dim=1)
326
+ ch_scale = 2
327
+ else:
328
+ activation = nn.ReLU()
329
+ ch_scale = 1
330
+ if gelu:
331
+ act2 = nn.GELU
332
+ else:
333
+ act2 = nn.ReLU
334
+
335
+ in_channels = audio_channels
336
+ padding = 0
337
+ for index in range(depth):
338
+ norm_fn = lambda d: nn.Identity() # noqa
339
+ if index >= norm_starts:
340
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
341
+
342
+ encode = []
343
+ encode += [
344
+ nn.Conv1d(in_channels, channels, kernel_size, stride),
345
+ norm_fn(channels),
346
+ act2(),
347
+ ]
348
+ attn = index >= dconv_attn
349
+ lstm = index >= dconv_lstm
350
+ if dconv_mode & 1:
351
+ encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
352
+ compress=dconv_comp, attn=attn, lstm=lstm)]
353
+ if rewrite:
354
+ encode += [
355
+ nn.Conv1d(channels, ch_scale * channels, 1),
356
+ norm_fn(ch_scale * channels), activation]
357
+ self.encoder.append(nn.Sequential(*encode))
358
+
359
+ decode = []
360
+ if index > 0:
361
+ out_channels = in_channels
362
+ else:
363
+ out_channels = len(self.sources) * audio_channels
364
+ if rewrite:
365
+ decode += [
366
+ nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
367
+ norm_fn(ch_scale * channels), activation]
368
+ if dconv_mode & 2:
369
+ decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
370
+ compress=dconv_comp, attn=attn, lstm=lstm)]
371
+ decode += [nn.ConvTranspose1d(channels, out_channels,
372
+ kernel_size, stride, padding=padding)]
373
+ if index > 0:
374
+ decode += [norm_fn(out_channels), act2()]
375
+ self.decoder.insert(0, nn.Sequential(*decode))
376
+ in_channels = channels
377
+ channels = int(growth * channels)
378
+
379
+ channels = in_channels
380
+ if lstm_layers:
381
+ self.lstm = BLSTM(channels, lstm_layers)
382
+ else:
383
+ self.lstm = None
384
+
385
+ if rescale:
386
+ rescale_module(self, reference=rescale)
387
+
388
+ def valid_length(self, length):
389
+ """
390
+ Return the nearest valid length to use with the model so that
391
+ there is no time steps left over in a convolution, e.g. for all
392
+ layers, size of the input - kernel_size % stride = 0.
393
+
394
+ Note that input are automatically padded if necessary to ensure that the output
395
+ has the same length as the input.
396
+ """
397
+ if self.resample:
398
+ length *= 2
399
+
400
+ for _ in range(self.depth):
401
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
402
+ length = max(1, length)
403
+
404
+ for idx in range(self.depth):
405
+ length = (length - 1) * self.stride + self.kernel_size
406
+
407
+ if self.resample:
408
+ length = math.ceil(length / 2)
409
+ return int(length)
410
+
411
+ def forward(self, mix):
412
+ x = mix
413
+ length = x.shape[-1]
414
+
415
+ if self.normalize:
416
+ mono = mix.mean(dim=1, keepdim=True)
417
+ mean = mono.mean(dim=-1, keepdim=True)
418
+ std = mono.std(dim=-1, keepdim=True)
419
+ x = (x - mean) / (1e-5 + std)
420
+ else:
421
+ mean = 0
422
+ std = 1
423
+
424
+ delta = self.valid_length(length) - length
425
+ x = F.pad(x, (delta // 2, delta - delta // 2))
426
+
427
+ if self.resample:
428
+ x = julius.resample_frac(x, 1, 2)
429
+
430
+ saved = []
431
+ for encode in self.encoder:
432
+ x = encode(x)
433
+ saved.append(x)
434
+
435
+ if self.lstm:
436
+ x = self.lstm(x)
437
+
438
+ for decode in self.decoder:
439
+ skip = saved.pop(-1)
440
+ skip = center_trim(skip, x)
441
+ x = decode(x + skip)
442
+
443
+ if self.resample:
444
+ x = julius.resample_frac(x, 2, 1)
445
+ x = x * std + mean
446
+ x = center_trim(x, length)
447
+ x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
448
+ return x
449
+
450
+ def load_state_dict(self, state, strict=True):
451
+ # fix a mismatch with previous generation Demucs models.
452
+ for idx in range(self.depth):
453
+ for a in ['encoder', 'decoder']:
454
+ for b in ['bias', 'weight']:
455
+ new = f'{a}.{idx}.3.{b}'
456
+ old = f'{a}.{idx}.2.{b}'
457
+ if old in state and new not in state:
458
+ state[new] = state.pop(old)
459
+ super().load_state_dict(state, strict=strict)
demucs/filtering.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch import Tensor
5
+ from torch.utils.data import DataLoader
6
+
7
+ def atan2(y, x):
8
+ r"""Element-wise arctangent function of y/x.
9
+ Returns a new tensor with signed angles in radians.
10
+ It is an alternative implementation of torch.atan2
11
+
12
+ Args:
13
+ y (Tensor): First input tensor
14
+ x (Tensor): Second input tensor [shape=y.shape]
15
+
16
+ Returns:
17
+ Tensor: [shape=y.shape].
18
+ """
19
+ pi = 2 * torch.asin(torch.tensor(1.0))
20
+ x += ((x == 0) & (y == 0)) * 1.0
21
+ out = torch.atan(y / x)
22
+ out += ((y >= 0) & (x < 0)) * pi
23
+ out -= ((y < 0) & (x < 0)) * pi
24
+ out *= 1 - ((y > 0) & (x == 0)) * 1.0
25
+ out += ((y > 0) & (x == 0)) * (pi / 2)
26
+ out *= 1 - ((y < 0) & (x == 0)) * 1.0
27
+ out += ((y < 0) & (x == 0)) * (-pi / 2)
28
+ return out
29
+
30
+
31
+ # Define basic complex operations on torch.Tensor objects whose last dimension
32
+ # consists in the concatenation of the real and imaginary parts.
33
+
34
+
35
+ def _norm(x: torch.Tensor) -> torch.Tensor:
36
+ r"""Computes the norm value of a torch Tensor, assuming that it
37
+ comes as real and imaginary part in its last dimension.
38
+
39
+ Args:
40
+ x (Tensor): Input Tensor of shape [shape=(..., 2)]
41
+
42
+ Returns:
43
+ Tensor: shape as x excluding the last dimension.
44
+ """
45
+ return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2
46
+
47
+
48
+ def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
49
+ """Element-wise multiplication of two complex Tensors described
50
+ through their real and imaginary parts.
51
+ The result is added to the `out` tensor"""
52
+
53
+ # check `out` and allocate it if needed
54
+ target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
55
+ if out is None or out.shape != target_shape:
56
+ out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
57
+ if out is a:
58
+ real_a = a[..., 0]
59
+ out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1])
60
+ out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0])
61
+ else:
62
+ out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1])
63
+ out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0])
64
+ return out
65
+
66
+
67
+ def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
68
+ """Element-wise multiplication of two complex Tensors described
69
+ through their real and imaginary parts
70
+ can work in place in case out is a only"""
71
+ target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
72
+ if out is None or out.shape != target_shape:
73
+ out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
74
+ if out is a:
75
+ real_a = a[..., 0]
76
+ out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1]
77
+ out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0]
78
+ else:
79
+ out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]
80
+ out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]
81
+ return out
82
+
83
+
84
+ def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
85
+ """Element-wise multiplicative inverse of a Tensor with complex
86
+ entries described through their real and imaginary parts.
87
+ can work in place in case out is z"""
88
+ ez = _norm(z)
89
+ if out is None or out.shape != z.shape:
90
+ out = torch.zeros_like(z)
91
+ out[..., 0] = z[..., 0] / ez
92
+ out[..., 1] = -z[..., 1] / ez
93
+ return out
94
+
95
+
96
+ def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor:
97
+ """Element-wise complex conjugate of a Tensor with complex entries
98
+ described through their real and imaginary parts.
99
+ can work in place in case out is z"""
100
+ if out is None or out.shape != z.shape:
101
+ out = torch.zeros_like(z)
102
+ out[..., 0] = z[..., 0]
103
+ out[..., 1] = -z[..., 1]
104
+ return out
105
+
106
+
107
+ def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
108
+ """
109
+ Invert 1x1 or 2x2 matrices
110
+
111
+ Will generate errors if the matrices are singular: user must handle this
112
+ through his own regularization schemes.
113
+
114
+ Args:
115
+ M (Tensor): [shape=(..., nb_channels, nb_channels, 2)]
116
+ matrices to invert: must be square along dimensions -3 and -2
117
+
118
+ Returns:
119
+ invM (Tensor): [shape=M.shape]
120
+ inverses of M
121
+ """
122
+ nb_channels = M.shape[-2]
123
+
124
+ if out is None or out.shape != M.shape:
125
+ out = torch.empty_like(M)
126
+
127
+ if nb_channels == 1:
128
+ # scalar case
129
+ out = _inv(M, out)
130
+ elif nb_channels == 2:
131
+ # two channels case: analytical expression
132
+
133
+ # first compute the determinent
134
+ det = _mul(M[..., 0, 0, :], M[..., 1, 1, :])
135
+ det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :])
136
+ # invert it
137
+ invDet = _inv(det)
138
+
139
+ # then fill out the matrix with the inverse
140
+ out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :])
141
+ out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :])
142
+ out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :])
143
+ out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :])
144
+ else:
145
+ raise Exception("Only 2 channels are supported for the torch version.")
146
+ return out
147
+
148
+
149
+ # Now define the signal-processing low-level functions used by the Separator
150
+
151
+
152
+ def expectation_maximization(
153
+ y: torch.Tensor,
154
+ x: torch.Tensor,
155
+ iterations: int = 2,
156
+ eps: float = 1e-10,
157
+ batch_size: int = 200,
158
+ ):
159
+ r"""Expectation maximization algorithm, for refining source separation
160
+ estimates.
161
+
162
+ This algorithm allows to make source separation results better by
163
+ enforcing multichannel consistency for the estimates. This usually means
164
+ a better perceptual quality in terms of spatial artifacts.
165
+
166
+ The implementation follows the details presented in [1]_, taking
167
+ inspiration from the original EM algorithm proposed in [2]_ and its
168
+ weighted refinement proposed in [3]_, [4]_.
169
+ It works by iteratively:
170
+
171
+ * Re-estimate source parameters (power spectral densities and spatial
172
+ covariance matrices) through :func:`get_local_gaussian_model`.
173
+
174
+ * Separate again the mixture with the new parameters by first computing
175
+ the new modelled mixture covariance matrices with :func:`get_mix_model`,
176
+ prepare the Wiener filters through :func:`wiener_gain` and apply them
177
+ with :func:`apply_filter``.
178
+
179
+ References
180
+ ----------
181
+ .. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
182
+ N. Takahashi and Y. Mitsufuji, "Improving music source separation based
183
+ on deep neural networks through data augmentation and network
184
+ blending." 2017 IEEE International Conference on Acoustics, Speech
185
+ and Signal Processing (ICASSP). IEEE, 2017.
186
+
187
+ .. [2] N.Q. Duong and E. Vincent and R.Gribonval. "Under-determined
188
+ reverberant audio source separation using a full-rank spatial
189
+ covariance model." IEEE Transactions on Audio, Speech, and Language
190
+ Processing 18.7 (2010): 1830-1840.
191
+
192
+ .. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
193
+ separation with deep neural networks." IEEE/ACM Transactions on Audio,
194
+ Speech, and Language Processing 24.9 (2016): 1652-1664.
195
+
196
+ .. [4] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
197
+ separation with deep neural networks." 2016 24th European Signal
198
+ Processing Conference (EUSIPCO). IEEE, 2016.
199
+
200
+ .. [5] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
201
+ source separation." IEEE Transactions on Signal Processing
202
+ 62.16 (2014): 4298-4310.
203
+
204
+ Args:
205
+ y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
206
+ initial estimates for the sources
207
+ x (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2)]
208
+ complex STFT of the mixture signal
209
+ iterations (int): [scalar]
210
+ number of iterations for the EM algorithm.
211
+ eps (float or None): [scalar]
212
+ The epsilon value to use for regularization and filters.
213
+
214
+ Returns:
215
+ y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
216
+ estimated sources after iterations
217
+ v (Tensor): [shape=(nb_frames, nb_bins, nb_sources)]
218
+ estimated power spectral densities
219
+ R (Tensor): [shape=(nb_bins, nb_channels, nb_channels, 2, nb_sources)]
220
+ estimated spatial covariance matrices
221
+
222
+ Notes:
223
+ * You need an initial estimate for the sources to apply this
224
+ algorithm. This is precisely what the :func:`wiener` function does.
225
+ * This algorithm *is not* an implementation of the "exact" EM
226
+ proposed in [1]_. In particular, it does compute the posterior
227
+ covariance matrices the same (exact) way. Instead, it uses the
228
+ simplified approximate scheme initially proposed in [5]_ and further
229
+ refined in [3]_, [4]_, that boils down to just take the empirical
230
+ covariance of the recent source estimates, followed by a weighted
231
+ average for the update of the spatial covariance matrix. It has been
232
+ empirically demonstrated that this simplified algorithm is more
233
+ robust for music separation.
234
+
235
+ Warning:
236
+ It is *very* important to make sure `x.dtype` is `torch.float64`
237
+ if you want double precision, because this function will **not**
238
+ do such conversion for you from `torch.complex32`, in case you want the
239
+ smaller RAM usage on purpose.
240
+
241
+ It is usually always better in terms of quality to have double
242
+ precision, by e.g. calling :func:`expectation_maximization`
243
+ with ``x.to(torch.float64)``.
244
+ """
245
+ # dimensions
246
+ (nb_frames, nb_bins, nb_channels) = x.shape[:-1]
247
+ nb_sources = y.shape[-1]
248
+
249
+ regularization = torch.cat(
250
+ (
251
+ torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None],
252
+ torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device),
253
+ ),
254
+ dim=2,
255
+ )
256
+ regularization = torch.sqrt(torch.as_tensor(eps)) * (
257
+ regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1))
258
+ )
259
+
260
+ # allocate the spatial covariance matrices
261
+ R = [
262
+ torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device)
263
+ for j in range(nb_sources)
264
+ ]
265
+ weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device)
266
+
267
+ v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device)
268
+ for it in range(iterations):
269
+ # constructing the mixture covariance matrix. Doing it with a loop
270
+ # to avoid storing anytime in RAM the whole 6D tensor
271
+
272
+ # update the PSD as the average spectrogram over channels
273
+ v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2)
274
+
275
+ # update spatial covariance matrices (weighted update)
276
+ for j in range(nb_sources):
277
+ R[j] = torch.tensor(0.0, device=x.device)
278
+ weight = torch.tensor(eps, device=x.device)
279
+ pos: int = 0
280
+ batch_size = batch_size if batch_size else nb_frames
281
+ while pos < nb_frames:
282
+ t = torch.arange(pos, min(nb_frames, pos + batch_size))
283
+ pos = int(t[-1]) + 1
284
+
285
+ R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0)
286
+ weight = weight + torch.sum(v[t, ..., j], dim=0)
287
+ R[j] = R[j] / weight[..., None, None, None]
288
+ weight = torch.zeros_like(weight)
289
+
290
+ # cloning y if we track gradient, because we're going to update it
291
+ if y.requires_grad:
292
+ y = y.clone()
293
+
294
+ pos = 0
295
+ while pos < nb_frames:
296
+ t = torch.arange(pos, min(nb_frames, pos + batch_size))
297
+ pos = int(t[-1]) + 1
298
+
299
+ y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype)
300
+
301
+ # compute mix covariance matrix
302
+ Cxx = regularization
303
+ for j in range(nb_sources):
304
+ Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone())
305
+
306
+ # invert it
307
+ inv_Cxx = _invert(Cxx)
308
+
309
+ # separate the sources
310
+ for j in range(nb_sources):
311
+
312
+ # create a wiener gain for this source
313
+ gain = torch.zeros_like(inv_Cxx)
314
+
315
+ # computes multichannel Wiener gain as v_j R_j inv_Cxx
316
+ indices = torch.cartesian_prod(
317
+ torch.arange(nb_channels),
318
+ torch.arange(nb_channels),
319
+ torch.arange(nb_channels),
320
+ )
321
+ for index in indices:
322
+ gain[:, :, index[0], index[1], :] = _mul_add(
323
+ R[j][None, :, index[0], index[2], :].clone(),
324
+ inv_Cxx[:, :, index[2], index[1], :],
325
+ gain[:, :, index[0], index[1], :],
326
+ )
327
+ gain = gain * v[t, ..., None, None, None, j]
328
+
329
+ # apply it to the mixture
330
+ for i in range(nb_channels):
331
+ y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j])
332
+
333
+ return y, v, R
334
+
335
+
336
+ def wiener(
337
+ targets_spectrograms: torch.Tensor,
338
+ mix_stft: torch.Tensor,
339
+ iterations: int = 1,
340
+ softmask: bool = False,
341
+ residual: bool = False,
342
+ scale_factor: float = 10.0,
343
+ eps: float = 1e-10,
344
+ ):
345
+ """Wiener-based separation for multichannel audio.
346
+
347
+ The method uses the (possibly multichannel) spectrograms of the
348
+ sources to separate the (complex) Short Term Fourier Transform of the
349
+ mix. Separation is done in a sequential way by:
350
+
351
+ * Getting an initial estimate. This can be done in two ways: either by
352
+ directly using the spectrograms with the mixture phase, or
353
+ by using a softmasking strategy. This initial phase is controlled
354
+ by the `softmask` flag.
355
+
356
+ * If required, adding an additional residual target as the mix minus
357
+ all targets.
358
+
359
+ * Refinining these initial estimates through a call to
360
+ :func:`expectation_maximization` if the number of iterations is nonzero.
361
+
362
+ This implementation also allows to specify the epsilon value used for
363
+ regularization. It is based on [1]_, [2]_, [3]_, [4]_.
364
+
365
+ References
366
+ ----------
367
+ .. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
368
+ N. Takahashi and Y. Mitsufuji, "Improving music source separation based
369
+ on deep neural networks through data augmentation and network
370
+ blending." 2017 IEEE International Conference on Acoustics, Speech
371
+ and Signal Processing (ICASSP). IEEE, 2017.
372
+
373
+ .. [2] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
374
+ separation with deep neural networks." IEEE/ACM Transactions on Audio,
375
+ Speech, and Language Processing 24.9 (2016): 1652-1664.
376
+
377
+ .. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
378
+ separation with deep neural networks." 2016 24th European Signal
379
+ Processing Conference (EUSIPCO). IEEE, 2016.
380
+
381
+ .. [4] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
382
+ source separation." IEEE Transactions on Signal Processing
383
+ 62.16 (2014): 4298-4310.
384
+
385
+ Args:
386
+ targets_spectrograms (Tensor): spectrograms of the sources
387
+ [shape=(nb_frames, nb_bins, nb_channels, nb_sources)].
388
+ This is a nonnegative tensor that is
389
+ usually the output of the actual separation method of the user. The
390
+ spectrograms may be mono, but they need to be 4-dimensional in all
391
+ cases.
392
+ mix_stft (Tensor): [shape=(nb_frames, nb_bins, nb_channels, complex=2)]
393
+ STFT of the mixture signal.
394
+ iterations (int): [scalar]
395
+ number of iterations for the EM algorithm
396
+ softmask (bool): Describes how the initial estimates are obtained.
397
+ * if `False`, then the mixture phase will directly be used with the
398
+ spectrogram as initial estimates.
399
+ * if `True`, initial estimates are obtained by multiplying the
400
+ complex mix element-wise with the ratio of each target spectrogram
401
+ with the sum of them all. This strategy is better if the model are
402
+ not really good, and worse otherwise.
403
+ residual (bool): if `True`, an additional target is created, which is
404
+ equal to the mixture minus the other targets, before application of
405
+ expectation maximization
406
+ eps (float): Epsilon value to use for computing the separations.
407
+ This is used whenever division with a model energy is
408
+ performed, i.e. when softmasking and when iterating the EM.
409
+ It can be understood as the energy of the additional white noise
410
+ that is taken out when separating.
411
+
412
+ Returns:
413
+ Tensor: shape=(nb_frames, nb_bins, nb_channels, complex=2, nb_sources)
414
+ STFT of estimated sources
415
+
416
+ Notes:
417
+ * Be careful that you need *magnitude spectrogram estimates* for the
418
+ case `softmask==False`.
419
+ * `softmask=False` is recommended
420
+ * The epsilon value will have a huge impact on performance. If it's
421
+ large, only the parts of the signal with a significant energy will
422
+ be kept in the sources. This epsilon then directly controls the
423
+ energy of the reconstruction error.
424
+
425
+ Warning:
426
+ As in :func:`expectation_maximization`, we recommend converting the
427
+ mixture `x` to double precision `torch.float64` *before* calling
428
+ :func:`wiener`.
429
+ """
430
+ if softmask:
431
+ # if we use softmask, we compute the ratio mask for all targets and
432
+ # multiply by the mix stft
433
+ y = (
434
+ mix_stft[..., None]
435
+ * (
436
+ targets_spectrograms
437
+ / (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype))
438
+ )[..., None, :]
439
+ )
440
+ else:
441
+ # otherwise, we just multiply the targets spectrograms with mix phase
442
+ # we tacitly assume that we have magnitude estimates.
443
+ angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None]
444
+ nb_sources = targets_spectrograms.shape[-1]
445
+ y = torch.zeros(
446
+ mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device
447
+ )
448
+ y[..., 0, :] = targets_spectrograms * torch.cos(angle)
449
+ y[..., 1, :] = targets_spectrograms * torch.sin(angle)
450
+
451
+ if residual:
452
+ # if required, adding an additional target as the mix minus
453
+ # available targets
454
+ y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1)
455
+
456
+ if iterations == 0:
457
+ return y
458
+
459
+ # we need to refine the estimates. Scales down the estimates for
460
+ # numerical stability
461
+ max_abs = torch.max(
462
+ torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device),
463
+ torch.sqrt(_norm(mix_stft)).max() / scale_factor,
464
+ )
465
+
466
+ mix_stft = mix_stft / max_abs
467
+ y = y / max_abs
468
+
469
+ # call expectation maximization
470
+ y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0]
471
+
472
+ # scale estimates up again
473
+ y = y * max_abs
474
+ return y
475
+
476
+
477
+ def _covariance(y_j):
478
+ """
479
+ Compute the empirical covariance for a source.
480
+
481
+ Args:
482
+ y_j (Tensor): complex stft of the source.
483
+ [shape=(nb_frames, nb_bins, nb_channels, 2)].
484
+
485
+ Returns:
486
+ Cj (Tensor): [shape=(nb_frames, nb_bins, nb_channels, nb_channels, 2)]
487
+ just y_j * conj(y_j.T): empirical covariance for each TF bin.
488
+ """
489
+ (nb_frames, nb_bins, nb_channels) = y_j.shape[:-1]
490
+ Cj = torch.zeros(
491
+ (nb_frames, nb_bins, nb_channels, nb_channels, 2),
492
+ dtype=y_j.dtype,
493
+ device=y_j.device,
494
+ )
495
+ indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels))
496
+ for index in indices:
497
+ Cj[:, :, index[0], index[1], :] = _mul_add(
498
+ y_j[:, :, index[0], :],
499
+ _conj(y_j[:, :, index[1], :]),
500
+ Cj[:, :, index[0], index[1], :],
501
+ )
502
+ return Cj
demucs/hdemucs.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ This code contains the spectrogram and Hybrid version of Demucs.
8
+ """
9
+ from copy import deepcopy
10
+ import math
11
+ import typing as tp
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+ from .filtering import wiener
16
+ from .demucs import DConv, rescale_module
17
+ from .states import capture_init
18
+ from .spec import spectro, ispectro
19
+
20
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
21
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
22
+ If this is the case, we insert extra 0 padding to the right before the reflection happen."""
23
+ x0 = x
24
+ length = x.shape[-1]
25
+ padding_left, padding_right = paddings
26
+ if mode == 'reflect':
27
+ max_pad = max(padding_left, padding_right)
28
+ if length <= max_pad:
29
+ extra_pad = max_pad - length + 1
30
+ extra_pad_right = min(padding_right, extra_pad)
31
+ extra_pad_left = extra_pad - extra_pad_right
32
+ paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
33
+ x = F.pad(x, (extra_pad_left, extra_pad_right))
34
+ out = F.pad(x, paddings, mode, value)
35
+ assert out.shape[-1] == length + padding_left + padding_right
36
+ assert (out[..., padding_left: padding_left + length] == x0).all()
37
+ return out
38
+
39
+ class ScaledEmbedding(nn.Module):
40
+ """
41
+ Boost learning rate for embeddings (with `scale`).
42
+ Also, can make embeddings continuous with `smooth`.
43
+ """
44
+ def __init__(self, num_embeddings: int, embedding_dim: int,
45
+ scale: float = 10., smooth=False):
46
+ super().__init__()
47
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
48
+ if smooth:
49
+ weight = torch.cumsum(self.embedding.weight.data, dim=0)
50
+ # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
51
+ weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
52
+ self.embedding.weight.data[:] = weight
53
+ self.embedding.weight.data /= scale
54
+ self.scale = scale
55
+
56
+ @property
57
+ def weight(self):
58
+ return self.embedding.weight * self.scale
59
+
60
+ def forward(self, x):
61
+ out = self.embedding(x) * self.scale
62
+ return out
63
+
64
+
65
+ class HEncLayer(nn.Module):
66
+ def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
67
+ freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,
68
+ rewrite=True):
69
+ """Encoder layer. This used both by the time and the frequency branch.
70
+
71
+ Args:
72
+ chin: number of input channels.
73
+ chout: number of output channels.
74
+ norm_groups: number of groups for group norm.
75
+ empty: used to make a layer with just the first conv. this is used
76
+ before merging the time and freq. branches.
77
+ freq: this is acting on frequencies.
78
+ dconv: insert DConv residual branches.
79
+ norm: use GroupNorm.
80
+ context: context size for the 1x1 conv.
81
+ dconv_kw: list of kwargs for the DConv class.
82
+ pad: pad the input. Padding is done so that the output size is
83
+ always the input size / stride.
84
+ rewrite: add 1x1 conv at the end of the layer.
85
+ """
86
+ super().__init__()
87
+ norm_fn = lambda d: nn.Identity() # noqa
88
+ if norm:
89
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
90
+ if pad:
91
+ pad = kernel_size // 4
92
+ else:
93
+ pad = 0
94
+ klass = nn.Conv1d
95
+ self.freq = freq
96
+ self.kernel_size = kernel_size
97
+ self.stride = stride
98
+ self.empty = empty
99
+ self.norm = norm
100
+ self.pad = pad
101
+ if freq:
102
+ kernel_size = [kernel_size, 1]
103
+ stride = [stride, 1]
104
+ pad = [pad, 0]
105
+ klass = nn.Conv2d
106
+ self.conv = klass(chin, chout, kernel_size, stride, pad)
107
+ if self.empty:
108
+ return
109
+ self.norm1 = norm_fn(chout)
110
+ self.rewrite = None
111
+ if rewrite:
112
+ self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
113
+ self.norm2 = norm_fn(2 * chout)
114
+
115
+ self.dconv = None
116
+ if dconv:
117
+ self.dconv = DConv(chout, **dconv_kw)
118
+
119
+ def forward(self, x, inject=None):
120
+ """
121
+ `inject` is used to inject the result from the time branch into the frequency branch,
122
+ when both have the same stride.
123
+ """
124
+ if not self.freq and x.dim() == 4:
125
+ B, C, Fr, T = x.shape
126
+ x = x.view(B, -1, T)
127
+
128
+ if not self.freq:
129
+ le = x.shape[-1]
130
+ if not le % self.stride == 0:
131
+ x = F.pad(x, (0, self.stride - (le % self.stride)))
132
+ y = self.conv(x)
133
+ if self.empty:
134
+ return y
135
+ if inject is not None:
136
+ assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
137
+ if inject.dim() == 3 and y.dim() == 4:
138
+ inject = inject[:, :, None]
139
+ y = y + inject
140
+ y = F.gelu(self.norm1(y))
141
+ if self.dconv:
142
+ if self.freq:
143
+ B, C, Fr, T = y.shape
144
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
145
+ y = self.dconv(y)
146
+ if self.freq:
147
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
148
+ if self.rewrite:
149
+ z = self.norm2(self.rewrite(y))
150
+ z = F.glu(z, dim=1)
151
+ else:
152
+ z = y
153
+ return z
154
+
155
+
156
+ class MultiWrap(nn.Module):
157
+ """
158
+ Takes one layer and replicate it N times. each replica will act
159
+ on a frequency band. All is done so that if the N replica have the same weights,
160
+ then this is exactly equivalent to applying the original module on all frequencies.
161
+
162
+ This is a bit over-engineered to avoid edge artifacts when splitting
163
+ the frequency bands, but it is possible the naive implementation would work as well...
164
+ """
165
+ def __init__(self, layer, split_ratios):
166
+ """
167
+ Args:
168
+ layer: module to clone, must be either HEncLayer or HDecLayer.
169
+ split_ratios: list of float indicating which ratio to keep for each band.
170
+ """
171
+ super().__init__()
172
+ self.split_ratios = split_ratios
173
+ self.layers = nn.ModuleList()
174
+ self.conv = isinstance(layer, HEncLayer)
175
+ assert not layer.norm
176
+ assert layer.freq
177
+ assert layer.pad
178
+ if not self.conv:
179
+ assert not layer.context_freq
180
+ for k in range(len(split_ratios) + 1):
181
+ lay = deepcopy(layer)
182
+ if self.conv:
183
+ lay.conv.padding = (0, 0)
184
+ else:
185
+ lay.pad = False
186
+ for m in lay.modules():
187
+ if hasattr(m, 'reset_parameters'):
188
+ m.reset_parameters()
189
+ self.layers.append(lay)
190
+
191
+ def forward(self, x, skip=None, length=None):
192
+ B, C, Fr, T = x.shape
193
+
194
+ ratios = list(self.split_ratios) + [1]
195
+ start = 0
196
+ outs = []
197
+ for ratio, layer in zip(ratios, self.layers):
198
+ if self.conv:
199
+ pad = layer.kernel_size // 4
200
+ if ratio == 1:
201
+ limit = Fr
202
+ frames = -1
203
+ else:
204
+ limit = int(round(Fr * ratio))
205
+ le = limit - start
206
+ if start == 0:
207
+ le += pad
208
+ frames = round((le - layer.kernel_size) / layer.stride + 1)
209
+ limit = start + (frames - 1) * layer.stride + layer.kernel_size
210
+ if start == 0:
211
+ limit -= pad
212
+ assert limit - start > 0, (limit, start)
213
+ assert limit <= Fr, (limit, Fr)
214
+ y = x[:, :, start:limit, :]
215
+ if start == 0:
216
+ y = F.pad(y, (0, 0, pad, 0))
217
+ if ratio == 1:
218
+ y = F.pad(y, (0, 0, 0, pad))
219
+ outs.append(layer(y))
220
+ start = limit - layer.kernel_size + layer.stride
221
+ else:
222
+ if ratio == 1:
223
+ limit = Fr
224
+ else:
225
+ limit = int(round(Fr * ratio))
226
+ last = layer.last
227
+ layer.last = True
228
+
229
+ y = x[:, :, start:limit]
230
+ s = skip[:, :, start:limit]
231
+ out, _ = layer(y, s, None)
232
+ if outs:
233
+ outs[-1][:, :, -layer.stride:] += (
234
+ out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1))
235
+ out = out[:, :, layer.stride:]
236
+ if ratio == 1:
237
+ out = out[:, :, :-layer.stride // 2, :]
238
+ if start == 0:
239
+ out = out[:, :, layer.stride // 2:, :]
240
+ outs.append(out)
241
+ layer.last = last
242
+ start = limit
243
+ out = torch.cat(outs, dim=2)
244
+ if not self.conv and not last:
245
+ out = F.gelu(out)
246
+ if self.conv:
247
+ return out
248
+ else:
249
+ return out, None
250
+
251
+
252
+ class HDecLayer(nn.Module):
253
+ def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
254
+ freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
255
+ context_freq=True, rewrite=True):
256
+ """
257
+ Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
258
+ """
259
+ super().__init__()
260
+ norm_fn = lambda d: nn.Identity() # noqa
261
+ if norm:
262
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
263
+ if pad:
264
+ pad = kernel_size // 4
265
+ else:
266
+ pad = 0
267
+ self.pad = pad
268
+ self.last = last
269
+ self.freq = freq
270
+ self.chin = chin
271
+ self.empty = empty
272
+ self.stride = stride
273
+ self.kernel_size = kernel_size
274
+ self.norm = norm
275
+ self.context_freq = context_freq
276
+ klass = nn.Conv1d
277
+ klass_tr = nn.ConvTranspose1d
278
+ if freq:
279
+ kernel_size = [kernel_size, 1]
280
+ stride = [stride, 1]
281
+ klass = nn.Conv2d
282
+ klass_tr = nn.ConvTranspose2d
283
+ self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
284
+ self.norm2 = norm_fn(chout)
285
+ if self.empty:
286
+ return
287
+ self.rewrite = None
288
+ if rewrite:
289
+ if context_freq:
290
+ self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
291
+ else:
292
+ self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
293
+ [0, context])
294
+ self.norm1 = norm_fn(2 * chin)
295
+
296
+ self.dconv = None
297
+ if dconv:
298
+ self.dconv = DConv(chin, **dconv_kw)
299
+
300
+ def forward(self, x, skip, length):
301
+ if self.freq and x.dim() == 3:
302
+ B, C, T = x.shape
303
+ x = x.view(B, self.chin, -1, T)
304
+
305
+ if not self.empty:
306
+ x = x + skip
307
+
308
+ if self.rewrite:
309
+ y = F.glu(self.norm1(self.rewrite(x)), dim=1)
310
+ else:
311
+ y = x
312
+ if self.dconv:
313
+ if self.freq:
314
+ B, C, Fr, T = y.shape
315
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
316
+ y = self.dconv(y)
317
+ if self.freq:
318
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
319
+ else:
320
+ y = x
321
+ assert skip is None
322
+ z = self.norm2(self.conv_tr(y))
323
+ if self.freq:
324
+ if self.pad:
325
+ z = z[..., self.pad:-self.pad, :]
326
+ else:
327
+ z = z[..., self.pad:self.pad + length]
328
+ assert z.shape[-1] == length, (z.shape[-1], length)
329
+ if not self.last:
330
+ z = F.gelu(z)
331
+ return z, y
332
+
333
+
334
+ class HDemucs(nn.Module):
335
+ """
336
+ Spectrogram and hybrid Demucs model.
337
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
338
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
339
+ Frequency layers can still access information across time steps thanks to the DConv residual.
340
+
341
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
342
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
343
+
344
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
345
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
346
+ Open Unmix implementation [Stoter et al. 2019].
347
+
348
+ The loss is always on the temporal domain, by backpropagating through the above
349
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
350
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
351
+ contribution, without changing the one from the waveform, which will lead to worse performance.
352
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
353
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
354
+ hybrid models.
355
+
356
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
357
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
358
+
359
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
360
+ """
361
+ @capture_init
362
+ def __init__(self,
363
+ sources,
364
+ # Channels
365
+ audio_channels=2,
366
+ channels=48,
367
+ channels_time=None,
368
+ growth=2,
369
+ # STFT
370
+ nfft=4096,
371
+ wiener_iters=0,
372
+ end_iters=0,
373
+ wiener_residual=False,
374
+ cac=True,
375
+ # Main structure
376
+ depth=6,
377
+ rewrite=True,
378
+ hybrid=True,
379
+ hybrid_old=False,
380
+ # Frequency branch
381
+ multi_freqs=None,
382
+ multi_freqs_depth=2,
383
+ freq_emb=0.2,
384
+ emb_scale=10,
385
+ emb_smooth=True,
386
+ # Convolutions
387
+ kernel_size=8,
388
+ time_stride=2,
389
+ stride=4,
390
+ context=1,
391
+ context_enc=0,
392
+ # Normalization
393
+ norm_starts=4,
394
+ norm_groups=4,
395
+ # DConv residual branch
396
+ dconv_mode=1,
397
+ dconv_depth=2,
398
+ dconv_comp=4,
399
+ dconv_attn=4,
400
+ dconv_lstm=4,
401
+ dconv_init=1e-4,
402
+ # Weight init
403
+ rescale=0.1,
404
+ # Metadata
405
+ samplerate=44100,
406
+ segment=4 * 10):
407
+
408
+ """
409
+ Args:
410
+ sources (list[str]): list of source names.
411
+ audio_channels (int): input/output audio channels.
412
+ channels (int): initial number of hidden channels.
413
+ channels_time: if not None, use a different `channels` value for the time branch.
414
+ growth: increase the number of hidden channels by this factor at each layer.
415
+ nfft: number of fft bins. Note that changing this require careful computation of
416
+ various shape parameters and will not work out of the box for hybrid models.
417
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
418
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
419
+ wiener_residual: add residual source before wiener filtering.
420
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
421
+ in input and output. no further processing is done before ISTFT.
422
+ depth (int): number of layers in the encoder and in the decoder.
423
+ rewrite (bool): add 1x1 convolution to each layer.
424
+ hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
425
+ hybrid_old: some models trained for MDX had a padding bug. This replicates
426
+ this bug to avoid retraining them.
427
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
428
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
429
+ layers will be wrapped.
430
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
431
+ the actual value controls the weight of the embedding.
432
+ emb_scale: equivalent to scaling the embedding learning rate
433
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
434
+ kernel_size: kernel_size for encoder and decoder layers.
435
+ stride: stride for encoder and decoder layers.
436
+ time_stride: stride for the final time layer, after the merge.
437
+ context: context for 1x1 conv in the decoder.
438
+ context_enc: context for 1x1 conv in the encoder.
439
+ norm_starts: layer at which group norm starts being used.
440
+ decoder layers are numbered in reverse order.
441
+ norm_groups: number of groups for group norm.
442
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
443
+ dconv_depth: depth of residual DConv branch.
444
+ dconv_comp: compression of DConv branch.
445
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
446
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
447
+ dconv_init: initial scale for the DConv branch LayerScale.
448
+ rescale: weight recaling trick
449
+
450
+ """
451
+ super().__init__()
452
+
453
+ self.cac = cac
454
+ self.wiener_residual = wiener_residual
455
+ self.audio_channels = audio_channels
456
+ self.sources = sources
457
+ self.kernel_size = kernel_size
458
+ self.context = context
459
+ self.stride = stride
460
+ self.depth = depth
461
+ self.channels = channels
462
+ self.samplerate = samplerate
463
+ self.segment = segment
464
+
465
+ self.nfft = nfft
466
+ self.hop_length = nfft // 4
467
+ self.wiener_iters = wiener_iters
468
+ self.end_iters = end_iters
469
+ self.freq_emb = None
470
+ self.hybrid = hybrid
471
+ self.hybrid_old = hybrid_old
472
+ if hybrid_old:
473
+ assert hybrid, "hybrid_old must come with hybrid=True"
474
+ if hybrid:
475
+ assert wiener_iters == end_iters
476
+
477
+ self.encoder = nn.ModuleList()
478
+ self.decoder = nn.ModuleList()
479
+
480
+ if hybrid:
481
+ self.tencoder = nn.ModuleList()
482
+ self.tdecoder = nn.ModuleList()
483
+
484
+ chin = audio_channels
485
+ chin_z = chin # number of channels for the freq branch
486
+ if self.cac:
487
+ chin_z *= 2
488
+ chout = channels_time or channels
489
+ chout_z = channels
490
+ freqs = nfft // 2
491
+
492
+ for index in range(depth):
493
+ lstm = index >= dconv_lstm
494
+ attn = index >= dconv_attn
495
+ norm = index >= norm_starts
496
+ freq = freqs > 1
497
+ stri = stride
498
+ ker = kernel_size
499
+ if not freq:
500
+ assert freqs == 1
501
+ ker = time_stride * 2
502
+ stri = time_stride
503
+
504
+ pad = True
505
+ last_freq = False
506
+ if freq and freqs <= kernel_size:
507
+ ker = freqs
508
+ pad = False
509
+ last_freq = True
510
+
511
+ kw = {
512
+ 'kernel_size': ker,
513
+ 'stride': stri,
514
+ 'freq': freq,
515
+ 'pad': pad,
516
+ 'norm': norm,
517
+ 'rewrite': rewrite,
518
+ 'norm_groups': norm_groups,
519
+ 'dconv_kw': {
520
+ 'lstm': lstm,
521
+ 'attn': attn,
522
+ 'depth': dconv_depth,
523
+ 'compress': dconv_comp,
524
+ 'init': dconv_init,
525
+ 'gelu': True,
526
+ }
527
+ }
528
+ kwt = dict(kw)
529
+ kwt['freq'] = 0
530
+ kwt['kernel_size'] = kernel_size
531
+ kwt['stride'] = stride
532
+ kwt['pad'] = True
533
+ kw_dec = dict(kw)
534
+ multi = False
535
+ if multi_freqs and index < multi_freqs_depth:
536
+ multi = True
537
+ kw_dec['context_freq'] = False
538
+
539
+ if last_freq:
540
+ chout_z = max(chout, chout_z)
541
+ chout = chout_z
542
+
543
+ enc = HEncLayer(chin_z, chout_z,
544
+ dconv=dconv_mode & 1, context=context_enc, **kw)
545
+ if hybrid and freq:
546
+ tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc,
547
+ empty=last_freq, **kwt)
548
+ self.tencoder.append(tenc)
549
+
550
+ if multi:
551
+ enc = MultiWrap(enc, multi_freqs)
552
+ self.encoder.append(enc)
553
+ if index == 0:
554
+ chin = self.audio_channels * len(self.sources)
555
+ chin_z = chin
556
+ if self.cac:
557
+ chin_z *= 2
558
+ dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2,
559
+ last=index == 0, context=context, **kw_dec)
560
+ if multi:
561
+ dec = MultiWrap(dec, multi_freqs)
562
+ if hybrid and freq:
563
+ tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq,
564
+ last=index == 0, context=context, **kwt)
565
+ self.tdecoder.insert(0, tdec)
566
+ self.decoder.insert(0, dec)
567
+
568
+ chin = chout
569
+ chin_z = chout_z
570
+ chout = int(growth * chout)
571
+ chout_z = int(growth * chout_z)
572
+ if freq:
573
+ if freqs <= kernel_size:
574
+ freqs = 1
575
+ else:
576
+ freqs //= stride
577
+ if index == 0 and freq_emb:
578
+ self.freq_emb = ScaledEmbedding(
579
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
580
+ self.freq_emb_scale = freq_emb
581
+
582
+ if rescale:
583
+ rescale_module(self, reference=rescale)
584
+
585
+ def _spec(self, x):
586
+ hl = self.hop_length
587
+ nfft = self.nfft
588
+ x0 = x # noqa
589
+
590
+ if self.hybrid:
591
+ # We re-pad the signal in order to keep the property
592
+ # that the size of the output is exactly the size of the input
593
+ # divided by the stride (here hop_length), when divisible.
594
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
595
+ # which is not supported by torch.stft.
596
+ # Having all convolution operations follow this convention allow to easily
597
+ # align the time and frequency branches later on.
598
+ assert hl == nfft // 4
599
+ le = int(math.ceil(x.shape[-1] / hl))
600
+ pad = hl // 2 * 3
601
+ if not self.hybrid_old:
602
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect')
603
+ else:
604
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]))
605
+
606
+ z = spectro(x, nfft, hl)[..., :-1, :]
607
+ if self.hybrid:
608
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
609
+ z = z[..., 2:2+le]
610
+ return z
611
+
612
+ def _ispec(self, z, length=None, scale=0):
613
+ hl = self.hop_length // (4 ** scale)
614
+ z = F.pad(z, (0, 0, 0, 1))
615
+ if self.hybrid:
616
+ z = F.pad(z, (2, 2))
617
+ pad = hl // 2 * 3
618
+ if not self.hybrid_old:
619
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
620
+ else:
621
+ le = hl * int(math.ceil(length / hl))
622
+ x = ispectro(z, hl, length=le)
623
+ if not self.hybrid_old:
624
+ x = x[..., pad:pad + length]
625
+ else:
626
+ x = x[..., :length]
627
+ else:
628
+ x = ispectro(z, hl, length)
629
+ return x
630
+
631
+ def _magnitude(self, z):
632
+ # return the magnitude of the spectrogram, except when cac is True,
633
+ # in which case we just move the complex dimension to the channel one.
634
+ if self.cac:
635
+ B, C, Fr, T = z.shape
636
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
637
+ m = m.reshape(B, C * 2, Fr, T)
638
+ else:
639
+ m = z.abs()
640
+ return m
641
+
642
+ def _mask(self, z, m):
643
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
644
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
645
+ niters = self.wiener_iters
646
+ if self.cac:
647
+ B, S, C, Fr, T = m.shape
648
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
649
+ out = torch.view_as_complex(out.contiguous())
650
+ return out
651
+ if self.training:
652
+ niters = self.end_iters
653
+ if niters < 0:
654
+ z = z[:, None]
655
+ return z / (1e-8 + z.abs()) * m
656
+ else:
657
+ return self._wiener(m, z, niters)
658
+
659
+ def _wiener(self, mag_out, mix_stft, niters):
660
+ # apply wiener filtering from OpenUnmix.
661
+ init = mix_stft.dtype
662
+ wiener_win_len = 300
663
+ residual = self.wiener_residual
664
+
665
+ B, S, C, Fq, T = mag_out.shape
666
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
667
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
668
+
669
+ outs = []
670
+ for sample in range(B):
671
+ pos = 0
672
+ out = []
673
+ for pos in range(0, T, wiener_win_len):
674
+ frame = slice(pos, pos + wiener_win_len)
675
+ z_out = wiener(
676
+ mag_out[sample, frame], mix_stft[sample, frame], niters,
677
+ residual=residual)
678
+ out.append(z_out.transpose(-1, -2))
679
+ outs.append(torch.cat(out, dim=0))
680
+ out = torch.view_as_complex(torch.stack(outs, 0))
681
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
682
+ if residual:
683
+ out = out[:, :-1]
684
+ assert list(out.shape) == [B, S, C, Fq, T]
685
+ return out.to(init)
686
+
687
+ def forward(self, mix):
688
+ x = mix
689
+ length = x.shape[-1]
690
+
691
+ z = self._spec(mix)
692
+ mag = self._magnitude(z)
693
+ x = mag
694
+
695
+ B, C, Fq, T = x.shape
696
+
697
+ # unlike previous Demucs, we always normalize because it is easier.
698
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
699
+ std = x.std(dim=(1, 2, 3), keepdim=True)
700
+ x = (x - mean) / (1e-5 + std)
701
+ # x will be the freq. branch input.
702
+
703
+ if self.hybrid:
704
+ # Prepare the time branch input.
705
+ xt = mix
706
+ meant = xt.mean(dim=(1, 2), keepdim=True)
707
+ stdt = xt.std(dim=(1, 2), keepdim=True)
708
+ xt = (xt - meant) / (1e-5 + stdt)
709
+
710
+ # okay, this is a giant mess I know...
711
+ saved = [] # skip connections, freq.
712
+ saved_t = [] # skip connections, time.
713
+ lengths = [] # saved lengths to properly remove padding, freq branch.
714
+ lengths_t = [] # saved lengths for time branch.
715
+ for idx, encode in enumerate(self.encoder):
716
+ lengths.append(x.shape[-1])
717
+ inject = None
718
+ if self.hybrid and idx < len(self.tencoder):
719
+ # we have not yet merged branches.
720
+ lengths_t.append(xt.shape[-1])
721
+ tenc = self.tencoder[idx]
722
+ xt = tenc(xt)
723
+ if not tenc.empty:
724
+ # save for skip connection
725
+ saved_t.append(xt)
726
+ else:
727
+ # tenc contains just the first conv., so that now time and freq.
728
+ # branches have the same shape and can be merged.
729
+ inject = xt
730
+ x = encode(x, inject)
731
+ if idx == 0 and self.freq_emb is not None:
732
+ # add frequency embedding to allow for non equivariant convolutions
733
+ # over the frequency axis.
734
+ frs = torch.arange(x.shape[-2], device=x.device)
735
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
736
+ x = x + self.freq_emb_scale * emb
737
+
738
+ saved.append(x)
739
+
740
+ x = torch.zeros_like(x)
741
+ if self.hybrid:
742
+ xt = torch.zeros_like(x)
743
+ # initialize everything to zero (signal will go through u-net skips).
744
+
745
+ for idx, decode in enumerate(self.decoder):
746
+ skip = saved.pop(-1)
747
+ x, pre = decode(x, skip, lengths.pop(-1))
748
+ # `pre` contains the output just before final transposed convolution,
749
+ # which is used when the freq. and time branch separate.
750
+
751
+ if self.hybrid:
752
+ offset = self.depth - len(self.tdecoder)
753
+ if self.hybrid and idx >= offset:
754
+ tdec = self.tdecoder[idx - offset]
755
+ length_t = lengths_t.pop(-1)
756
+ if tdec.empty:
757
+ assert pre.shape[2] == 1, pre.shape
758
+ pre = pre[:, :, 0]
759
+ xt, _ = tdec(pre, None, length_t)
760
+ else:
761
+ skip = saved_t.pop(-1)
762
+ xt, _ = tdec(xt, skip, length_t)
763
+
764
+ # Let's make sure we used all stored skip connections.
765
+ assert len(saved) == 0
766
+ assert len(lengths_t) == 0
767
+ assert len(saved_t) == 0
768
+
769
+ S = len(self.sources)
770
+ x = x.view(B, S, -1, Fq, T)
771
+ x = x * std[:, None] + mean[:, None]
772
+
773
+ zout = self._mask(z, x)
774
+ x = self._ispec(zout, length)
775
+
776
+ if self.hybrid:
777
+ xt = xt.view(B, S, -1, length)
778
+ xt = xt * stdt[:, None] + meant[:, None]
779
+ x = xt + x
780
+ return x
781
+
782
+
demucs/htdemucs.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # First author is Simon Rouard.
7
+ """
8
+ This code contains the spectrogram and Hybrid version of Demucs.
9
+ """
10
+ import math
11
+
12
+ from .filtering import wiener
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from fractions import Fraction
17
+ from einops import rearrange
18
+
19
+ from .transformer import CrossTransformerEncoder
20
+
21
+ from .demucs import rescale_module
22
+ from .states import capture_init
23
+ from .spec import spectro, ispectro
24
+ from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
25
+
26
+
27
+ class HTDemucs(nn.Module):
28
+ """
29
+ Spectrogram and hybrid Demucs model.
30
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
31
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
32
+ Frequency layers can still access information across time steps thanks to the DConv residual.
33
+
34
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
35
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
36
+
37
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
38
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
39
+ Open Unmix implementation [Stoter et al. 2019].
40
+
41
+ The loss is always on the temporal domain, by backpropagating through the above
42
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
43
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
44
+ contribution, without changing the one from the waveform, which will lead to worse performance.
45
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
46
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
47
+ hybrid models.
48
+
49
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
50
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
51
+
52
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
53
+ """
54
+
55
+ @capture_init
56
+ def __init__(
57
+ self,
58
+ sources,
59
+ # Channels
60
+ audio_channels=2,
61
+ channels=48,
62
+ channels_time=None,
63
+ growth=2,
64
+ # STFT
65
+ nfft=4096,
66
+ wiener_iters=0,
67
+ end_iters=0,
68
+ wiener_residual=False,
69
+ cac=True,
70
+ # Main structure
71
+ depth=4,
72
+ rewrite=True,
73
+ # Frequency branch
74
+ multi_freqs=None,
75
+ multi_freqs_depth=3,
76
+ freq_emb=0.2,
77
+ emb_scale=10,
78
+ emb_smooth=True,
79
+ # Convolutions
80
+ kernel_size=8,
81
+ time_stride=2,
82
+ stride=4,
83
+ context=1,
84
+ context_enc=0,
85
+ # Normalization
86
+ norm_starts=4,
87
+ norm_groups=4,
88
+ # DConv residual branch
89
+ dconv_mode=1,
90
+ dconv_depth=2,
91
+ dconv_comp=8,
92
+ dconv_init=1e-3,
93
+ # Before the Transformer
94
+ bottom_channels=0,
95
+ # Transformer
96
+ t_layers=5,
97
+ t_emb="sin",
98
+ t_hidden_scale=4.0,
99
+ t_heads=8,
100
+ t_dropout=0.0,
101
+ t_max_positions=10000,
102
+ t_norm_in=True,
103
+ t_norm_in_group=False,
104
+ t_group_norm=False,
105
+ t_norm_first=True,
106
+ t_norm_out=True,
107
+ t_max_period=10000.0,
108
+ t_weight_decay=0.0,
109
+ t_lr=None,
110
+ t_layer_scale=True,
111
+ t_gelu=True,
112
+ t_weight_pos_embed=1.0,
113
+ t_sin_random_shift=0,
114
+ t_cape_mean_normalize=True,
115
+ t_cape_augment=True,
116
+ t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
117
+ t_sparse_self_attn=False,
118
+ t_sparse_cross_attn=False,
119
+ t_mask_type="diag",
120
+ t_mask_random_seed=42,
121
+ t_sparse_attn_window=500,
122
+ t_global_window=100,
123
+ t_sparsity=0.95,
124
+ t_auto_sparsity=False,
125
+ # ------ Particuliar parameters
126
+ t_cross_first=False,
127
+ # Weight init
128
+ rescale=0.1,
129
+ # Metadata
130
+ samplerate=44100,
131
+ segment=10,
132
+ use_train_segment=True,
133
+ ):
134
+ """
135
+ Args:
136
+ sources (list[str]): list of source names.
137
+ audio_channels (int): input/output audio channels.
138
+ channels (int): initial number of hidden channels.
139
+ channels_time: if not None, use a different `channels` value for the time branch.
140
+ growth: increase the number of hidden channels by this factor at each layer.
141
+ nfft: number of fft bins. Note that changing this require careful computation of
142
+ various shape parameters and will not work out of the box for hybrid models.
143
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
144
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
145
+ wiener_residual: add residual source before wiener filtering.
146
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
147
+ in input and output. no further processing is done before ISTFT.
148
+ depth (int): number of layers in the encoder and in the decoder.
149
+ rewrite (bool): add 1x1 convolution to each layer.
150
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
151
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
152
+ layers will be wrapped.
153
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
154
+ the actual value controls the weight of the embedding.
155
+ emb_scale: equivalent to scaling the embedding learning rate
156
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
157
+ kernel_size: kernel_size for encoder and decoder layers.
158
+ stride: stride for encoder and decoder layers.
159
+ time_stride: stride for the final time layer, after the merge.
160
+ context: context for 1x1 conv in the decoder.
161
+ context_enc: context for 1x1 conv in the encoder.
162
+ norm_starts: layer at which group norm starts being used.
163
+ decoder layers are numbered in reverse order.
164
+ norm_groups: number of groups for group norm.
165
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
166
+ dconv_depth: depth of residual DConv branch.
167
+ dconv_comp: compression of DConv branch.
168
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
169
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
170
+ dconv_init: initial scale for the DConv branch LayerScale.
171
+ bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
172
+ transformer in order to change the number of channels
173
+ t_layers: number of layers in each branch (waveform and spec) of the transformer
174
+ t_emb: "sin", "cape" or "scaled"
175
+ t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
176
+ for instance if C = 384 (the number of channels in the transformer) and
177
+ t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
178
+ 384 * 4 = 1536
179
+ t_heads: number of heads for the transformer
180
+ t_dropout: dropout in the transformer
181
+ t_max_positions: max_positions for the "scaled" positional embedding, only
182
+ useful if t_emb="scaled"
183
+ t_norm_in: (bool) norm before addinf positional embedding and getting into the
184
+ transformer layers
185
+ t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
186
+ timesteps (GroupNorm with group=1)
187
+ t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
188
+ timesteps (GroupNorm with group=1)
189
+ t_norm_first: (bool) if True the norm is before the attention and before the FFN
190
+ t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
191
+ t_max_period: (float) denominator in the sinusoidal embedding expression
192
+ t_weight_decay: (float) weight decay for the transformer
193
+ t_lr: (float) specific learning rate for the transformer
194
+ t_layer_scale: (bool) Layer Scale for the transformer
195
+ t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
196
+ t_weight_pos_embed: (float) weighting of the positional embedding
197
+ t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
198
+ see: https://arxiv.org/abs/2106.03143
199
+ t_cape_augment: (bool) if t_emb="cape", must be True during training and False
200
+ during the inference, see: https://arxiv.org/abs/2106.03143
201
+ t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
202
+ see: https://arxiv.org/abs/2106.03143
203
+ t_sparse_self_attn: (bool) if True, the self attentions are sparse
204
+ t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
205
+ unless you designed really specific masks)
206
+ t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
207
+ with '_' between: i.e. "diag_jmask_random" (note that this is permutation
208
+ invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
209
+ t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
210
+ that generated the random part of the mask
211
+ t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
212
+ a key (j), the mask is True id |i-j|<=t_sparse_attn_window
213
+ t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
214
+ and mask[:, :t_global_window] will be True
215
+ t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
216
+ level of the random part of the mask.
217
+ t_cross_first: (bool) if True cross attention is the first layer of the
218
+ transformer (False seems to be better)
219
+ rescale: weight rescaling trick
220
+ use_train_segment: (bool) if True, the actual size that is used during the
221
+ training is used during inference.
222
+ """
223
+ super().__init__()
224
+ self.cac = cac
225
+ self.wiener_residual = wiener_residual
226
+ self.audio_channels = audio_channels
227
+ self.sources = sources
228
+ self.kernel_size = kernel_size
229
+ self.context = context
230
+ self.stride = stride
231
+ self.depth = depth
232
+ self.bottom_channels = bottom_channels
233
+ self.channels = channels
234
+ self.samplerate = samplerate
235
+ self.segment = segment
236
+ self.use_train_segment = use_train_segment
237
+ self.nfft = nfft
238
+ self.hop_length = nfft // 4
239
+ self.wiener_iters = wiener_iters
240
+ self.end_iters = end_iters
241
+ self.freq_emb = None
242
+ assert wiener_iters == end_iters
243
+
244
+ self.encoder = nn.ModuleList()
245
+ self.decoder = nn.ModuleList()
246
+
247
+ self.tencoder = nn.ModuleList()
248
+ self.tdecoder = nn.ModuleList()
249
+
250
+ chin = audio_channels
251
+ chin_z = chin # number of channels for the freq branch
252
+ if self.cac:
253
+ chin_z *= 2
254
+ chout = channels_time or channels
255
+ chout_z = channels
256
+ freqs = nfft // 2
257
+
258
+ for index in range(depth):
259
+ norm = index >= norm_starts
260
+ freq = freqs > 1
261
+ stri = stride
262
+ ker = kernel_size
263
+ if not freq:
264
+ assert freqs == 1
265
+ ker = time_stride * 2
266
+ stri = time_stride
267
+
268
+ pad = True
269
+ last_freq = False
270
+ if freq and freqs <= kernel_size:
271
+ ker = freqs
272
+ pad = False
273
+ last_freq = True
274
+
275
+ kw = {
276
+ "kernel_size": ker,
277
+ "stride": stri,
278
+ "freq": freq,
279
+ "pad": pad,
280
+ "norm": norm,
281
+ "rewrite": rewrite,
282
+ "norm_groups": norm_groups,
283
+ "dconv_kw": {
284
+ "depth": dconv_depth,
285
+ "compress": dconv_comp,
286
+ "init": dconv_init,
287
+ "gelu": True,
288
+ },
289
+ }
290
+ kwt = dict(kw)
291
+ kwt["freq"] = 0
292
+ kwt["kernel_size"] = kernel_size
293
+ kwt["stride"] = stride
294
+ kwt["pad"] = True
295
+ kw_dec = dict(kw)
296
+ multi = False
297
+ if multi_freqs and index < multi_freqs_depth:
298
+ multi = True
299
+ kw_dec["context_freq"] = False
300
+
301
+ if last_freq:
302
+ chout_z = max(chout, chout_z)
303
+ chout = chout_z
304
+
305
+ enc = HEncLayer(
306
+ chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
307
+ )
308
+ if freq:
309
+ tenc = HEncLayer(
310
+ chin,
311
+ chout,
312
+ dconv=dconv_mode & 1,
313
+ context=context_enc,
314
+ empty=last_freq,
315
+ **kwt
316
+ )
317
+ self.tencoder.append(tenc)
318
+
319
+ if multi:
320
+ enc = MultiWrap(enc, multi_freqs)
321
+ self.encoder.append(enc)
322
+ if index == 0:
323
+ chin = self.audio_channels * len(self.sources)
324
+ chin_z = chin
325
+ if self.cac:
326
+ chin_z *= 2
327
+ dec = HDecLayer(
328
+ chout_z,
329
+ chin_z,
330
+ dconv=dconv_mode & 2,
331
+ last=index == 0,
332
+ context=context,
333
+ **kw_dec
334
+ )
335
+ if multi:
336
+ dec = MultiWrap(dec, multi_freqs)
337
+ if freq:
338
+ tdec = HDecLayer(
339
+ chout,
340
+ chin,
341
+ dconv=dconv_mode & 2,
342
+ empty=last_freq,
343
+ last=index == 0,
344
+ context=context,
345
+ **kwt
346
+ )
347
+ self.tdecoder.insert(0, tdec)
348
+ self.decoder.insert(0, dec)
349
+
350
+ chin = chout
351
+ chin_z = chout_z
352
+ chout = int(growth * chout)
353
+ chout_z = int(growth * chout_z)
354
+ if freq:
355
+ if freqs <= kernel_size:
356
+ freqs = 1
357
+ else:
358
+ freqs //= stride
359
+ if index == 0 and freq_emb:
360
+ self.freq_emb = ScaledEmbedding(
361
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale
362
+ )
363
+ self.freq_emb_scale = freq_emb
364
+
365
+ if rescale:
366
+ rescale_module(self, reference=rescale)
367
+
368
+ transformer_channels = channels * growth ** (depth - 1)
369
+ if bottom_channels:
370
+ self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
371
+ self.channel_downsampler = nn.Conv1d(
372
+ bottom_channels, transformer_channels, 1
373
+ )
374
+ self.channel_upsampler_t = nn.Conv1d(
375
+ transformer_channels, bottom_channels, 1
376
+ )
377
+ self.channel_downsampler_t = nn.Conv1d(
378
+ bottom_channels, transformer_channels, 1
379
+ )
380
+
381
+ transformer_channels = bottom_channels
382
+
383
+ if t_layers > 0:
384
+ self.crosstransformer = CrossTransformerEncoder(
385
+ dim=transformer_channels,
386
+ emb=t_emb,
387
+ hidden_scale=t_hidden_scale,
388
+ num_heads=t_heads,
389
+ num_layers=t_layers,
390
+ cross_first=t_cross_first,
391
+ dropout=t_dropout,
392
+ max_positions=t_max_positions,
393
+ norm_in=t_norm_in,
394
+ norm_in_group=t_norm_in_group,
395
+ group_norm=t_group_norm,
396
+ norm_first=t_norm_first,
397
+ norm_out=t_norm_out,
398
+ max_period=t_max_period,
399
+ weight_decay=t_weight_decay,
400
+ lr=t_lr,
401
+ layer_scale=t_layer_scale,
402
+ gelu=t_gelu,
403
+ sin_random_shift=t_sin_random_shift,
404
+ weight_pos_embed=t_weight_pos_embed,
405
+ cape_mean_normalize=t_cape_mean_normalize,
406
+ cape_augment=t_cape_augment,
407
+ cape_glob_loc_scale=t_cape_glob_loc_scale,
408
+ sparse_self_attn=t_sparse_self_attn,
409
+ sparse_cross_attn=t_sparse_cross_attn,
410
+ mask_type=t_mask_type,
411
+ mask_random_seed=t_mask_random_seed,
412
+ sparse_attn_window=t_sparse_attn_window,
413
+ global_window=t_global_window,
414
+ sparsity=t_sparsity,
415
+ auto_sparsity=t_auto_sparsity,
416
+ )
417
+ else:
418
+ self.crosstransformer = None
419
+
420
+ def _spec(self, x):
421
+ hl = self.hop_length
422
+ nfft = self.nfft
423
+ x0 = x # noqa
424
+
425
+ # We re-pad the signal in order to keep the property
426
+ # that the size of the output is exactly the size of the input
427
+ # divided by the stride (here hop_length), when divisible.
428
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
429
+ # which is not supported by torch.stft.
430
+ # Having all convolution operations follow this convention allow to easily
431
+ # align the time and frequency branches later on.
432
+ assert hl == nfft // 4
433
+ le = int(math.ceil(x.shape[-1] / hl))
434
+ pad = hl // 2 * 3
435
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
436
+
437
+ z = spectro(x, nfft, hl)[..., :-1, :]
438
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
439
+ z = z[..., 2: 2 + le]
440
+ return z
441
+
442
+ def _ispec(self, z, length=None, scale=0):
443
+ hl = self.hop_length // (4**scale)
444
+ z = F.pad(z, (0, 0, 0, 1))
445
+ z = F.pad(z, (2, 2))
446
+ pad = hl // 2 * 3
447
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
448
+ x = ispectro(z, hl, length=le)
449
+ x = x[..., pad: pad + length]
450
+ return x
451
+
452
+ def _magnitude(self, z):
453
+ # return the magnitude of the spectrogram, except when cac is True,
454
+ # in which case we just move the complex dimension to the channel one.
455
+ if self.cac:
456
+ B, C, Fr, T = z.shape
457
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
458
+ m = m.reshape(B, C * 2, Fr, T)
459
+ else:
460
+ m = z.abs()
461
+ return m
462
+
463
+ def _mask(self, z, m):
464
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
465
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
466
+ niters = self.wiener_iters
467
+ if self.cac:
468
+ B, S, C, Fr, T = m.shape
469
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
470
+ out = torch.view_as_complex(out.contiguous())
471
+ return out
472
+ if self.training:
473
+ niters = self.end_iters
474
+ if niters < 0:
475
+ z = z[:, None]
476
+ return z / (1e-8 + z.abs()) * m
477
+ else:
478
+ return self._wiener(m, z, niters)
479
+
480
+ def _wiener(self, mag_out, mix_stft, niters):
481
+ # apply wiener filtering from OpenUnmix.
482
+ init = mix_stft.dtype
483
+ wiener_win_len = 300
484
+ residual = self.wiener_residual
485
+
486
+ B, S, C, Fq, T = mag_out.shape
487
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
488
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
489
+
490
+ outs = []
491
+ for sample in range(B):
492
+ pos = 0
493
+ out = []
494
+ for pos in range(0, T, wiener_win_len):
495
+ frame = slice(pos, pos + wiener_win_len)
496
+ z_out = wiener(
497
+ mag_out[sample, frame],
498
+ mix_stft[sample, frame],
499
+ niters,
500
+ residual=residual,
501
+ )
502
+ out.append(z_out.transpose(-1, -2))
503
+ outs.append(torch.cat(out, dim=0))
504
+ out = torch.view_as_complex(torch.stack(outs, 0))
505
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
506
+ if residual:
507
+ out = out[:, :-1]
508
+ assert list(out.shape) == [B, S, C, Fq, T]
509
+ return out.to(init)
510
+
511
+ def valid_length(self, length: int):
512
+ """
513
+ Return a length that is appropriate for evaluation.
514
+ In our case, always return the training length, unless
515
+ it is smaller than the given length, in which case this
516
+ raises an error.
517
+ """
518
+ if not self.use_train_segment:
519
+ return length
520
+ training_length = int(self.segment * self.samplerate)
521
+ if training_length < length:
522
+ raise ValueError(
523
+ f"Given length {length} is longer than "
524
+ f"training length {training_length}")
525
+ return training_length
526
+
527
+ def forward(self, mix):
528
+ length = mix.shape[-1]
529
+ length_pre_pad = None
530
+ if self.use_train_segment:
531
+ if self.training:
532
+ self.segment = Fraction(mix.shape[-1], self.samplerate)
533
+ else:
534
+ training_length = int(self.segment * self.samplerate)
535
+ if mix.shape[-1] < training_length:
536
+ length_pre_pad = mix.shape[-1]
537
+ mix = F.pad(mix, (0, training_length - length_pre_pad))
538
+ z = self._spec(mix)
539
+ mag = self._magnitude(z)
540
+ x = mag
541
+
542
+ B, C, Fq, T = x.shape
543
+
544
+ # unlike previous Demucs, we always normalize because it is easier.
545
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
546
+ std = x.std(dim=(1, 2, 3), keepdim=True)
547
+ x = (x - mean) / (1e-5 + std)
548
+ # x will be the freq. branch input.
549
+
550
+ # Prepare the time branch input.
551
+ xt = mix
552
+ meant = xt.mean(dim=(1, 2), keepdim=True)
553
+ stdt = xt.std(dim=(1, 2), keepdim=True)
554
+ xt = (xt - meant) / (1e-5 + stdt)
555
+
556
+ # okay, this is a giant mess I know...
557
+ saved = [] # skip connections, freq.
558
+ saved_t = [] # skip connections, time.
559
+ lengths = [] # saved lengths to properly remove padding, freq branch.
560
+ lengths_t = [] # saved lengths for time branch.
561
+ for idx, encode in enumerate(self.encoder):
562
+ lengths.append(x.shape[-1])
563
+ inject = None
564
+ if idx < len(self.tencoder):
565
+ # we have not yet merged branches.
566
+ lengths_t.append(xt.shape[-1])
567
+ tenc = self.tencoder[idx]
568
+ xt = tenc(xt)
569
+ if not tenc.empty:
570
+ # save for skip connection
571
+ saved_t.append(xt)
572
+ else:
573
+ # tenc contains just the first conv., so that now time and freq.
574
+ # branches have the same shape and can be merged.
575
+ inject = xt
576
+ x = encode(x, inject)
577
+ if idx == 0 and self.freq_emb is not None:
578
+ # add frequency embedding to allow for non equivariant convolutions
579
+ # over the frequency axis.
580
+ frs = torch.arange(x.shape[-2], device=x.device)
581
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
582
+ x = x + self.freq_emb_scale * emb
583
+
584
+ saved.append(x)
585
+ if self.crosstransformer:
586
+ if self.bottom_channels:
587
+ b, c, f, t = x.shape
588
+ x = rearrange(x, "b c f t-> b c (f t)")
589
+ x = self.channel_upsampler(x)
590
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
591
+ xt = self.channel_upsampler_t(xt)
592
+
593
+ x, xt = self.crosstransformer(x, xt)
594
+
595
+ if self.bottom_channels:
596
+ x = rearrange(x, "b c f t-> b c (f t)")
597
+ x = self.channel_downsampler(x)
598
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
599
+ xt = self.channel_downsampler_t(xt)
600
+
601
+ for idx, decode in enumerate(self.decoder):
602
+ skip = saved.pop(-1)
603
+ x, pre = decode(x, skip, lengths.pop(-1))
604
+ # `pre` contains the output just before final transposed convolution,
605
+ # which is used when the freq. and time branch separate.
606
+
607
+ offset = self.depth - len(self.tdecoder)
608
+ if idx >= offset:
609
+ tdec = self.tdecoder[idx - offset]
610
+ length_t = lengths_t.pop(-1)
611
+ if tdec.empty:
612
+ assert pre.shape[2] == 1, pre.shape
613
+ pre = pre[:, :, 0]
614
+ xt, _ = tdec(pre, None, length_t)
615
+ else:
616
+ skip = saved_t.pop(-1)
617
+ xt, _ = tdec(xt, skip, length_t)
618
+
619
+ # Let's make sure we used all stored skip connections.
620
+ assert len(saved) == 0
621
+ assert len(lengths_t) == 0
622
+ assert len(saved_t) == 0
623
+
624
+ S = len(self.sources)
625
+ x = x.view(B, S, -1, Fq, T)
626
+ x = x * std[:, None] + mean[:, None]
627
+
628
+ zout = self._mask(z, x)
629
+ if self.use_train_segment:
630
+ if self.training:
631
+ x = self._ispec(zout, length)
632
+ else:
633
+ x = self._ispec(zout, training_length)
634
+ else:
635
+ x = self._ispec(zout, length)
636
+
637
+ if self.use_train_segment:
638
+ if self.training:
639
+ xt = xt.view(B, S, -1, length)
640
+ else:
641
+ xt = xt.view(B, S, -1, training_length)
642
+ else:
643
+ xt = xt.view(B, S, -1, length)
644
+ xt = xt * stdt[:, None] + meant[:, None]
645
+ x = xt + x
646
+ if length_pre_pad:
647
+ x = x[..., :length_pre_pad]
648
+ return x
demucs/model.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import torch as th
10
+ from torch import nn
11
+
12
+ from .utils import capture_init, center_trim
13
+
14
+
15
+ class BLSTM(nn.Module):
16
+ def __init__(self, dim, layers=1):
17
+ super().__init__()
18
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
19
+ self.linear = nn.Linear(2 * dim, dim)
20
+
21
+ def forward(self, x):
22
+ x = x.permute(2, 0, 1)
23
+ x = self.lstm(x)[0]
24
+ x = self.linear(x)
25
+ x = x.permute(1, 2, 0)
26
+ return x
27
+
28
+
29
+ def rescale_conv(conv, reference):
30
+ std = conv.weight.std().detach()
31
+ scale = (std / reference)**0.5
32
+ conv.weight.data /= scale
33
+ if conv.bias is not None:
34
+ conv.bias.data /= scale
35
+
36
+
37
+ def rescale_module(module, reference):
38
+ for sub in module.modules():
39
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
40
+ rescale_conv(sub, reference)
41
+
42
+
43
+ def upsample(x, stride):
44
+ """
45
+ Linear upsampling, the output will be `stride` times longer.
46
+ """
47
+ batch, channels, time = x.size()
48
+ weight = th.arange(stride, device=x.device, dtype=th.float) / stride
49
+ x = x.view(batch, channels, time, 1)
50
+ out = x[..., :-1, :] * (1 - weight) + x[..., 1:, :] * weight
51
+ return out.reshape(batch, channels, -1)
52
+
53
+
54
+ def downsample(x, stride):
55
+ """
56
+ Downsample x by decimation.
57
+ """
58
+ return x[:, :, ::stride]
59
+
60
+
61
+ class Demucs(nn.Module):
62
+ @capture_init
63
+ def __init__(self,
64
+ sources=4,
65
+ audio_channels=2,
66
+ channels=64,
67
+ depth=6,
68
+ rewrite=True,
69
+ glu=True,
70
+ upsample=False,
71
+ rescale=0.1,
72
+ kernel_size=8,
73
+ stride=4,
74
+ growth=2.,
75
+ lstm_layers=2,
76
+ context=3,
77
+ samplerate=44100):
78
+ """
79
+ Args:
80
+ sources (int): number of sources to separate
81
+ audio_channels (int): stereo or mono
82
+ channels (int): first convolution channels
83
+ depth (int): number of encoder/decoder layers
84
+ rewrite (bool): add 1x1 convolution to each encoder layer
85
+ and a convolution to each decoder layer.
86
+ For the decoder layer, `context` gives the kernel size.
87
+ glu (bool): use glu instead of ReLU
88
+ upsample (bool): use linear upsampling with convolutions
89
+ Wave-U-Net style, instead of transposed convolutions
90
+ rescale (int): rescale initial weights of convolutions
91
+ to get their standard deviation closer to `rescale`
92
+ kernel_size (int): kernel size for convolutions
93
+ stride (int): stride for convolutions
94
+ growth (float): multiply (resp divide) number of channels by that
95
+ for each layer of the encoder (resp decoder)
96
+ lstm_layers (int): number of lstm layers, 0 = no lstm
97
+ context (int): kernel size of the convolution in the
98
+ decoder before the transposed convolution. If > 1,
99
+ will provide some context from neighboring time
100
+ steps.
101
+ """
102
+
103
+ super().__init__()
104
+ self.audio_channels = audio_channels
105
+ self.sources = sources
106
+ self.kernel_size = kernel_size
107
+ self.context = context
108
+ self.stride = stride
109
+ self.depth = depth
110
+ self.upsample = upsample
111
+ self.channels = channels
112
+ self.samplerate = samplerate
113
+
114
+ self.encoder = nn.ModuleList()
115
+ self.decoder = nn.ModuleList()
116
+
117
+ self.final = None
118
+ if upsample:
119
+ self.final = nn.Conv1d(channels + audio_channels, sources * audio_channels, 1)
120
+ stride = 1
121
+
122
+ if glu:
123
+ activation = nn.GLU(dim=1)
124
+ ch_scale = 2
125
+ else:
126
+ activation = nn.ReLU()
127
+ ch_scale = 1
128
+ in_channels = audio_channels
129
+ for index in range(depth):
130
+ encode = []
131
+ encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
132
+ if rewrite:
133
+ encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
134
+ self.encoder.append(nn.Sequential(*encode))
135
+
136
+ decode = []
137
+ if index > 0:
138
+ out_channels = in_channels
139
+ else:
140
+ if upsample:
141
+ out_channels = channels
142
+ else:
143
+ out_channels = sources * audio_channels
144
+ if rewrite:
145
+ decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
146
+ if upsample:
147
+ decode += [
148
+ nn.Conv1d(channels, out_channels, kernel_size, stride=1),
149
+ ]
150
+ else:
151
+ decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
152
+ if index > 0:
153
+ decode.append(nn.ReLU())
154
+ self.decoder.insert(0, nn.Sequential(*decode))
155
+ in_channels = channels
156
+ channels = int(growth * channels)
157
+
158
+ channels = in_channels
159
+
160
+ if lstm_layers:
161
+ self.lstm = BLSTM(channels, lstm_layers)
162
+ else:
163
+ self.lstm = None
164
+
165
+ if rescale:
166
+ rescale_module(self, reference=rescale)
167
+
168
+ def valid_length(self, length):
169
+ """
170
+ Return the nearest valid length to use with the model so that
171
+ there is no time steps left over in a convolutions, e.g. for all
172
+ layers, size of the input - kernel_size % stride = 0.
173
+
174
+ If the mixture has a valid length, the estimated sources
175
+ will have exactly the same length when context = 1. If context > 1,
176
+ the two signals can be center trimmed to match.
177
+
178
+ For training, extracts should have a valid length.For evaluation
179
+ on full tracks we recommend passing `pad = True` to :method:`forward`.
180
+ """
181
+ for _ in range(self.depth):
182
+ if self.upsample:
183
+ length = math.ceil(length / self.stride) + self.kernel_size - 1
184
+ else:
185
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
186
+ length = max(1, length)
187
+ length += self.context - 1
188
+ for _ in range(self.depth):
189
+ if self.upsample:
190
+ length = length * self.stride + self.kernel_size - 1
191
+ else:
192
+ length = (length - 1) * self.stride + self.kernel_size
193
+
194
+ return int(length)
195
+
196
+ def forward(self, mix):
197
+ x = mix
198
+ saved = [x]
199
+ for encode in self.encoder:
200
+ x = encode(x)
201
+ saved.append(x)
202
+ if self.upsample:
203
+ x = downsample(x, self.stride)
204
+ if self.lstm:
205
+ x = self.lstm(x)
206
+ for decode in self.decoder:
207
+ if self.upsample:
208
+ x = upsample(x, stride=self.stride)
209
+ skip = center_trim(saved.pop(-1), x)
210
+ x = x + skip
211
+ x = decode(x)
212
+ if self.final:
213
+ skip = center_trim(saved.pop(-1), x)
214
+ x = th.cat([x, skip], dim=1)
215
+ x = self.final(x)
216
+
217
+ x = x.view(x.size(0), self.sources, self.audio_channels, x.size(-1))
218
+ return x
demucs/model_v2.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import julius
10
+ from torch import nn
11
+ from .tasnet_v2 import ConvTasNet
12
+
13
+ from .utils import capture_init, center_trim
14
+
15
+
16
+ class BLSTM(nn.Module):
17
+ def __init__(self, dim, layers=1):
18
+ super().__init__()
19
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
20
+ self.linear = nn.Linear(2 * dim, dim)
21
+
22
+ def forward(self, x):
23
+ x = x.permute(2, 0, 1)
24
+ x = self.lstm(x)[0]
25
+ x = self.linear(x)
26
+ x = x.permute(1, 2, 0)
27
+ return x
28
+
29
+
30
+ def rescale_conv(conv, reference):
31
+ std = conv.weight.std().detach()
32
+ scale = (std / reference)**0.5
33
+ conv.weight.data /= scale
34
+ if conv.bias is not None:
35
+ conv.bias.data /= scale
36
+
37
+
38
+ def rescale_module(module, reference):
39
+ for sub in module.modules():
40
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
41
+ rescale_conv(sub, reference)
42
+
43
+ def auto_load_demucs_model_v2(sources, demucs_model_name):
44
+
45
+ if '48' in demucs_model_name:
46
+ channels=48
47
+ elif 'unittest' in demucs_model_name:
48
+ channels=4
49
+ else:
50
+ channels=64
51
+
52
+ if 'tasnet' in demucs_model_name:
53
+ init_demucs_model = ConvTasNet(sources, X=10)
54
+ else:
55
+ init_demucs_model = Demucs(sources, channels=channels)
56
+
57
+ return init_demucs_model
58
+
59
+ class Demucs(nn.Module):
60
+ @capture_init
61
+ def __init__(self,
62
+ sources,
63
+ audio_channels=2,
64
+ channels=64,
65
+ depth=6,
66
+ rewrite=True,
67
+ glu=True,
68
+ rescale=0.1,
69
+ resample=True,
70
+ kernel_size=8,
71
+ stride=4,
72
+ growth=2.,
73
+ lstm_layers=2,
74
+ context=3,
75
+ normalize=False,
76
+ samplerate=44100,
77
+ segment_length=4 * 10 * 44100):
78
+ """
79
+ Args:
80
+ sources (list[str]): list of source names
81
+ audio_channels (int): stereo or mono
82
+ channels (int): first convolution channels
83
+ depth (int): number of encoder/decoder layers
84
+ rewrite (bool): add 1x1 convolution to each encoder layer
85
+ and a convolution to each decoder layer.
86
+ For the decoder layer, `context` gives the kernel size.
87
+ glu (bool): use glu instead of ReLU
88
+ resample_input (bool): upsample x2 the input and downsample /2 the output.
89
+ rescale (int): rescale initial weights of convolutions
90
+ to get their standard deviation closer to `rescale`
91
+ kernel_size (int): kernel size for convolutions
92
+ stride (int): stride for convolutions
93
+ growth (float): multiply (resp divide) number of channels by that
94
+ for each layer of the encoder (resp decoder)
95
+ lstm_layers (int): number of lstm layers, 0 = no lstm
96
+ context (int): kernel size of the convolution in the
97
+ decoder before the transposed convolution. If > 1,
98
+ will provide some context from neighboring time
99
+ steps.
100
+ samplerate (int): stored as meta information for easing
101
+ future evaluations of the model.
102
+ segment_length (int): stored as meta information for easing
103
+ future evaluations of the model. Length of the segments on which
104
+ the model was trained.
105
+ """
106
+
107
+ super().__init__()
108
+ self.audio_channels = audio_channels
109
+ self.sources = sources
110
+ self.kernel_size = kernel_size
111
+ self.context = context
112
+ self.stride = stride
113
+ self.depth = depth
114
+ self.resample = resample
115
+ self.channels = channels
116
+ self.normalize = normalize
117
+ self.samplerate = samplerate
118
+ self.segment_length = segment_length
119
+
120
+ self.encoder = nn.ModuleList()
121
+ self.decoder = nn.ModuleList()
122
+
123
+ if glu:
124
+ activation = nn.GLU(dim=1)
125
+ ch_scale = 2
126
+ else:
127
+ activation = nn.ReLU()
128
+ ch_scale = 1
129
+ in_channels = audio_channels
130
+ for index in range(depth):
131
+ encode = []
132
+ encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
133
+ if rewrite:
134
+ encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
135
+ self.encoder.append(nn.Sequential(*encode))
136
+
137
+ decode = []
138
+ if index > 0:
139
+ out_channels = in_channels
140
+ else:
141
+ out_channels = len(self.sources) * audio_channels
142
+ if rewrite:
143
+ decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
144
+ decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
145
+ if index > 0:
146
+ decode.append(nn.ReLU())
147
+ self.decoder.insert(0, nn.Sequential(*decode))
148
+ in_channels = channels
149
+ channels = int(growth * channels)
150
+
151
+ channels = in_channels
152
+
153
+ if lstm_layers:
154
+ self.lstm = BLSTM(channels, lstm_layers)
155
+ else:
156
+ self.lstm = None
157
+
158
+ if rescale:
159
+ rescale_module(self, reference=rescale)
160
+
161
+ def valid_length(self, length):
162
+ """
163
+ Return the nearest valid length to use with the model so that
164
+ there is no time steps left over in a convolutions, e.g. for all
165
+ layers, size of the input - kernel_size % stride = 0.
166
+
167
+ If the mixture has a valid length, the estimated sources
168
+ will have exactly the same length when context = 1. If context > 1,
169
+ the two signals can be center trimmed to match.
170
+
171
+ For training, extracts should have a valid length.For evaluation
172
+ on full tracks we recommend passing `pad = True` to :method:`forward`.
173
+ """
174
+ if self.resample:
175
+ length *= 2
176
+ for _ in range(self.depth):
177
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
178
+ length = max(1, length)
179
+ length += self.context - 1
180
+ for _ in range(self.depth):
181
+ length = (length - 1) * self.stride + self.kernel_size
182
+
183
+ if self.resample:
184
+ length = math.ceil(length / 2)
185
+ return int(length)
186
+
187
+ def forward(self, mix):
188
+ x = mix
189
+
190
+ if self.normalize:
191
+ mono = mix.mean(dim=1, keepdim=True)
192
+ mean = mono.mean(dim=-1, keepdim=True)
193
+ std = mono.std(dim=-1, keepdim=True)
194
+ else:
195
+ mean = 0
196
+ std = 1
197
+
198
+ x = (x - mean) / (1e-5 + std)
199
+
200
+ if self.resample:
201
+ x = julius.resample_frac(x, 1, 2)
202
+
203
+ saved = []
204
+ for encode in self.encoder:
205
+ x = encode(x)
206
+ saved.append(x)
207
+ if self.lstm:
208
+ x = self.lstm(x)
209
+ for decode in self.decoder:
210
+ skip = center_trim(saved.pop(-1), x)
211
+ x = x + skip
212
+ x = decode(x)
213
+
214
+ if self.resample:
215
+ x = julius.resample_frac(x, 2, 1)
216
+ x = x * std + mean
217
+ x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
218
+ return x
demucs/pretrained.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Loading pretrained models.
7
+ """
8
+
9
+ import logging
10
+ from pathlib import Path
11
+ import typing as tp
12
+
13
+ #from dora.log import fatal
14
+
15
+ import logging
16
+
17
+ from diffq import DiffQuantizer
18
+ import torch.hub
19
+
20
+ from .model import Demucs
21
+ from .tasnet_v2 import ConvTasNet
22
+ from .utils import set_state
23
+
24
+ from .hdemucs import HDemucs
25
+ from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa
26
+
27
+ logger = logging.getLogger(__name__)
28
+ ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/"
29
+ REMOTE_ROOT = Path(__file__).parent / 'remote'
30
+
31
+ SOURCES = ["drums", "bass", "other", "vocals"]
32
+
33
+
34
+ def demucs_unittest():
35
+ model = HDemucs(channels=4, sources=SOURCES)
36
+ return model
37
+
38
+
39
+ def add_model_flags(parser):
40
+ group = parser.add_mutually_exclusive_group(required=False)
41
+ group.add_argument("-s", "--sig", help="Locally trained XP signature.")
42
+ group.add_argument("-n", "--name", default="mdx_extra_q",
43
+ help="Pretrained model name or signature. Default is mdx_extra_q.")
44
+ parser.add_argument("--repo", type=Path,
45
+ help="Folder containing all pre-trained models for use with -n.")
46
+
47
+
48
+ def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
49
+ root: str = ''
50
+ models: tp.Dict[str, str] = {}
51
+ for line in remote_file_list.read_text().split('\n'):
52
+ line = line.strip()
53
+ if line.startswith('#'):
54
+ continue
55
+ elif line.startswith('root:'):
56
+ root = line.split(':', 1)[1].strip()
57
+ else:
58
+ sig = line.split('-', 1)[0]
59
+ assert sig not in models
60
+ models[sig] = ROOT_URL + root + line
61
+ return models
62
+
63
+ def get_model(name: str,
64
+ repo: tp.Optional[Path] = None):
65
+ """`name` must be a bag of models name or a pretrained signature
66
+ from the remote AWS model repo or the specified local repo if `repo` is not None.
67
+ """
68
+ if name == 'demucs_unittest':
69
+ return demucs_unittest()
70
+ model_repo: ModelOnlyRepo
71
+ if repo is None:
72
+ models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
73
+ model_repo = RemoteRepo(models)
74
+ bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
75
+ else:
76
+ if not repo.is_dir():
77
+ fatal(f"{repo} must exist and be a directory.")
78
+ model_repo = LocalRepo(repo)
79
+ bag_repo = BagOnlyRepo(repo, model_repo)
80
+ any_repo = AnyModelRepo(model_repo, bag_repo)
81
+ model = any_repo.get_model(name)
82
+ model.eval()
83
+ return model
84
+
85
+ def get_model_from_args(args):
86
+ """
87
+ Load local model package or pre-trained model.
88
+ """
89
+ return get_model(name=args.name, repo=args.repo)
90
+
91
+ logger = logging.getLogger(__name__)
92
+ ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/"
93
+
94
+ PRETRAINED_MODELS = {
95
+ 'demucs': 'e07c671f',
96
+ 'demucs48_hq': '28a1282c',
97
+ 'demucs_extra': '3646af93',
98
+ 'demucs_quantized': '07afea75',
99
+ 'tasnet': 'beb46fac',
100
+ 'tasnet_extra': 'df3777b2',
101
+ 'demucs_unittest': '09ebc15f',
102
+ }
103
+
104
+ SOURCES = ["drums", "bass", "other", "vocals"]
105
+
106
+
107
+ def get_url(name):
108
+ sig = PRETRAINED_MODELS[name]
109
+ return ROOT + name + "-" + sig[:8] + ".th"
110
+
111
+ def is_pretrained(name):
112
+ return name in PRETRAINED_MODELS
113
+
114
+
115
+ def load_pretrained(name):
116
+ if name == "demucs":
117
+ return demucs(pretrained=True)
118
+ elif name == "demucs48_hq":
119
+ return demucs(pretrained=True, hq=True, channels=48)
120
+ elif name == "demucs_extra":
121
+ return demucs(pretrained=True, extra=True)
122
+ elif name == "demucs_quantized":
123
+ return demucs(pretrained=True, quantized=True)
124
+ elif name == "demucs_unittest":
125
+ return demucs_unittest(pretrained=True)
126
+ elif name == "tasnet":
127
+ return tasnet(pretrained=True)
128
+ elif name == "tasnet_extra":
129
+ return tasnet(pretrained=True, extra=True)
130
+ else:
131
+ raise ValueError(f"Invalid pretrained name {name}")
132
+
133
+
134
+ def _load_state(name, model, quantizer=None):
135
+ url = get_url(name)
136
+ state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
137
+ set_state(model, quantizer, state)
138
+ if quantizer:
139
+ quantizer.detach()
140
+
141
+
142
+ def demucs_unittest(pretrained=True):
143
+ model = Demucs(channels=4, sources=SOURCES)
144
+ if pretrained:
145
+ _load_state('demucs_unittest', model)
146
+ return model
147
+
148
+
149
+ def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64):
150
+ if not pretrained and (extra or quantized or hq):
151
+ raise ValueError("if extra or quantized is True, pretrained must be True.")
152
+ model = Demucs(sources=SOURCES, channels=channels)
153
+ if pretrained:
154
+ name = 'demucs'
155
+ if channels != 64:
156
+ name += str(channels)
157
+ quantizer = None
158
+ if sum([extra, quantized, hq]) > 1:
159
+ raise ValueError("Only one of extra, quantized, hq, can be True.")
160
+ if quantized:
161
+ quantizer = DiffQuantizer(model, group_size=8, min_size=1)
162
+ name += '_quantized'
163
+ if extra:
164
+ name += '_extra'
165
+ if hq:
166
+ name += '_hq'
167
+ _load_state(name, model, quantizer)
168
+ return model
169
+
170
+
171
+ def tasnet(pretrained=True, extra=False):
172
+ if not pretrained and extra:
173
+ raise ValueError("if extra is True, pretrained must be True.")
174
+ model = ConvTasNet(X=10, sources=SOURCES)
175
+ if pretrained:
176
+ name = 'tasnet'
177
+ if extra:
178
+ name = 'tasnet_extra'
179
+ _load_state(name, model)
180
+ return model
demucs/repo.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Represents a model repository, including pre-trained models and bags of models.
7
+ A repo can either be the main remote repository stored in AWS, or a local repository
8
+ with your own models.
9
+ """
10
+
11
+ from hashlib import sha256
12
+ from pathlib import Path
13
+ import typing as tp
14
+
15
+ import torch
16
+ import yaml
17
+
18
+ from .apply import BagOfModels, Model
19
+ from .states import load_model
20
+
21
+
22
+ AnyModel = tp.Union[Model, BagOfModels]
23
+
24
+
25
+ class ModelLoadingError(RuntimeError):
26
+ pass
27
+
28
+
29
+ def check_checksum(path: Path, checksum: str):
30
+ sha = sha256()
31
+ with open(path, 'rb') as file:
32
+ while True:
33
+ buf = file.read(2**20)
34
+ if not buf:
35
+ break
36
+ sha.update(buf)
37
+ actual_checksum = sha.hexdigest()[:len(checksum)]
38
+ if actual_checksum != checksum:
39
+ raise ModelLoadingError(f'Invalid checksum for file {path}, '
40
+ f'expected {checksum} but got {actual_checksum}')
41
+
42
+ class ModelOnlyRepo:
43
+ """Base class for all model only repos.
44
+ """
45
+ def has_model(self, sig: str) -> bool:
46
+ raise NotImplementedError()
47
+
48
+ def get_model(self, sig: str) -> Model:
49
+ raise NotImplementedError()
50
+
51
+
52
+ class RemoteRepo(ModelOnlyRepo):
53
+ def __init__(self, models: tp.Dict[str, str]):
54
+ self._models = models
55
+
56
+ def has_model(self, sig: str) -> bool:
57
+ return sig in self._models
58
+
59
+ def get_model(self, sig: str) -> Model:
60
+ try:
61
+ url = self._models[sig]
62
+ except KeyError:
63
+ raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.')
64
+ pkg = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
65
+ return load_model(pkg)
66
+
67
+
68
+ class LocalRepo(ModelOnlyRepo):
69
+ def __init__(self, root: Path):
70
+ self.root = root
71
+ self.scan()
72
+
73
+ def scan(self):
74
+ self._models = {}
75
+ self._checksums = {}
76
+ for file in self.root.iterdir():
77
+ if file.suffix == '.th':
78
+ if '-' in file.stem:
79
+ xp_sig, checksum = file.stem.split('-')
80
+ self._checksums[xp_sig] = checksum
81
+ else:
82
+ xp_sig = file.stem
83
+ if xp_sig in self._models:
84
+ print('Whats xp? ', xp_sig)
85
+ raise ModelLoadingError(
86
+ f'Duplicate pre-trained model exist for signature {xp_sig}. '
87
+ 'Please delete all but one.')
88
+ self._models[xp_sig] = file
89
+
90
+ def has_model(self, sig: str) -> bool:
91
+ return sig in self._models
92
+
93
+ def get_model(self, sig: str) -> Model:
94
+ try:
95
+ file = self._models[sig]
96
+ except KeyError:
97
+ raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.')
98
+ if sig in self._checksums:
99
+ check_checksum(file, self._checksums[sig])
100
+ return load_model(file)
101
+
102
+
103
+ class BagOnlyRepo:
104
+ """Handles only YAML files containing bag of models, leaving the actual
105
+ model loading to some Repo.
106
+ """
107
+ def __init__(self, root: Path, model_repo: ModelOnlyRepo):
108
+ self.root = root
109
+ self.model_repo = model_repo
110
+ self.scan()
111
+
112
+ def scan(self):
113
+ self._bags = {}
114
+ for file in self.root.iterdir():
115
+ if file.suffix == '.yaml':
116
+ self._bags[file.stem] = file
117
+
118
+ def has_model(self, name: str) -> bool:
119
+ return name in self._bags
120
+
121
+ def get_model(self, name: str) -> BagOfModels:
122
+ try:
123
+ yaml_file = self._bags[name]
124
+ except KeyError:
125
+ raise ModelLoadingError(f'{name} is neither a single pre-trained model or '
126
+ 'a bag of models.')
127
+ bag = yaml.safe_load(open(yaml_file))
128
+ signatures = bag['models']
129
+ models = [self.model_repo.get_model(sig) for sig in signatures]
130
+ weights = bag.get('weights')
131
+ segment = bag.get('segment')
132
+ return BagOfModels(models, weights, segment)
133
+
134
+
135
+ class AnyModelRepo:
136
+ def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo):
137
+ self.model_repo = model_repo
138
+ self.bag_repo = bag_repo
139
+
140
+ def has_model(self, name_or_sig: str) -> bool:
141
+ return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig)
142
+
143
+ def get_model(self, name_or_sig: str) -> AnyModel:
144
+ print('name_or_sig: ', name_or_sig)
145
+ if self.model_repo.has_model(name_or_sig):
146
+ return self.model_repo.get_model(name_or_sig)
147
+ else:
148
+ return self.bag_repo.get_model(name_or_sig)
demucs/spec.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Conveniance wrapper to perform STFT and iSTFT"""
7
+
8
+ import torch as th
9
+
10
+
11
+ def spectro(x, n_fft=512, hop_length=None, pad=0):
12
+ *other, length = x.shape
13
+ x = x.reshape(-1, length)
14
+ z = th.stft(x,
15
+ n_fft * (1 + pad),
16
+ hop_length or n_fft // 4,
17
+ window=th.hann_window(n_fft).to(x),
18
+ win_length=n_fft,
19
+ normalized=True,
20
+ center=True,
21
+ return_complex=True,
22
+ pad_mode='reflect')
23
+ _, freqs, frame = z.shape
24
+ return z.view(*other, freqs, frame)
25
+
26
+
27
+ def ispectro(z, hop_length=None, length=None, pad=0):
28
+ *other, freqs, frames = z.shape
29
+ n_fft = 2 * freqs - 2
30
+ z = z.view(-1, freqs, frames)
31
+ win_length = n_fft // (1 + pad)
32
+ x = th.istft(z,
33
+ n_fft,
34
+ hop_length,
35
+ window=th.hann_window(win_length).to(z.real),
36
+ win_length=win_length,
37
+ normalized=True,
38
+ length=length,
39
+ center=True)
40
+ _, length = x.shape
41
+ return x.view(*other, length)
demucs/states.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Utilities to save and load models.
8
+ """
9
+ from contextlib import contextmanager
10
+
11
+ import functools
12
+ import hashlib
13
+ import inspect
14
+ import io
15
+ from pathlib import Path
16
+ import warnings
17
+
18
+ from omegaconf import OmegaConf
19
+ from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state
20
+ import torch
21
+
22
+
23
+ def get_quantizer(model, args, optimizer=None):
24
+ """Return the quantizer given the XP quantization args."""
25
+ quantizer = None
26
+ if args.diffq:
27
+ quantizer = DiffQuantizer(
28
+ model, min_size=args.min_size, group_size=args.group_size)
29
+ if optimizer is not None:
30
+ quantizer.setup_optimizer(optimizer)
31
+ elif args.qat:
32
+ quantizer = UniformQuantizer(
33
+ model, bits=args.qat, min_size=args.min_size)
34
+ return quantizer
35
+
36
+
37
+ def load_model(path_or_package, strict=False):
38
+ """Load a model from the given serialized model, either given as a dict (already loaded)
39
+ or a path to a file on disk."""
40
+ if isinstance(path_or_package, dict):
41
+ package = path_or_package
42
+ elif isinstance(path_or_package, (str, Path)):
43
+ with warnings.catch_warnings():
44
+ warnings.simplefilter("ignore")
45
+ path = path_or_package
46
+ package = torch.load(path, 'cpu')
47
+ else:
48
+ raise ValueError(f"Invalid type for {path_or_package}.")
49
+
50
+ klass = package["klass"]
51
+ args = package["args"]
52
+ kwargs = package["kwargs"]
53
+
54
+ if strict:
55
+ model = klass(*args, **kwargs)
56
+ else:
57
+ sig = inspect.signature(klass)
58
+ for key in list(kwargs):
59
+ if key not in sig.parameters:
60
+ warnings.warn("Dropping inexistant parameter " + key)
61
+ del kwargs[key]
62
+ model = klass(*args, **kwargs)
63
+
64
+ state = package["state"]
65
+
66
+ set_state(model, state)
67
+ return model
68
+
69
+
70
+ def get_state(model, quantizer, half=False):
71
+ """Get the state from a model, potentially with quantization applied.
72
+ If `half` is True, model are stored as half precision, which shouldn't impact performance
73
+ but half the state size."""
74
+ if quantizer is None:
75
+ dtype = torch.half if half else None
76
+ state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()}
77
+ else:
78
+ state = quantizer.get_quantized_state()
79
+ state['__quantized'] = True
80
+ return state
81
+
82
+
83
+ def set_state(model, state, quantizer=None):
84
+ """Set the state on a given model."""
85
+ if state.get('__quantized'):
86
+ if quantizer is not None:
87
+ quantizer.restore_quantized_state(model, state['quantized'])
88
+ else:
89
+ restore_quantized_state(model, state)
90
+ else:
91
+ model.load_state_dict(state)
92
+ return state
93
+
94
+
95
+ def save_with_checksum(content, path):
96
+ """Save the given value on disk, along with a sha256 hash.
97
+ Should be used with the output of either `serialize_model` or `get_state`."""
98
+ buf = io.BytesIO()
99
+ torch.save(content, buf)
100
+ sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
101
+
102
+ path = path.parent / (path.stem + "-" + sig + path.suffix)
103
+ path.write_bytes(buf.getvalue())
104
+
105
+
106
+ def serialize_model(model, training_args, quantizer=None, half=True):
107
+ args, kwargs = model._init_args_kwargs
108
+ klass = model.__class__
109
+
110
+ state = get_state(model, quantizer, half)
111
+ return {
112
+ 'klass': klass,
113
+ 'args': args,
114
+ 'kwargs': kwargs,
115
+ 'state': state,
116
+ 'training_args': OmegaConf.to_container(training_args, resolve=True),
117
+ }
118
+
119
+
120
+ def copy_state(state):
121
+ return {k: v.cpu().clone() for k, v in state.items()}
122
+
123
+
124
+ @contextmanager
125
+ def swap_state(model, state):
126
+ """
127
+ Context manager that swaps the state of a model, e.g:
128
+
129
+ # model is in old state
130
+ with swap_state(model, new_state):
131
+ # model in new state
132
+ # model back to old state
133
+ """
134
+ old_state = copy_state(model.state_dict())
135
+ model.load_state_dict(state, strict=False)
136
+ try:
137
+ yield
138
+ finally:
139
+ model.load_state_dict(old_state)
140
+
141
+
142
+ def capture_init(init):
143
+ @functools.wraps(init)
144
+ def __init__(self, *args, **kwargs):
145
+ self._init_args_kwargs = (args, kwargs)
146
+ init(self, *args, **kwargs)
147
+
148
+ return __init__
demucs/tasnet.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Created on 2018/12
8
+ # Author: Kaituo XU
9
+ # Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
10
+ # Here is the original license:
11
+ # The MIT License (MIT)
12
+ #
13
+ # Copyright (c) 2018 Kaituo XU
14
+ #
15
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
16
+ # of this software and associated documentation files (the "Software"), to deal
17
+ # in the Software without restriction, including without limitation the rights
18
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
19
+ # copies of the Software, and to permit persons to whom the Software is
20
+ # furnished to do so, subject to the following conditions:
21
+ #
22
+ # The above copyright notice and this permission notice shall be included in all
23
+ # copies or substantial portions of the Software.
24
+ #
25
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31
+ # SOFTWARE.
32
+
33
+ import math
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+
39
+ from .utils import capture_init
40
+
41
+ EPS = 1e-8
42
+
43
+
44
+ def overlap_and_add(signal, frame_step):
45
+ outer_dimensions = signal.size()[:-2]
46
+ frames, frame_length = signal.size()[-2:]
47
+
48
+ subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
49
+ subframe_step = frame_step // subframe_length
50
+ subframes_per_frame = frame_length // subframe_length
51
+ output_size = frame_step * (frames - 1) + frame_length
52
+ output_subframes = output_size // subframe_length
53
+
54
+ subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
55
+
56
+ frame = torch.arange(0, output_subframes,
57
+ device=signal.device).unfold(0, subframes_per_frame, subframe_step)
58
+ frame = frame.long() # signal may in GPU or CPU
59
+ frame = frame.contiguous().view(-1)
60
+
61
+ result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
62
+ result.index_add_(-2, frame, subframe_signal)
63
+ result = result.view(*outer_dimensions, -1)
64
+ return result
65
+
66
+
67
+ class ConvTasNet(nn.Module):
68
+ @capture_init
69
+ def __init__(self,
70
+ N=256,
71
+ L=20,
72
+ B=256,
73
+ H=512,
74
+ P=3,
75
+ X=8,
76
+ R=4,
77
+ C=4,
78
+ audio_channels=1,
79
+ samplerate=44100,
80
+ norm_type="gLN",
81
+ causal=False,
82
+ mask_nonlinear='relu'):
83
+ """
84
+ Args:
85
+ N: Number of filters in autoencoder
86
+ L: Length of the filters (in samples)
87
+ B: Number of channels in bottleneck 1 × 1-conv block
88
+ H: Number of channels in convolutional blocks
89
+ P: Kernel size in convolutional blocks
90
+ X: Number of convolutional blocks in each repeat
91
+ R: Number of repeats
92
+ C: Number of speakers
93
+ norm_type: BN, gLN, cLN
94
+ causal: causal or non-causal
95
+ mask_nonlinear: use which non-linear function to generate mask
96
+ """
97
+ super(ConvTasNet, self).__init__()
98
+ # Hyper-parameter
99
+ self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = N, L, B, H, P, X, R, C
100
+ self.norm_type = norm_type
101
+ self.causal = causal
102
+ self.mask_nonlinear = mask_nonlinear
103
+ self.audio_channels = audio_channels
104
+ self.samplerate = samplerate
105
+ # Components
106
+ self.encoder = Encoder(L, N, audio_channels)
107
+ self.separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear)
108
+ self.decoder = Decoder(N, L, audio_channels)
109
+ # init
110
+ for p in self.parameters():
111
+ if p.dim() > 1:
112
+ nn.init.xavier_normal_(p)
113
+
114
+ def valid_length(self, length):
115
+ return length
116
+
117
+ def forward(self, mixture):
118
+ """
119
+ Args:
120
+ mixture: [M, T], M is batch size, T is #samples
121
+ Returns:
122
+ est_source: [M, C, T]
123
+ """
124
+ mixture_w = self.encoder(mixture)
125
+ est_mask = self.separator(mixture_w)
126
+ est_source = self.decoder(mixture_w, est_mask)
127
+
128
+ # T changed after conv1d in encoder, fix it here
129
+ T_origin = mixture.size(-1)
130
+ T_conv = est_source.size(-1)
131
+ est_source = F.pad(est_source, (0, T_origin - T_conv))
132
+ return est_source
133
+
134
+
135
+ class Encoder(nn.Module):
136
+ """Estimation of the nonnegative mixture weight by a 1-D conv layer.
137
+ """
138
+ def __init__(self, L, N, audio_channels):
139
+ super(Encoder, self).__init__()
140
+ # Hyper-parameter
141
+ self.L, self.N = L, N
142
+ # Components
143
+ # 50% overlap
144
+ self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
145
+
146
+ def forward(self, mixture):
147
+ """
148
+ Args:
149
+ mixture: [M, T], M is batch size, T is #samples
150
+ Returns:
151
+ mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
152
+ """
153
+ mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
154
+ return mixture_w
155
+
156
+
157
+ class Decoder(nn.Module):
158
+ def __init__(self, N, L, audio_channels):
159
+ super(Decoder, self).__init__()
160
+ # Hyper-parameter
161
+ self.N, self.L = N, L
162
+ self.audio_channels = audio_channels
163
+ # Components
164
+ self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
165
+
166
+ def forward(self, mixture_w, est_mask):
167
+ """
168
+ Args:
169
+ mixture_w: [M, N, K]
170
+ est_mask: [M, C, N, K]
171
+ Returns:
172
+ est_source: [M, C, T]
173
+ """
174
+ # D = W * M
175
+ source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
176
+ source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
177
+ # S = DV
178
+ est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
179
+ m, c, k, _ = est_source.size()
180
+ est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
181
+ est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
182
+ return est_source
183
+
184
+
185
+ class TemporalConvNet(nn.Module):
186
+ def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'):
187
+ """
188
+ Args:
189
+ N: Number of filters in autoencoder
190
+ B: Number of channels in bottleneck 1 × 1-conv block
191
+ H: Number of channels in convolutional blocks
192
+ P: Kernel size in convolutional blocks
193
+ X: Number of convolutional blocks in each repeat
194
+ R: Number of repeats
195
+ C: Number of speakers
196
+ norm_type: BN, gLN, cLN
197
+ causal: causal or non-causal
198
+ mask_nonlinear: use which non-linear function to generate mask
199
+ """
200
+ super(TemporalConvNet, self).__init__()
201
+ # Hyper-parameter
202
+ self.C = C
203
+ self.mask_nonlinear = mask_nonlinear
204
+ # Components
205
+ # [M, N, K] -> [M, N, K]
206
+ layer_norm = ChannelwiseLayerNorm(N)
207
+ # [M, N, K] -> [M, B, K]
208
+ bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
209
+ # [M, B, K] -> [M, B, K]
210
+ repeats = []
211
+ for r in range(R):
212
+ blocks = []
213
+ for x in range(X):
214
+ dilation = 2**x
215
+ padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
216
+ blocks += [
217
+ TemporalBlock(B,
218
+ H,
219
+ P,
220
+ stride=1,
221
+ padding=padding,
222
+ dilation=dilation,
223
+ norm_type=norm_type,
224
+ causal=causal)
225
+ ]
226
+ repeats += [nn.Sequential(*blocks)]
227
+ temporal_conv_net = nn.Sequential(*repeats)
228
+ # [M, B, K] -> [M, C*N, K]
229
+ mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
230
+ # Put together
231
+ self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net,
232
+ mask_conv1x1)
233
+
234
+ def forward(self, mixture_w):
235
+ """
236
+ Keep this API same with TasNet
237
+ Args:
238
+ mixture_w: [M, N, K], M is batch size
239
+ returns:
240
+ est_mask: [M, C, N, K]
241
+ """
242
+ M, N, K = mixture_w.size()
243
+ score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
244
+ score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
245
+ if self.mask_nonlinear == 'softmax':
246
+ est_mask = F.softmax(score, dim=1)
247
+ elif self.mask_nonlinear == 'relu':
248
+ est_mask = F.relu(score)
249
+ else:
250
+ raise ValueError("Unsupported mask non-linear function")
251
+ return est_mask
252
+
253
+
254
+ class TemporalBlock(nn.Module):
255
+ def __init__(self,
256
+ in_channels,
257
+ out_channels,
258
+ kernel_size,
259
+ stride,
260
+ padding,
261
+ dilation,
262
+ norm_type="gLN",
263
+ causal=False):
264
+ super(TemporalBlock, self).__init__()
265
+ # [M, B, K] -> [M, H, K]
266
+ conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
267
+ prelu = nn.PReLU()
268
+ norm = chose_norm(norm_type, out_channels)
269
+ # [M, H, K] -> [M, B, K]
270
+ dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding,
271
+ dilation, norm_type, causal)
272
+ # Put together
273
+ self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
274
+
275
+ def forward(self, x):
276
+ """
277
+ Args:
278
+ x: [M, B, K]
279
+ Returns:
280
+ [M, B, K]
281
+ """
282
+ residual = x
283
+ out = self.net(x)
284
+ # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
285
+ return out + residual # look like w/o F.relu is better than w/ F.relu
286
+ # return F.relu(out + residual)
287
+
288
+
289
+ class DepthwiseSeparableConv(nn.Module):
290
+ def __init__(self,
291
+ in_channels,
292
+ out_channels,
293
+ kernel_size,
294
+ stride,
295
+ padding,
296
+ dilation,
297
+ norm_type="gLN",
298
+ causal=False):
299
+ super(DepthwiseSeparableConv, self).__init__()
300
+ # Use `groups` option to implement depthwise convolution
301
+ # [M, H, K] -> [M, H, K]
302
+ depthwise_conv = nn.Conv1d(in_channels,
303
+ in_channels,
304
+ kernel_size,
305
+ stride=stride,
306
+ padding=padding,
307
+ dilation=dilation,
308
+ groups=in_channels,
309
+ bias=False)
310
+ if causal:
311
+ chomp = Chomp1d(padding)
312
+ prelu = nn.PReLU()
313
+ norm = chose_norm(norm_type, in_channels)
314
+ # [M, H, K] -> [M, B, K]
315
+ pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
316
+ # Put together
317
+ if causal:
318
+ self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
319
+ else:
320
+ self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
321
+
322
+ def forward(self, x):
323
+ """
324
+ Args:
325
+ x: [M, H, K]
326
+ Returns:
327
+ result: [M, B, K]
328
+ """
329
+ return self.net(x)
330
+
331
+
332
+ class Chomp1d(nn.Module):
333
+ """To ensure the output length is the same as the input.
334
+ """
335
+ def __init__(self, chomp_size):
336
+ super(Chomp1d, self).__init__()
337
+ self.chomp_size = chomp_size
338
+
339
+ def forward(self, x):
340
+ """
341
+ Args:
342
+ x: [M, H, Kpad]
343
+ Returns:
344
+ [M, H, K]
345
+ """
346
+ return x[:, :, :-self.chomp_size].contiguous()
347
+
348
+
349
+ def chose_norm(norm_type, channel_size):
350
+ """The input of normlization will be (M, C, K), where M is batch size,
351
+ C is channel size and K is sequence length.
352
+ """
353
+ if norm_type == "gLN":
354
+ return GlobalLayerNorm(channel_size)
355
+ elif norm_type == "cLN":
356
+ return ChannelwiseLayerNorm(channel_size)
357
+ elif norm_type == "id":
358
+ return nn.Identity()
359
+ else: # norm_type == "BN":
360
+ # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
361
+ # along M and K, so this BN usage is right.
362
+ return nn.BatchNorm1d(channel_size)
363
+
364
+
365
+ # TODO: Use nn.LayerNorm to impl cLN to speed up
366
+ class ChannelwiseLayerNorm(nn.Module):
367
+ """Channel-wise Layer Normalization (cLN)"""
368
+ def __init__(self, channel_size):
369
+ super(ChannelwiseLayerNorm, self).__init__()
370
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
371
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
372
+ self.reset_parameters()
373
+
374
+ def reset_parameters(self):
375
+ self.gamma.data.fill_(1)
376
+ self.beta.data.zero_()
377
+
378
+ def forward(self, y):
379
+ """
380
+ Args:
381
+ y: [M, N, K], M is batch size, N is channel size, K is length
382
+ Returns:
383
+ cLN_y: [M, N, K]
384
+ """
385
+ mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
386
+ var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
387
+ cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
388
+ return cLN_y
389
+
390
+
391
+ class GlobalLayerNorm(nn.Module):
392
+ """Global Layer Normalization (gLN)"""
393
+ def __init__(self, channel_size):
394
+ super(GlobalLayerNorm, self).__init__()
395
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
396
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
397
+ self.reset_parameters()
398
+
399
+ def reset_parameters(self):
400
+ self.gamma.data.fill_(1)
401
+ self.beta.data.zero_()
402
+
403
+ def forward(self, y):
404
+ """
405
+ Args:
406
+ y: [M, N, K], M is batch size, N is channel size, K is length
407
+ Returns:
408
+ gLN_y: [M, N, K]
409
+ """
410
+ # TODO: in torch 1.0, torch.mean() support dim list
411
+ mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
412
+ var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
413
+ gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
414
+ return gLN_y
415
+
416
+
417
+ if __name__ == "__main__":
418
+ torch.manual_seed(123)
419
+ M, N, L, T = 2, 3, 4, 12
420
+ K = 2 * T // L - 1
421
+ B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
422
+ mixture = torch.randint(3, (M, T))
423
+ # test Encoder
424
+ encoder = Encoder(L, N)
425
+ encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
426
+ mixture_w = encoder(mixture)
427
+ print('mixture', mixture)
428
+ print('U', encoder.conv1d_U.weight)
429
+ print('mixture_w', mixture_w)
430
+ print('mixture_w size', mixture_w.size())
431
+
432
+ # test TemporalConvNet
433
+ separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
434
+ est_mask = separator(mixture_w)
435
+ print('est_mask', est_mask)
436
+
437
+ # test Decoder
438
+ decoder = Decoder(N, L)
439
+ est_mask = torch.randint(2, (B, K, C, N))
440
+ est_source = decoder(mixture_w, est_mask)
441
+ print('est_source', est_source)
442
+
443
+ # test Conv-TasNet
444
+ conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
445
+ est_source = conv_tasnet(mixture)
446
+ print('est_source', est_source)
447
+ print('est_source size', est_source.size())
demucs/tasnet_v2.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Created on 2018/12
8
+ # Author: Kaituo XU
9
+ # Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
10
+ # Here is the original license:
11
+ # The MIT License (MIT)
12
+ #
13
+ # Copyright (c) 2018 Kaituo XU
14
+ #
15
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
16
+ # of this software and associated documentation files (the "Software"), to deal
17
+ # in the Software without restriction, including without limitation the rights
18
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
19
+ # copies of the Software, and to permit persons to whom the Software is
20
+ # furnished to do so, subject to the following conditions:
21
+ #
22
+ # The above copyright notice and this permission notice shall be included in all
23
+ # copies or substantial portions of the Software.
24
+ #
25
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31
+ # SOFTWARE.
32
+
33
+ import math
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+
39
+ from .utils import capture_init
40
+
41
+ EPS = 1e-8
42
+
43
+
44
+ def overlap_and_add(signal, frame_step):
45
+ outer_dimensions = signal.size()[:-2]
46
+ frames, frame_length = signal.size()[-2:]
47
+
48
+ subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
49
+ subframe_step = frame_step // subframe_length
50
+ subframes_per_frame = frame_length // subframe_length
51
+ output_size = frame_step * (frames - 1) + frame_length
52
+ output_subframes = output_size // subframe_length
53
+
54
+ subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
55
+
56
+ frame = torch.arange(0, output_subframes,
57
+ device=signal.device).unfold(0, subframes_per_frame, subframe_step)
58
+ frame = frame.long() # signal may in GPU or CPU
59
+ frame = frame.contiguous().view(-1)
60
+
61
+ result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
62
+ result.index_add_(-2, frame, subframe_signal)
63
+ result = result.view(*outer_dimensions, -1)
64
+ return result
65
+
66
+
67
+ class ConvTasNet(nn.Module):
68
+ @capture_init
69
+ def __init__(self,
70
+ sources,
71
+ N=256,
72
+ L=20,
73
+ B=256,
74
+ H=512,
75
+ P=3,
76
+ X=8,
77
+ R=4,
78
+ audio_channels=2,
79
+ norm_type="gLN",
80
+ causal=False,
81
+ mask_nonlinear='relu',
82
+ samplerate=44100,
83
+ segment_length=44100 * 2 * 4):
84
+ """
85
+ Args:
86
+ sources: list of sources
87
+ N: Number of filters in autoencoder
88
+ L: Length of the filters (in samples)
89
+ B: Number of channels in bottleneck 1 × 1-conv block
90
+ H: Number of channels in convolutional blocks
91
+ P: Kernel size in convolutional blocks
92
+ X: Number of convolutional blocks in each repeat
93
+ R: Number of repeats
94
+ norm_type: BN, gLN, cLN
95
+ causal: causal or non-causal
96
+ mask_nonlinear: use which non-linear function to generate mask
97
+ """
98
+ super(ConvTasNet, self).__init__()
99
+ # Hyper-parameter
100
+ self.sources = sources
101
+ self.C = len(sources)
102
+ self.N, self.L, self.B, self.H, self.P, self.X, self.R = N, L, B, H, P, X, R
103
+ self.norm_type = norm_type
104
+ self.causal = causal
105
+ self.mask_nonlinear = mask_nonlinear
106
+ self.audio_channels = audio_channels
107
+ self.samplerate = samplerate
108
+ self.segment_length = segment_length
109
+ # Components
110
+ self.encoder = Encoder(L, N, audio_channels)
111
+ self.separator = TemporalConvNet(
112
+ N, B, H, P, X, R, self.C, norm_type, causal, mask_nonlinear)
113
+ self.decoder = Decoder(N, L, audio_channels)
114
+ # init
115
+ for p in self.parameters():
116
+ if p.dim() > 1:
117
+ nn.init.xavier_normal_(p)
118
+
119
+ def valid_length(self, length):
120
+ return length
121
+
122
+ def forward(self, mixture):
123
+ """
124
+ Args:
125
+ mixture: [M, T], M is batch size, T is #samples
126
+ Returns:
127
+ est_source: [M, C, T]
128
+ """
129
+ mixture_w = self.encoder(mixture)
130
+ est_mask = self.separator(mixture_w)
131
+ est_source = self.decoder(mixture_w, est_mask)
132
+
133
+ # T changed after conv1d in encoder, fix it here
134
+ T_origin = mixture.size(-1)
135
+ T_conv = est_source.size(-1)
136
+ est_source = F.pad(est_source, (0, T_origin - T_conv))
137
+ return est_source
138
+
139
+
140
+ class Encoder(nn.Module):
141
+ """Estimation of the nonnegative mixture weight by a 1-D conv layer.
142
+ """
143
+ def __init__(self, L, N, audio_channels):
144
+ super(Encoder, self).__init__()
145
+ # Hyper-parameter
146
+ self.L, self.N = L, N
147
+ # Components
148
+ # 50% overlap
149
+ self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
150
+
151
+ def forward(self, mixture):
152
+ """
153
+ Args:
154
+ mixture: [M, T], M is batch size, T is #samples
155
+ Returns:
156
+ mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
157
+ """
158
+ mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
159
+ return mixture_w
160
+
161
+
162
+ class Decoder(nn.Module):
163
+ def __init__(self, N, L, audio_channels):
164
+ super(Decoder, self).__init__()
165
+ # Hyper-parameter
166
+ self.N, self.L = N, L
167
+ self.audio_channels = audio_channels
168
+ # Components
169
+ self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
170
+
171
+ def forward(self, mixture_w, est_mask):
172
+ """
173
+ Args:
174
+ mixture_w: [M, N, K]
175
+ est_mask: [M, C, N, K]
176
+ Returns:
177
+ est_source: [M, C, T]
178
+ """
179
+ # D = W * M
180
+ source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
181
+ source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
182
+ # S = DV
183
+ est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
184
+ m, c, k, _ = est_source.size()
185
+ est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
186
+ est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
187
+ return est_source
188
+
189
+
190
+ class TemporalConvNet(nn.Module):
191
+ def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'):
192
+ """
193
+ Args:
194
+ N: Number of filters in autoencoder
195
+ B: Number of channels in bottleneck 1 × 1-conv block
196
+ H: Number of channels in convolutional blocks
197
+ P: Kernel size in convolutional blocks
198
+ X: Number of convolutional blocks in each repeat
199
+ R: Number of repeats
200
+ C: Number of speakers
201
+ norm_type: BN, gLN, cLN
202
+ causal: causal or non-causal
203
+ mask_nonlinear: use which non-linear function to generate mask
204
+ """
205
+ super(TemporalConvNet, self).__init__()
206
+ # Hyper-parameter
207
+ self.C = C
208
+ self.mask_nonlinear = mask_nonlinear
209
+ # Components
210
+ # [M, N, K] -> [M, N, K]
211
+ layer_norm = ChannelwiseLayerNorm(N)
212
+ # [M, N, K] -> [M, B, K]
213
+ bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
214
+ # [M, B, K] -> [M, B, K]
215
+ repeats = []
216
+ for r in range(R):
217
+ blocks = []
218
+ for x in range(X):
219
+ dilation = 2**x
220
+ padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
221
+ blocks += [
222
+ TemporalBlock(B,
223
+ H,
224
+ P,
225
+ stride=1,
226
+ padding=padding,
227
+ dilation=dilation,
228
+ norm_type=norm_type,
229
+ causal=causal)
230
+ ]
231
+ repeats += [nn.Sequential(*blocks)]
232
+ temporal_conv_net = nn.Sequential(*repeats)
233
+ # [M, B, K] -> [M, C*N, K]
234
+ mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
235
+ # Put together
236
+ self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net,
237
+ mask_conv1x1)
238
+
239
+ def forward(self, mixture_w):
240
+ """
241
+ Keep this API same with TasNet
242
+ Args:
243
+ mixture_w: [M, N, K], M is batch size
244
+ returns:
245
+ est_mask: [M, C, N, K]
246
+ """
247
+ M, N, K = mixture_w.size()
248
+ score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
249
+ score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
250
+ if self.mask_nonlinear == 'softmax':
251
+ est_mask = F.softmax(score, dim=1)
252
+ elif self.mask_nonlinear == 'relu':
253
+ est_mask = F.relu(score)
254
+ else:
255
+ raise ValueError("Unsupported mask non-linear function")
256
+ return est_mask
257
+
258
+
259
+ class TemporalBlock(nn.Module):
260
+ def __init__(self,
261
+ in_channels,
262
+ out_channels,
263
+ kernel_size,
264
+ stride,
265
+ padding,
266
+ dilation,
267
+ norm_type="gLN",
268
+ causal=False):
269
+ super(TemporalBlock, self).__init__()
270
+ # [M, B, K] -> [M, H, K]
271
+ conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
272
+ prelu = nn.PReLU()
273
+ norm = chose_norm(norm_type, out_channels)
274
+ # [M, H, K] -> [M, B, K]
275
+ dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding,
276
+ dilation, norm_type, causal)
277
+ # Put together
278
+ self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
279
+
280
+ def forward(self, x):
281
+ """
282
+ Args:
283
+ x: [M, B, K]
284
+ Returns:
285
+ [M, B, K]
286
+ """
287
+ residual = x
288
+ out = self.net(x)
289
+ # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
290
+ return out + residual # look like w/o F.relu is better than w/ F.relu
291
+ # return F.relu(out + residual)
292
+
293
+
294
+ class DepthwiseSeparableConv(nn.Module):
295
+ def __init__(self,
296
+ in_channels,
297
+ out_channels,
298
+ kernel_size,
299
+ stride,
300
+ padding,
301
+ dilation,
302
+ norm_type="gLN",
303
+ causal=False):
304
+ super(DepthwiseSeparableConv, self).__init__()
305
+ # Use `groups` option to implement depthwise convolution
306
+ # [M, H, K] -> [M, H, K]
307
+ depthwise_conv = nn.Conv1d(in_channels,
308
+ in_channels,
309
+ kernel_size,
310
+ stride=stride,
311
+ padding=padding,
312
+ dilation=dilation,
313
+ groups=in_channels,
314
+ bias=False)
315
+ if causal:
316
+ chomp = Chomp1d(padding)
317
+ prelu = nn.PReLU()
318
+ norm = chose_norm(norm_type, in_channels)
319
+ # [M, H, K] -> [M, B, K]
320
+ pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
321
+ # Put together
322
+ if causal:
323
+ self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
324
+ else:
325
+ self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
326
+
327
+ def forward(self, x):
328
+ """
329
+ Args:
330
+ x: [M, H, K]
331
+ Returns:
332
+ result: [M, B, K]
333
+ """
334
+ return self.net(x)
335
+
336
+
337
+ class Chomp1d(nn.Module):
338
+ """To ensure the output length is the same as the input.
339
+ """
340
+ def __init__(self, chomp_size):
341
+ super(Chomp1d, self).__init__()
342
+ self.chomp_size = chomp_size
343
+
344
+ def forward(self, x):
345
+ """
346
+ Args:
347
+ x: [M, H, Kpad]
348
+ Returns:
349
+ [M, H, K]
350
+ """
351
+ return x[:, :, :-self.chomp_size].contiguous()
352
+
353
+
354
+ def chose_norm(norm_type, channel_size):
355
+ """The input of normlization will be (M, C, K), where M is batch size,
356
+ C is channel size and K is sequence length.
357
+ """
358
+ if norm_type == "gLN":
359
+ return GlobalLayerNorm(channel_size)
360
+ elif norm_type == "cLN":
361
+ return ChannelwiseLayerNorm(channel_size)
362
+ elif norm_type == "id":
363
+ return nn.Identity()
364
+ else: # norm_type == "BN":
365
+ # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
366
+ # along M and K, so this BN usage is right.
367
+ return nn.BatchNorm1d(channel_size)
368
+
369
+
370
+ # TODO: Use nn.LayerNorm to impl cLN to speed up
371
+ class ChannelwiseLayerNorm(nn.Module):
372
+ """Channel-wise Layer Normalization (cLN)"""
373
+ def __init__(self, channel_size):
374
+ super(ChannelwiseLayerNorm, self).__init__()
375
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
376
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
377
+ self.reset_parameters()
378
+
379
+ def reset_parameters(self):
380
+ self.gamma.data.fill_(1)
381
+ self.beta.data.zero_()
382
+
383
+ def forward(self, y):
384
+ """
385
+ Args:
386
+ y: [M, N, K], M is batch size, N is channel size, K is length
387
+ Returns:
388
+ cLN_y: [M, N, K]
389
+ """
390
+ mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
391
+ var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
392
+ cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
393
+ return cLN_y
394
+
395
+
396
+ class GlobalLayerNorm(nn.Module):
397
+ """Global Layer Normalization (gLN)"""
398
+ def __init__(self, channel_size):
399
+ super(GlobalLayerNorm, self).__init__()
400
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
401
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
402
+ self.reset_parameters()
403
+
404
+ def reset_parameters(self):
405
+ self.gamma.data.fill_(1)
406
+ self.beta.data.zero_()
407
+
408
+ def forward(self, y):
409
+ """
410
+ Args:
411
+ y: [M, N, K], M is batch size, N is channel size, K is length
412
+ Returns:
413
+ gLN_y: [M, N, K]
414
+ """
415
+ # TODO: in torch 1.0, torch.mean() support dim list
416
+ mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
417
+ var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
418
+ gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
419
+ return gLN_y
420
+
421
+
422
+ if __name__ == "__main__":
423
+ torch.manual_seed(123)
424
+ M, N, L, T = 2, 3, 4, 12
425
+ K = 2 * T // L - 1
426
+ B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
427
+ mixture = torch.randint(3, (M, T))
428
+ # test Encoder
429
+ encoder = Encoder(L, N)
430
+ encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
431
+ mixture_w = encoder(mixture)
432
+ print('mixture', mixture)
433
+ print('U', encoder.conv1d_U.weight)
434
+ print('mixture_w', mixture_w)
435
+ print('mixture_w size', mixture_w.size())
436
+
437
+ # test TemporalConvNet
438
+ separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
439
+ est_mask = separator(mixture_w)
440
+ print('est_mask', est_mask)
441
+
442
+ # test Decoder
443
+ decoder = Decoder(N, L)
444
+ est_mask = torch.randint(2, (B, K, C, N))
445
+ est_source = decoder(mixture_w, est_mask)
446
+ print('est_source', est_source)
447
+
448
+ # test Conv-TasNet
449
+ conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
450
+ est_source = conv_tasnet(mixture)
451
+ print('est_source', est_source)
452
+ print('est_source size', est_source.size())
demucs/transformer.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-present, Meta, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # First author is Simon Rouard.
7
+
8
+ import random
9
+ import typing as tp
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import numpy as np
15
+ import math
16
+ from einops import rearrange
17
+
18
+
19
+ def create_sin_embedding(
20
+ length: int, dim: int, shift: int = 0, device="cpu", max_period=10000
21
+ ):
22
+ # We aim for TBC format
23
+ assert dim % 2 == 0
24
+ pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
25
+ half_dim = dim // 2
26
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
27
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
28
+ return torch.cat(
29
+ [
30
+ torch.cos(phase),
31
+ torch.sin(phase),
32
+ ],
33
+ dim=-1,
34
+ )
35
+
36
+
37
+ def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
38
+ """
39
+ :param d_model: dimension of the model
40
+ :param height: height of the positions
41
+ :param width: width of the positions
42
+ :return: d_model*height*width position matrix
43
+ """
44
+ if d_model % 4 != 0:
45
+ raise ValueError(
46
+ "Cannot use sin/cos positional encoding with "
47
+ "odd dimension (got dim={:d})".format(d_model)
48
+ )
49
+ pe = torch.zeros(d_model, height, width)
50
+ # Each dimension use half of d_model
51
+ d_model = int(d_model / 2)
52
+ div_term = torch.exp(
53
+ torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model)
54
+ )
55
+ pos_w = torch.arange(0.0, width).unsqueeze(1)
56
+ pos_h = torch.arange(0.0, height).unsqueeze(1)
57
+ pe[0:d_model:2, :, :] = (
58
+ torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
59
+ )
60
+ pe[1:d_model:2, :, :] = (
61
+ torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
62
+ )
63
+ pe[d_model::2, :, :] = (
64
+ torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
65
+ )
66
+ pe[d_model + 1:: 2, :, :] = (
67
+ torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
68
+ )
69
+
70
+ return pe[None, :].to(device)
71
+
72
+
73
+ def create_sin_embedding_cape(
74
+ length: int,
75
+ dim: int,
76
+ batch_size: int,
77
+ mean_normalize: bool,
78
+ augment: bool, # True during training
79
+ max_global_shift: float = 0.0, # delta max
80
+ max_local_shift: float = 0.0, # epsilon max
81
+ max_scale: float = 1.0,
82
+ device: str = "cpu",
83
+ max_period: float = 10000.0,
84
+ ):
85
+ # We aim for TBC format
86
+ assert dim % 2 == 0
87
+ pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1)
88
+ pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1)
89
+ if mean_normalize:
90
+ pos -= torch.nanmean(pos, dim=0, keepdim=True)
91
+
92
+ if augment:
93
+ delta = np.random.uniform(
94
+ -max_global_shift, +max_global_shift, size=[1, batch_size, 1]
95
+ )
96
+ delta_local = np.random.uniform(
97
+ -max_local_shift, +max_local_shift, size=[length, batch_size, 1]
98
+ )
99
+ log_lambdas = np.random.uniform(
100
+ -np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1]
101
+ )
102
+ pos = (pos + delta + delta_local) * np.exp(log_lambdas)
103
+
104
+ pos = pos.to(device)
105
+
106
+ half_dim = dim // 2
107
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
108
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
109
+ return torch.cat(
110
+ [
111
+ torch.cos(phase),
112
+ torch.sin(phase),
113
+ ],
114
+ dim=-1,
115
+ ).float()
116
+
117
+
118
+ def get_causal_mask(length):
119
+ pos = torch.arange(length)
120
+ return pos > pos[:, None]
121
+
122
+
123
+ def get_elementary_mask(
124
+ T1,
125
+ T2,
126
+ mask_type,
127
+ sparse_attn_window,
128
+ global_window,
129
+ mask_random_seed,
130
+ sparsity,
131
+ device,
132
+ ):
133
+ """
134
+ When the input of the Decoder has length T1 and the output T2
135
+ The mask matrix has shape (T2, T1)
136
+ """
137
+ assert mask_type in ["diag", "jmask", "random", "global"]
138
+
139
+ if mask_type == "global":
140
+ mask = torch.zeros(T2, T1, dtype=torch.bool)
141
+ mask[:, :global_window] = True
142
+ line_window = int(global_window * T2 / T1)
143
+ mask[:line_window, :] = True
144
+
145
+ if mask_type == "diag":
146
+
147
+ mask = torch.zeros(T2, T1, dtype=torch.bool)
148
+ rows = torch.arange(T2)[:, None]
149
+ cols = (
150
+ (T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1))
151
+ .long()
152
+ .clamp(0, T1 - 1)
153
+ )
154
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
155
+
156
+ elif mask_type == "jmask":
157
+ mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool)
158
+ rows = torch.arange(T2 + 2)[:, None]
159
+ t = torch.arange(0, int((2 * T1) ** 0.5 + 1))
160
+ t = (t * (t + 1) / 2).int()
161
+ t = torch.cat([-t.flip(0)[:-1], t])
162
+ cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1)
163
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
164
+ mask = mask[1:-1, 1:-1]
165
+
166
+ elif mask_type == "random":
167
+ gene = torch.Generator(device=device)
168
+ gene.manual_seed(mask_random_seed)
169
+ mask = (
170
+ torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1)
171
+ > sparsity
172
+ )
173
+
174
+ mask = mask.to(device)
175
+ return mask
176
+
177
+
178
+ def get_mask(
179
+ T1,
180
+ T2,
181
+ mask_type,
182
+ sparse_attn_window,
183
+ global_window,
184
+ mask_random_seed,
185
+ sparsity,
186
+ device,
187
+ ):
188
+ """
189
+ Return a SparseCSRTensor mask that is a combination of elementary masks
190
+ mask_type can be a combination of multiple masks: for instance "diag_jmask_random"
191
+ """
192
+ from xformers.sparse import SparseCSRTensor
193
+ # create a list
194
+ mask_types = mask_type.split("_")
195
+
196
+ all_masks = [
197
+ get_elementary_mask(
198
+ T1,
199
+ T2,
200
+ mask,
201
+ sparse_attn_window,
202
+ global_window,
203
+ mask_random_seed,
204
+ sparsity,
205
+ device,
206
+ )
207
+ for mask in mask_types
208
+ ]
209
+
210
+ final_mask = torch.stack(all_masks).sum(axis=0) > 0
211
+
212
+ return SparseCSRTensor.from_dense(final_mask[None])
213
+
214
+
215
+ class ScaledEmbedding(nn.Module):
216
+ def __init__(
217
+ self,
218
+ num_embeddings: int,
219
+ embedding_dim: int,
220
+ scale: float = 1.0,
221
+ boost: float = 3.0,
222
+ ):
223
+ super().__init__()
224
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
225
+ self.embedding.weight.data *= scale / boost
226
+ self.boost = boost
227
+
228
+ @property
229
+ def weight(self):
230
+ return self.embedding.weight * self.boost
231
+
232
+ def forward(self, x):
233
+ return self.embedding(x) * self.boost
234
+
235
+
236
+ class LayerScale(nn.Module):
237
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
238
+ This rescales diagonaly residual outputs close to 0 initially, then learnt.
239
+ """
240
+
241
+ def __init__(self, channels: int, init: float = 0, channel_last=False):
242
+ """
243
+ channel_last = False corresponds to (B, C, T) tensors
244
+ channel_last = True corresponds to (T, B, C) tensors
245
+ """
246
+ super().__init__()
247
+ self.channel_last = channel_last
248
+ self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
249
+ self.scale.data[:] = init
250
+
251
+ def forward(self, x):
252
+ if self.channel_last:
253
+ return self.scale * x
254
+ else:
255
+ return self.scale[:, None] * x
256
+
257
+
258
+ class MyGroupNorm(nn.GroupNorm):
259
+ def __init__(self, *args, **kwargs):
260
+ super().__init__(*args, **kwargs)
261
+
262
+ def forward(self, x):
263
+ """
264
+ x: (B, T, C)
265
+ if num_groups=1: Normalisation on all T and C together for each B
266
+ """
267
+ x = x.transpose(1, 2)
268
+ return super().forward(x).transpose(1, 2)
269
+
270
+
271
+ class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
272
+ def __init__(
273
+ self,
274
+ d_model,
275
+ nhead,
276
+ dim_feedforward=2048,
277
+ dropout=0.1,
278
+ activation=F.relu,
279
+ group_norm=0,
280
+ norm_first=False,
281
+ norm_out=False,
282
+ layer_norm_eps=1e-5,
283
+ layer_scale=False,
284
+ init_values=1e-4,
285
+ device=None,
286
+ dtype=None,
287
+ sparse=False,
288
+ mask_type="diag",
289
+ mask_random_seed=42,
290
+ sparse_attn_window=500,
291
+ global_window=50,
292
+ auto_sparsity=False,
293
+ sparsity=0.95,
294
+ batch_first=False,
295
+ ):
296
+ factory_kwargs = {"device": device, "dtype": dtype}
297
+ super().__init__(
298
+ d_model=d_model,
299
+ nhead=nhead,
300
+ dim_feedforward=dim_feedforward,
301
+ dropout=dropout,
302
+ activation=activation,
303
+ layer_norm_eps=layer_norm_eps,
304
+ batch_first=batch_first,
305
+ norm_first=norm_first,
306
+ device=device,
307
+ dtype=dtype,
308
+ )
309
+ self.sparse = sparse
310
+ self.auto_sparsity = auto_sparsity
311
+ if sparse:
312
+ if not auto_sparsity:
313
+ self.mask_type = mask_type
314
+ self.sparse_attn_window = sparse_attn_window
315
+ self.global_window = global_window
316
+ self.sparsity = sparsity
317
+ if group_norm:
318
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
319
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
320
+
321
+ self.norm_out = None
322
+ if self.norm_first & norm_out:
323
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
324
+ self.gamma_1 = (
325
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
326
+ )
327
+ self.gamma_2 = (
328
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
329
+ )
330
+
331
+ if sparse:
332
+ self.self_attn = MultiheadAttention(
333
+ d_model, nhead, dropout=dropout, batch_first=batch_first,
334
+ auto_sparsity=sparsity if auto_sparsity else 0,
335
+ )
336
+ self.__setattr__("src_mask", torch.zeros(1, 1))
337
+ self.mask_random_seed = mask_random_seed
338
+
339
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
340
+ """
341
+ if batch_first = False, src shape is (T, B, C)
342
+ the case where batch_first=True is not covered
343
+ """
344
+ device = src.device
345
+ x = src
346
+ T, B, C = x.shape
347
+ if self.sparse and not self.auto_sparsity:
348
+ assert src_mask is None
349
+ src_mask = self.src_mask
350
+ if src_mask.shape[-1] != T:
351
+ src_mask = get_mask(
352
+ T,
353
+ T,
354
+ self.mask_type,
355
+ self.sparse_attn_window,
356
+ self.global_window,
357
+ self.mask_random_seed,
358
+ self.sparsity,
359
+ device,
360
+ )
361
+ self.__setattr__("src_mask", src_mask)
362
+
363
+ if self.norm_first:
364
+ x = x + self.gamma_1(
365
+ self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
366
+ )
367
+ x = x + self.gamma_2(self._ff_block(self.norm2(x)))
368
+
369
+ if self.norm_out:
370
+ x = self.norm_out(x)
371
+ else:
372
+ x = self.norm1(
373
+ x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask))
374
+ )
375
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
376
+
377
+ return x
378
+
379
+
380
+ class CrossTransformerEncoderLayer(nn.Module):
381
+ def __init__(
382
+ self,
383
+ d_model: int,
384
+ nhead: int,
385
+ dim_feedforward: int = 2048,
386
+ dropout: float = 0.1,
387
+ activation=F.relu,
388
+ layer_norm_eps: float = 1e-5,
389
+ layer_scale: bool = False,
390
+ init_values: float = 1e-4,
391
+ norm_first: bool = False,
392
+ group_norm: bool = False,
393
+ norm_out: bool = False,
394
+ sparse=False,
395
+ mask_type="diag",
396
+ mask_random_seed=42,
397
+ sparse_attn_window=500,
398
+ global_window=50,
399
+ sparsity=0.95,
400
+ auto_sparsity=None,
401
+ device=None,
402
+ dtype=None,
403
+ batch_first=False,
404
+ ):
405
+ factory_kwargs = {"device": device, "dtype": dtype}
406
+ super().__init__()
407
+
408
+ self.sparse = sparse
409
+ self.auto_sparsity = auto_sparsity
410
+ if sparse:
411
+ if not auto_sparsity:
412
+ self.mask_type = mask_type
413
+ self.sparse_attn_window = sparse_attn_window
414
+ self.global_window = global_window
415
+ self.sparsity = sparsity
416
+
417
+ self.cross_attn: nn.Module
418
+ self.cross_attn = nn.MultiheadAttention(
419
+ d_model, nhead, dropout=dropout, batch_first=batch_first)
420
+ # Implementation of Feedforward model
421
+ self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
422
+ self.dropout = nn.Dropout(dropout)
423
+ self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
424
+
425
+ self.norm_first = norm_first
426
+ self.norm1: nn.Module
427
+ self.norm2: nn.Module
428
+ self.norm3: nn.Module
429
+ if group_norm:
430
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
431
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
432
+ self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
433
+ else:
434
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
435
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
436
+ self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
437
+
438
+ self.norm_out = None
439
+ if self.norm_first & norm_out:
440
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
441
+
442
+ self.gamma_1 = (
443
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
444
+ )
445
+ self.gamma_2 = (
446
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
447
+ )
448
+
449
+ self.dropout1 = nn.Dropout(dropout)
450
+ self.dropout2 = nn.Dropout(dropout)
451
+
452
+ # Legacy string support for activation function.
453
+ if isinstance(activation, str):
454
+ self.activation = self._get_activation_fn(activation)
455
+ else:
456
+ self.activation = activation
457
+
458
+ if sparse:
459
+ self.cross_attn = MultiheadAttention(
460
+ d_model, nhead, dropout=dropout, batch_first=batch_first,
461
+ auto_sparsity=sparsity if auto_sparsity else 0)
462
+ if not auto_sparsity:
463
+ self.__setattr__("mask", torch.zeros(1, 1))
464
+ self.mask_random_seed = mask_random_seed
465
+
466
+ def forward(self, q, k, mask=None):
467
+ """
468
+ Args:
469
+ q: tensor of shape (T, B, C)
470
+ k: tensor of shape (S, B, C)
471
+ mask: tensor of shape (T, S)
472
+
473
+ """
474
+ device = q.device
475
+ T, B, C = q.shape
476
+ S, B, C = k.shape
477
+ if self.sparse and not self.auto_sparsity:
478
+ assert mask is None
479
+ mask = self.mask
480
+ if mask.shape[-1] != S or mask.shape[-2] != T:
481
+ mask = get_mask(
482
+ S,
483
+ T,
484
+ self.mask_type,
485
+ self.sparse_attn_window,
486
+ self.global_window,
487
+ self.mask_random_seed,
488
+ self.sparsity,
489
+ device,
490
+ )
491
+ self.__setattr__("mask", mask)
492
+
493
+ if self.norm_first:
494
+ x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
495
+ x = x + self.gamma_2(self._ff_block(self.norm3(x)))
496
+ if self.norm_out:
497
+ x = self.norm_out(x)
498
+ else:
499
+ x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
500
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
501
+
502
+ return x
503
+
504
+ # self-attention block
505
+ def _ca_block(self, q, k, attn_mask=None):
506
+ x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
507
+ return self.dropout1(x)
508
+
509
+ # feed forward block
510
+ def _ff_block(self, x):
511
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
512
+ return self.dropout2(x)
513
+
514
+ def _get_activation_fn(self, activation):
515
+ if activation == "relu":
516
+ return F.relu
517
+ elif activation == "gelu":
518
+ return F.gelu
519
+
520
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
521
+
522
+
523
+ # ----------------- MULTI-BLOCKS MODELS: -----------------------
524
+
525
+
526
+ class CrossTransformerEncoder(nn.Module):
527
+ def __init__(
528
+ self,
529
+ dim: int,
530
+ emb: str = "sin",
531
+ hidden_scale: float = 4.0,
532
+ num_heads: int = 8,
533
+ num_layers: int = 6,
534
+ cross_first: bool = False,
535
+ dropout: float = 0.0,
536
+ max_positions: int = 1000,
537
+ norm_in: bool = True,
538
+ norm_in_group: bool = False,
539
+ group_norm: int = False,
540
+ norm_first: bool = False,
541
+ norm_out: bool = False,
542
+ max_period: float = 10000.0,
543
+ weight_decay: float = 0.0,
544
+ lr: tp.Optional[float] = None,
545
+ layer_scale: bool = False,
546
+ gelu: bool = True,
547
+ sin_random_shift: int = 0,
548
+ weight_pos_embed: float = 1.0,
549
+ cape_mean_normalize: bool = True,
550
+ cape_augment: bool = True,
551
+ cape_glob_loc_scale: list = [5000.0, 1.0, 1.4],
552
+ sparse_self_attn: bool = False,
553
+ sparse_cross_attn: bool = False,
554
+ mask_type: str = "diag",
555
+ mask_random_seed: int = 42,
556
+ sparse_attn_window: int = 500,
557
+ global_window: int = 50,
558
+ auto_sparsity: bool = False,
559
+ sparsity: float = 0.95,
560
+ ):
561
+ super().__init__()
562
+ """
563
+ """
564
+ assert dim % num_heads == 0
565
+
566
+ hidden_dim = int(dim * hidden_scale)
567
+
568
+ self.num_layers = num_layers
569
+ # classic parity = 1 means that if idx%2 == 1 there is a
570
+ # classical encoder else there is a cross encoder
571
+ self.classic_parity = 1 if cross_first else 0
572
+ self.emb = emb
573
+ self.max_period = max_period
574
+ self.weight_decay = weight_decay
575
+ self.weight_pos_embed = weight_pos_embed
576
+ self.sin_random_shift = sin_random_shift
577
+ if emb == "cape":
578
+ self.cape_mean_normalize = cape_mean_normalize
579
+ self.cape_augment = cape_augment
580
+ self.cape_glob_loc_scale = cape_glob_loc_scale
581
+ if emb == "scaled":
582
+ self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
583
+
584
+ self.lr = lr
585
+
586
+ activation: tp.Any = F.gelu if gelu else F.relu
587
+
588
+ self.norm_in: nn.Module
589
+ self.norm_in_t: nn.Module
590
+ if norm_in:
591
+ self.norm_in = nn.LayerNorm(dim)
592
+ self.norm_in_t = nn.LayerNorm(dim)
593
+ elif norm_in_group:
594
+ self.norm_in = MyGroupNorm(int(norm_in_group), dim)
595
+ self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
596
+ else:
597
+ self.norm_in = nn.Identity()
598
+ self.norm_in_t = nn.Identity()
599
+
600
+ # spectrogram layers
601
+ self.layers = nn.ModuleList()
602
+ # temporal layers
603
+ self.layers_t = nn.ModuleList()
604
+
605
+ kwargs_common = {
606
+ "d_model": dim,
607
+ "nhead": num_heads,
608
+ "dim_feedforward": hidden_dim,
609
+ "dropout": dropout,
610
+ "activation": activation,
611
+ "group_norm": group_norm,
612
+ "norm_first": norm_first,
613
+ "norm_out": norm_out,
614
+ "layer_scale": layer_scale,
615
+ "mask_type": mask_type,
616
+ "mask_random_seed": mask_random_seed,
617
+ "sparse_attn_window": sparse_attn_window,
618
+ "global_window": global_window,
619
+ "sparsity": sparsity,
620
+ "auto_sparsity": auto_sparsity,
621
+ "batch_first": True,
622
+ }
623
+
624
+ kwargs_classic_encoder = dict(kwargs_common)
625
+ kwargs_classic_encoder.update({
626
+ "sparse": sparse_self_attn,
627
+ })
628
+ kwargs_cross_encoder = dict(kwargs_common)
629
+ kwargs_cross_encoder.update({
630
+ "sparse": sparse_cross_attn,
631
+ })
632
+
633
+ for idx in range(num_layers):
634
+ if idx % 2 == self.classic_parity:
635
+
636
+ self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
637
+ self.layers_t.append(
638
+ MyTransformerEncoderLayer(**kwargs_classic_encoder)
639
+ )
640
+
641
+ else:
642
+ self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
643
+
644
+ self.layers_t.append(
645
+ CrossTransformerEncoderLayer(**kwargs_cross_encoder)
646
+ )
647
+
648
+ def forward(self, x, xt):
649
+ B, C, Fr, T1 = x.shape
650
+ pos_emb_2d = create_2d_sin_embedding(
651
+ C, Fr, T1, x.device, self.max_period
652
+ ) # (1, C, Fr, T1)
653
+ pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
654
+ x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
655
+ x = self.norm_in(x)
656
+ x = x + self.weight_pos_embed * pos_emb_2d
657
+
658
+ B, C, T2 = xt.shape
659
+ xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C
660
+ pos_emb = self._get_pos_embedding(T2, B, C, x.device)
661
+ pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
662
+ xt = self.norm_in_t(xt)
663
+ xt = xt + self.weight_pos_embed * pos_emb
664
+
665
+ for idx in range(self.num_layers):
666
+ if idx % 2 == self.classic_parity:
667
+ x = self.layers[idx](x)
668
+ xt = self.layers_t[idx](xt)
669
+ else:
670
+ old_x = x
671
+ x = self.layers[idx](x, xt)
672
+ xt = self.layers_t[idx](xt, old_x)
673
+
674
+ x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1)
675
+ xt = rearrange(xt, "b t2 c -> b c t2")
676
+ return x, xt
677
+
678
+ def _get_pos_embedding(self, T, B, C, device):
679
+ if self.emb == "sin":
680
+ shift = random.randrange(self.sin_random_shift + 1)
681
+ pos_emb = create_sin_embedding(
682
+ T, C, shift=shift, device=device, max_period=self.max_period
683
+ )
684
+ elif self.emb == "cape":
685
+ if self.training:
686
+ pos_emb = create_sin_embedding_cape(
687
+ T,
688
+ C,
689
+ B,
690
+ device=device,
691
+ max_period=self.max_period,
692
+ mean_normalize=self.cape_mean_normalize,
693
+ augment=self.cape_augment,
694
+ max_global_shift=self.cape_glob_loc_scale[0],
695
+ max_local_shift=self.cape_glob_loc_scale[1],
696
+ max_scale=self.cape_glob_loc_scale[2],
697
+ )
698
+ else:
699
+ pos_emb = create_sin_embedding_cape(
700
+ T,
701
+ C,
702
+ B,
703
+ device=device,
704
+ max_period=self.max_period,
705
+ mean_normalize=self.cape_mean_normalize,
706
+ augment=False,
707
+ )
708
+
709
+ elif self.emb == "scaled":
710
+ pos = torch.arange(T, device=device)
711
+ pos_emb = self.position_embeddings(pos)[:, None]
712
+
713
+ return pos_emb
714
+
715
+ def make_optim_group(self):
716
+ group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
717
+ if self.lr is not None:
718
+ group["lr"] = self.lr
719
+ return group
720
+
721
+
722
+ # Attention Modules
723
+
724
+
725
+ class MultiheadAttention(nn.Module):
726
+ def __init__(
727
+ self,
728
+ embed_dim,
729
+ num_heads,
730
+ dropout=0.0,
731
+ bias=True,
732
+ add_bias_kv=False,
733
+ add_zero_attn=False,
734
+ kdim=None,
735
+ vdim=None,
736
+ batch_first=False,
737
+ auto_sparsity=None,
738
+ ):
739
+ super().__init__()
740
+ assert auto_sparsity is not None, "sanity check"
741
+ self.num_heads = num_heads
742
+ self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
743
+ self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
744
+ self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
745
+ self.attn_drop = torch.nn.Dropout(dropout)
746
+ self.proj = torch.nn.Linear(embed_dim, embed_dim, bias)
747
+ self.proj_drop = torch.nn.Dropout(dropout)
748
+ self.batch_first = batch_first
749
+ self.auto_sparsity = auto_sparsity
750
+
751
+ def forward(
752
+ self,
753
+ query,
754
+ key,
755
+ value,
756
+ key_padding_mask=None,
757
+ need_weights=True,
758
+ attn_mask=None,
759
+ average_attn_weights=True,
760
+ ):
761
+
762
+ if not self.batch_first: # N, B, C
763
+ query = query.permute(1, 0, 2) # B, N_q, C
764
+ key = key.permute(1, 0, 2) # B, N_k, C
765
+ value = value.permute(1, 0, 2) # B, N_k, C
766
+ B, N_q, C = query.shape
767
+ B, N_k, C = key.shape
768
+
769
+ q = (
770
+ self.q(query)
771
+ .reshape(B, N_q, self.num_heads, C // self.num_heads)
772
+ .permute(0, 2, 1, 3)
773
+ )
774
+ q = q.flatten(0, 1)
775
+ k = (
776
+ self.k(key)
777
+ .reshape(B, N_k, self.num_heads, C // self.num_heads)
778
+ .permute(0, 2, 1, 3)
779
+ )
780
+ k = k.flatten(0, 1)
781
+ v = (
782
+ self.v(value)
783
+ .reshape(B, N_k, self.num_heads, C // self.num_heads)
784
+ .permute(0, 2, 1, 3)
785
+ )
786
+ v = v.flatten(0, 1)
787
+
788
+ if self.auto_sparsity:
789
+ assert attn_mask is None
790
+ x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity)
791
+ else:
792
+ x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop)
793
+ x = x.reshape(B, self.num_heads, N_q, C // self.num_heads)
794
+
795
+ x = x.transpose(1, 2).reshape(B, N_q, C)
796
+ x = self.proj(x)
797
+ x = self.proj_drop(x)
798
+ if not self.batch_first:
799
+ x = x.permute(1, 0, 2)
800
+ return x, None
801
+
802
+
803
+ def scaled_query_key_softmax(q, k, att_mask):
804
+ from xformers.ops import masked_matmul
805
+ q = q / (k.size(-1)) ** 0.5
806
+ att = masked_matmul(q, k.transpose(-2, -1), att_mask)
807
+ att = torch.nn.functional.softmax(att, -1)
808
+ return att
809
+
810
+
811
+ def scaled_dot_product_attention(q, k, v, att_mask, dropout):
812
+ att = scaled_query_key_softmax(q, k, att_mask=att_mask)
813
+ att = dropout(att)
814
+ y = att @ v
815
+ return y
816
+
817
+
818
+ def _compute_buckets(x, R):
819
+ qq = torch.einsum('btf,bfhi->bhti', x, R)
820
+ qq = torch.cat([qq, -qq], dim=-1)
821
+ buckets = qq.argmax(dim=-1)
822
+
823
+ return buckets.permute(0, 2, 1).byte().contiguous()
824
+
825
+
826
+ def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None):
827
+ # assert False, "The code for the custom sparse kernel is not ready for release yet."
828
+ from xformers.ops import find_locations, sparse_memory_efficient_attention
829
+ n_hashes = 32
830
+ proj_size = 4
831
+ query, key, value = [x.contiguous() for x in [query, key, value]]
832
+ with torch.no_grad():
833
+ R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device)
834
+ bucket_query = _compute_buckets(query, R)
835
+ bucket_key = _compute_buckets(key, R)
836
+ row_offsets, column_indices = find_locations(
837
+ bucket_query, bucket_key, sparsity, infer_sparsity)
838
+ return sparse_memory_efficient_attention(
839
+ query, key, value, row_offsets, column_indices, attn_bias)
demucs/utils.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import defaultdict
8
+ from contextlib import contextmanager
9
+ import math
10
+ import os
11
+ import tempfile
12
+ import typing as tp
13
+
14
+ import errno
15
+ import functools
16
+ import hashlib
17
+ import inspect
18
+ import io
19
+ import os
20
+ import random
21
+ import socket
22
+ import tempfile
23
+ import warnings
24
+ import zlib
25
+ import tkinter as tk
26
+
27
+ from diffq import UniformQuantizer, DiffQuantizer
28
+ import torch as th
29
+ import tqdm
30
+ from torch import distributed
31
+ from torch.nn import functional as F
32
+
33
+ import torch
34
+
35
+ def unfold(a, kernel_size, stride):
36
+ """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
37
+ with K the kernel size, by extracting frames with the given stride.
38
+
39
+ This will pad the input so that `F = ceil(T / K)`.
40
+
41
+ see https://github.com/pytorch/pytorch/issues/60466
42
+ """
43
+ *shape, length = a.shape
44
+ n_frames = math.ceil(length / stride)
45
+ tgt_length = (n_frames - 1) * stride + kernel_size
46
+ a = F.pad(a, (0, tgt_length - length))
47
+ strides = list(a.stride())
48
+ assert strides[-1] == 1, 'data should be contiguous'
49
+ strides = strides[:-1] + [stride, 1]
50
+ return a.as_strided([*shape, n_frames, kernel_size], strides)
51
+
52
+
53
+ def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
54
+ """
55
+ Center trim `tensor` with respect to `reference`, along the last dimension.
56
+ `reference` can also be a number, representing the length to trim to.
57
+ If the size difference != 0 mod 2, the extra sample is removed on the right side.
58
+ """
59
+ ref_size: int
60
+ if isinstance(reference, torch.Tensor):
61
+ ref_size = reference.size(-1)
62
+ else:
63
+ ref_size = reference
64
+ delta = tensor.size(-1) - ref_size
65
+ if delta < 0:
66
+ raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
67
+ if delta:
68
+ tensor = tensor[..., delta // 2:-(delta - delta // 2)]
69
+ return tensor
70
+
71
+
72
+ def pull_metric(history: tp.List[dict], name: str):
73
+ out = []
74
+ for metrics in history:
75
+ metric = metrics
76
+ for part in name.split("."):
77
+ metric = metric[part]
78
+ out.append(metric)
79
+ return out
80
+
81
+
82
+ def EMA(beta: float = 1):
83
+ """
84
+ Exponential Moving Average callback.
85
+ Returns a single function that can be called to repeatidly update the EMA
86
+ with a dict of metrics. The callback will return
87
+ the new averaged dict of metrics.
88
+
89
+ Note that for `beta=1`, this is just plain averaging.
90
+ """
91
+ fix: tp.Dict[str, float] = defaultdict(float)
92
+ total: tp.Dict[str, float] = defaultdict(float)
93
+
94
+ def _update(metrics: dict, weight: float = 1) -> dict:
95
+ nonlocal total, fix
96
+ for key, value in metrics.items():
97
+ total[key] = total[key] * beta + weight * float(value)
98
+ fix[key] = fix[key] * beta + weight
99
+ return {key: tot / fix[key] for key, tot in total.items()}
100
+ return _update
101
+
102
+
103
+ def sizeof_fmt(num: float, suffix: str = 'B'):
104
+ """
105
+ Given `num` bytes, return human readable size.
106
+ Taken from https://stackoverflow.com/a/1094933
107
+ """
108
+ for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
109
+ if abs(num) < 1024.0:
110
+ return "%3.1f%s%s" % (num, unit, suffix)
111
+ num /= 1024.0
112
+ return "%.1f%s%s" % (num, 'Yi', suffix)
113
+
114
+
115
+ @contextmanager
116
+ def temp_filenames(count: int, delete=True):
117
+ names = []
118
+ try:
119
+ for _ in range(count):
120
+ names.append(tempfile.NamedTemporaryFile(delete=False).name)
121
+ yield names
122
+ finally:
123
+ if delete:
124
+ for name in names:
125
+ os.unlink(name)
126
+
127
+ def average_metric(metric, count=1.):
128
+ """
129
+ Average `metric` which should be a float across all hosts. `count` should be
130
+ the weight for this particular host (i.e. number of examples).
131
+ """
132
+ metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda')
133
+ distributed.all_reduce(metric, op=distributed.ReduceOp.SUM)
134
+ return metric[1].item() / metric[0].item()
135
+
136
+
137
+ def free_port(host='', low=20000, high=40000):
138
+ """
139
+ Return a port number that is most likely free.
140
+ This could suffer from a race condition although
141
+ it should be quite rare.
142
+ """
143
+ sock = socket.socket()
144
+ while True:
145
+ port = random.randint(low, high)
146
+ try:
147
+ sock.bind((host, port))
148
+ except OSError as error:
149
+ if error.errno == errno.EADDRINUSE:
150
+ continue
151
+ raise
152
+ return port
153
+
154
+
155
+ def sizeof_fmt(num, suffix='B'):
156
+ """
157
+ Given `num` bytes, return human readable size.
158
+ Taken from https://stackoverflow.com/a/1094933
159
+ """
160
+ for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
161
+ if abs(num) < 1024.0:
162
+ return "%3.1f%s%s" % (num, unit, suffix)
163
+ num /= 1024.0
164
+ return "%.1f%s%s" % (num, 'Yi', suffix)
165
+
166
+
167
+ def human_seconds(seconds, display='.2f'):
168
+ """
169
+ Given `seconds` seconds, return human readable duration.
170
+ """
171
+ value = seconds * 1e6
172
+ ratios = [1e3, 1e3, 60, 60, 24]
173
+ names = ['us', 'ms', 's', 'min', 'hrs', 'days']
174
+ last = names.pop(0)
175
+ for name, ratio in zip(names, ratios):
176
+ if value / ratio < 0.3:
177
+ break
178
+ value /= ratio
179
+ last = name
180
+ return f"{format(value, display)} {last}"
181
+
182
+
183
+ class TensorChunk:
184
+ def __init__(self, tensor, offset=0, length=None):
185
+ total_length = tensor.shape[-1]
186
+ assert offset >= 0
187
+ assert offset < total_length
188
+
189
+ if length is None:
190
+ length = total_length - offset
191
+ else:
192
+ length = min(total_length - offset, length)
193
+
194
+ self.tensor = tensor
195
+ self.offset = offset
196
+ self.length = length
197
+ self.device = tensor.device
198
+
199
+ @property
200
+ def shape(self):
201
+ shape = list(self.tensor.shape)
202
+ shape[-1] = self.length
203
+ return shape
204
+
205
+ def padded(self, target_length):
206
+ delta = target_length - self.length
207
+ total_length = self.tensor.shape[-1]
208
+ assert delta >= 0
209
+
210
+ start = self.offset - delta // 2
211
+ end = start + target_length
212
+
213
+ correct_start = max(0, start)
214
+ correct_end = min(total_length, end)
215
+
216
+ pad_left = correct_start - start
217
+ pad_right = end - correct_end
218
+
219
+ out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
220
+ assert out.shape[-1] == target_length
221
+ return out
222
+
223
+
224
+ def tensor_chunk(tensor_or_chunk):
225
+ if isinstance(tensor_or_chunk, TensorChunk):
226
+ return tensor_or_chunk
227
+ else:
228
+ assert isinstance(tensor_or_chunk, th.Tensor)
229
+ return TensorChunk(tensor_or_chunk)
230
+
231
+
232
+ def apply_model_v1(model, mix, shifts=None, split=False, progress=False, set_progress_bar=None):
233
+ """
234
+ Apply model to a given mixture.
235
+
236
+ Args:
237
+ shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
238
+ and apply the oppositve shift to the output. This is repeated `shifts` time and
239
+ all predictions are averaged. This effectively makes the model time equivariant
240
+ and improves SDR by up to 0.2 points.
241
+ split (bool): if True, the input will be broken down in 8 seconds extracts
242
+ and predictions will be performed individually on each and concatenated.
243
+ Useful for model with large memory footprint like Tasnet.
244
+ progress (bool): if True, show a progress bar (requires split=True)
245
+ """
246
+
247
+ channels, length = mix.size()
248
+ device = mix.device
249
+ progress_value = 0
250
+
251
+ if split:
252
+ out = th.zeros(4, channels, length, device=device)
253
+ shift = model.samplerate * 10
254
+ offsets = range(0, length, shift)
255
+ scale = 10
256
+ if progress:
257
+ offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds')
258
+ for offset in offsets:
259
+ chunk = mix[..., offset:offset + shift]
260
+ if set_progress_bar:
261
+ progress_value += 1
262
+ set_progress_bar(0.1, (0.8/len(offsets)*progress_value))
263
+ chunk_out = apply_model_v1(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar)
264
+ else:
265
+ chunk_out = apply_model_v1(model, chunk, shifts=shifts)
266
+ out[..., offset:offset + shift] = chunk_out
267
+ offset += shift
268
+ return out
269
+ elif shifts:
270
+ max_shift = int(model.samplerate / 2)
271
+ mix = F.pad(mix, (max_shift, max_shift))
272
+ offsets = list(range(max_shift))
273
+ random.shuffle(offsets)
274
+ out = 0
275
+ for offset in offsets[:shifts]:
276
+ shifted = mix[..., offset:offset + length + max_shift]
277
+ if set_progress_bar:
278
+ shifted_out = apply_model_v1(model, shifted, set_progress_bar=set_progress_bar)
279
+ else:
280
+ shifted_out = apply_model_v1(model, shifted)
281
+ out += shifted_out[..., max_shift - offset:max_shift - offset + length]
282
+ out /= shifts
283
+ return out
284
+ else:
285
+ valid_length = model.valid_length(length)
286
+ delta = valid_length - length
287
+ padded = F.pad(mix, (delta // 2, delta - delta // 2))
288
+ with th.no_grad():
289
+ out = model(padded.unsqueeze(0))[0]
290
+ return center_trim(out, mix)
291
+
292
+ def apply_model_v2(model, mix, shifts=None, split=False,
293
+ overlap=0.25, transition_power=1., progress=False, set_progress_bar=None):
294
+ """
295
+ Apply model to a given mixture.
296
+
297
+ Args:
298
+ shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
299
+ and apply the oppositve shift to the output. This is repeated `shifts` time and
300
+ all predictions are averaged. This effectively makes the model time equivariant
301
+ and improves SDR by up to 0.2 points.
302
+ split (bool): if True, the input will be broken down in 8 seconds extracts
303
+ and predictions will be performed individually on each and concatenated.
304
+ Useful for model with large memory footprint like Tasnet.
305
+ progress (bool): if True, show a progress bar (requires split=True)
306
+ """
307
+
308
+ assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
309
+ device = mix.device
310
+ channels, length = mix.shape
311
+ progress_value = 0
312
+
313
+ if split:
314
+ out = th.zeros(len(model.sources), channels, length, device=device)
315
+ sum_weight = th.zeros(length, device=device)
316
+ segment = model.segment_length
317
+ stride = int((1 - overlap) * segment)
318
+ offsets = range(0, length, stride)
319
+ scale = stride / model.samplerate
320
+ if progress:
321
+ offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds')
322
+ # We start from a triangle shaped weight, with maximal weight in the middle
323
+ # of the segment. Then we normalize and take to the power `transition_power`.
324
+ # Large values of transition power will lead to sharper transitions.
325
+ weight = th.cat([th.arange(1, segment // 2 + 1),
326
+ th.arange(segment - segment // 2, 0, -1)]).to(device)
327
+ assert len(weight) == segment
328
+ # If the overlap < 50%, this will translate to linear transition when
329
+ # transition_power is 1.
330
+ weight = (weight / weight.max())**transition_power
331
+ for offset in offsets:
332
+ chunk = TensorChunk(mix, offset, segment)
333
+ if set_progress_bar:
334
+ progress_value += 1
335
+ set_progress_bar(0.1, (0.8/len(offsets)*progress_value))
336
+ chunk_out = apply_model_v2(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar)
337
+ else:
338
+ chunk_out = apply_model_v2(model, chunk, shifts=shifts)
339
+ chunk_length = chunk_out.shape[-1]
340
+ out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out
341
+ sum_weight[offset:offset + segment] += weight[:chunk_length]
342
+ offset += segment
343
+ assert sum_weight.min() > 0
344
+ out /= sum_weight
345
+ return out
346
+ elif shifts:
347
+ max_shift = int(0.5 * model.samplerate)
348
+ mix = tensor_chunk(mix)
349
+ padded_mix = mix.padded(length + 2 * max_shift)
350
+ out = 0
351
+ for _ in range(shifts):
352
+ offset = random.randint(0, max_shift)
353
+ shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
354
+
355
+ if set_progress_bar:
356
+ progress_value += 1
357
+ shifted_out = apply_model_v2(model, shifted, set_progress_bar=set_progress_bar)
358
+ else:
359
+ shifted_out = apply_model_v2(model, shifted)
360
+ out += shifted_out[..., max_shift - offset:]
361
+ out /= shifts
362
+ return out
363
+ else:
364
+ valid_length = model.valid_length(length)
365
+ mix = tensor_chunk(mix)
366
+ padded_mix = mix.padded(valid_length)
367
+ with th.no_grad():
368
+ out = model(padded_mix.unsqueeze(0))[0]
369
+ return center_trim(out, length)
370
+
371
+
372
+ @contextmanager
373
+ def temp_filenames(count, delete=True):
374
+ names = []
375
+ try:
376
+ for _ in range(count):
377
+ names.append(tempfile.NamedTemporaryFile(delete=False).name)
378
+ yield names
379
+ finally:
380
+ if delete:
381
+ for name in names:
382
+ os.unlink(name)
383
+
384
+
385
+ def get_quantizer(model, args, optimizer=None):
386
+ quantizer = None
387
+ if args.diffq:
388
+ quantizer = DiffQuantizer(
389
+ model, min_size=args.q_min_size, group_size=8)
390
+ if optimizer is not None:
391
+ quantizer.setup_optimizer(optimizer)
392
+ elif args.qat:
393
+ quantizer = UniformQuantizer(
394
+ model, bits=args.qat, min_size=args.q_min_size)
395
+ return quantizer
396
+
397
+
398
+ def load_model(path, strict=False):
399
+ with warnings.catch_warnings():
400
+ warnings.simplefilter("ignore")
401
+ load_from = path
402
+ package = th.load(load_from, 'cpu')
403
+
404
+ klass = package["klass"]
405
+ args = package["args"]
406
+ kwargs = package["kwargs"]
407
+
408
+ if strict:
409
+ model = klass(*args, **kwargs)
410
+ else:
411
+ sig = inspect.signature(klass)
412
+ for key in list(kwargs):
413
+ if key not in sig.parameters:
414
+ warnings.warn("Dropping inexistant parameter " + key)
415
+ del kwargs[key]
416
+ model = klass(*args, **kwargs)
417
+
418
+ state = package["state"]
419
+ training_args = package["training_args"]
420
+ quantizer = get_quantizer(model, training_args)
421
+
422
+ set_state(model, quantizer, state)
423
+ return model
424
+
425
+
426
+ def get_state(model, quantizer):
427
+ if quantizer is None:
428
+ state = {k: p.data.to('cpu') for k, p in model.state_dict().items()}
429
+ else:
430
+ state = quantizer.get_quantized_state()
431
+ buf = io.BytesIO()
432
+ th.save(state, buf)
433
+ state = {'compressed': zlib.compress(buf.getvalue())}
434
+ return state
435
+
436
+
437
+ def set_state(model, quantizer, state):
438
+ if quantizer is None:
439
+ model.load_state_dict(state)
440
+ else:
441
+ buf = io.BytesIO(zlib.decompress(state["compressed"]))
442
+ state = th.load(buf, "cpu")
443
+ quantizer.restore_quantized_state(state)
444
+
445
+ return state
446
+
447
+
448
+ def save_state(state, path):
449
+ buf = io.BytesIO()
450
+ th.save(state, buf)
451
+ sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
452
+
453
+ path = path.parent / (path.stem + "-" + sig + path.suffix)
454
+ path.write_bytes(buf.getvalue())
455
+
456
+
457
+ def save_model(model, quantizer, training_args, path):
458
+ args, kwargs = model._init_args_kwargs
459
+ klass = model.__class__
460
+
461
+ state = get_state(model, quantizer)
462
+
463
+ save_to = path
464
+ package = {
465
+ 'klass': klass,
466
+ 'args': args,
467
+ 'kwargs': kwargs,
468
+ 'state': state,
469
+ 'training_args': training_args,
470
+ }
471
+ th.save(package, save_to)
472
+
473
+
474
+ def capture_init(init):
475
+ @functools.wraps(init)
476
+ def __init__(self, *args, **kwargs):
477
+ self._init_args_kwargs = (args, kwargs)
478
+ init(self, *args, **kwargs)
479
+
480
+ return __init__
481
+
482
+ class DummyPoolExecutor:
483
+ class DummyResult:
484
+ def __init__(self, func, *args, **kwargs):
485
+ self.func = func
486
+ self.args = args
487
+ self.kwargs = kwargs
488
+
489
+ def result(self):
490
+ return self.func(*self.args, **self.kwargs)
491
+
492
+ def __init__(self, workers=0):
493
+ pass
494
+
495
+ def submit(self, func, *args, **kwargs):
496
+ return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
497
+
498
+ def __enter__(self):
499
+ return self
500
+
501
+ def __exit__(self, exc_type, exc_value, exc_tb):
502
+ return
gui_data/constants.py ADDED
@@ -0,0 +1,1147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+
3
+ #Platform Details
4
+ OPERATING_SYSTEM = platform.system()
5
+ SYSTEM_ARCH = platform.platform()
6
+ SYSTEM_PROC = platform.processor()
7
+ ARM = 'arm'
8
+
9
+ #Main Font
10
+ MAIN_FONT_NAME = "Century Gothic"
11
+
12
+ #Model Types
13
+ VR_ARCH_TYPE = 'VR Arc'
14
+ MDX_ARCH_TYPE = 'MDX-Net'
15
+ DEMUCS_ARCH_TYPE = 'Demucs'
16
+ VR_ARCH_PM = 'VR Architecture'
17
+ ENSEMBLE_MODE = 'Ensemble Mode'
18
+ ENSEMBLE_STEM_CHECK = 'Ensemble Stem'
19
+ SECONDARY_MODEL = 'Secondary Model'
20
+ DEMUCS_6_STEM_MODEL = 'htdemucs_6s'
21
+
22
+ DEMUCS_V3_ARCH_TYPE = 'Demucs v3'
23
+ DEMUCS_V4_ARCH_TYPE = 'Demucs v4'
24
+ DEMUCS_NEWER_ARCH_TYPES = [DEMUCS_V3_ARCH_TYPE, DEMUCS_V4_ARCH_TYPE]
25
+
26
+ DEMUCS_V1 = 'v1'
27
+ DEMUCS_V2 = 'v2'
28
+ DEMUCS_V3 = 'v3'
29
+ DEMUCS_V4 = 'v4'
30
+
31
+ DEMUCS_V1_TAG = 'v1 | '
32
+ DEMUCS_V2_TAG = 'v2 | '
33
+ DEMUCS_V3_TAG = 'v3 | '
34
+ DEMUCS_V4_TAG = 'v4 | '
35
+ DEMUCS_NEWER_TAGS = [DEMUCS_V3_TAG, DEMUCS_V4_TAG]
36
+
37
+ DEMUCS_VERSION_MAPPER = {
38
+ DEMUCS_V1:DEMUCS_V1_TAG,
39
+ DEMUCS_V2:DEMUCS_V2_TAG,
40
+ DEMUCS_V3:DEMUCS_V3_TAG,
41
+ DEMUCS_V4:DEMUCS_V4_TAG}
42
+
43
+ #Download Center
44
+ DOWNLOAD_FAILED = 'Download Failed'
45
+ DOWNLOAD_STOPPED = 'Download Stopped'
46
+ DOWNLOAD_COMPLETE = 'Download Complete'
47
+ DOWNLOAD_UPDATE_COMPLETE = 'Update Download Complete'
48
+ SETTINGS_MENU_EXIT = 'exit'
49
+ NO_CONNECTION = 'No Internet Connection'
50
+ VIP_SELECTION = 'VIP:'
51
+ DEVELOPER_SELECTION = 'VIP:'
52
+ NO_NEW_MODELS = 'All Available Models Downloaded'
53
+ ENSEMBLE_PARTITION = ': '
54
+ NO_MODEL = 'No Model Selected'
55
+ CHOOSE_MODEL = 'Choose Model'
56
+ SINGLE_DOWNLOAD = 'Downloading Item 1/1...'
57
+ DOWNLOADING_ITEM = 'Downloading Item'
58
+ FILE_EXISTS = 'File already exists!'
59
+ DOWNLOADING_UPDATE = 'Downloading Update...'
60
+ DOWNLOAD_MORE = 'Download More Models'
61
+
62
+ #Menu Options
63
+
64
+ AUTO_SELECT = 'Auto'
65
+
66
+ #LINKS
67
+ DOWNLOAD_CHECKS = "https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json"
68
+ MDX_MODEL_DATA_LINK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/model_data.json"
69
+ VR_MODEL_DATA_LINK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/vr_model_data/model_data.json"
70
+
71
+ DEMUCS_MODEL_NAME_DATA_LINK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/demucs_model_data/model_name_mapper.json"
72
+ MDX_MODEL_NAME_DATA_LINK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/model_name_mapper.json"
73
+
74
+ DONATE_LINK_BMAC = "https://www.buymeacoffee.com/uvr5"
75
+ DONATE_LINK_PATREON = "https://www.patreon.com/uvr"
76
+
77
+ #DOWNLOAD REPOS
78
+ NORMAL_REPO = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/"
79
+ UPDATE_REPO = "https://github.com/TRvlvr/model_repo/releases/download/uvr_update_patches/"
80
+
81
+ UPDATE_MAC_ARM_REPO = "https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.5.0/Ultimate_Vocal_Remover_v5_5_MacOS_arm64.dmg"
82
+ UPDATE_MAC_X86_64_REPO = "https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.5.0/Ultimate_Vocal_Remover_v5_5_MacOS_x86_64.dmg"
83
+ UPDATE_LINUX_REPO = "https://github.com/Anjok07/ultimatevocalremovergui#linux-installation"
84
+ UPDATE_REPO = "https://github.com/TRvlvr/model_repo/releases/download/uvr_update_patches/"
85
+
86
+ ISSUE_LINK = 'https://github.com/Anjok07/ultimatevocalremovergui/issues/new'
87
+ VIP_REPO = b'\xf3\xc2W\x19\x1foI)\xc2\xa9\xcc\xb67(Z\xf5',\
88
+ b'gAAAAABjQAIQ-NpNMMxMedpKHHb7ze_nqB05hw0YhbOy3pFzuzDrfqumn8_qvraxEoUpZC5ZXC0gGvfDxFMqyq9VWbYKlA67SUFI_wZB6QoVyGI581vs7kaGfUqlXHIdDS6tQ_U-BfjbEAK9EU_74-R2zXjz8Xzekw=='
89
+ NO_CODE = 'incorrect_code'
90
+
91
+ #Extensions
92
+
93
+ ONNX = '.onnx'
94
+ CKPT = '.ckpt'
95
+ YAML = '.yaml'
96
+ PTH = '.pth'
97
+ TH_EXT = '.th'
98
+ JSON = '.json'
99
+
100
+ #GUI Buttons
101
+
102
+ START_PROCESSING = 'Start Processing'
103
+ WAIT_PROCESSING = 'Please wait...'
104
+ STOP_PROCESSING = 'Halting process, please wait...'
105
+ LOADING_MODELS = 'Loading models...'
106
+
107
+ #---Messages and Logs----
108
+
109
+ MISSING_MODEL = 'missing'
110
+ MODEL_PRESENT = 'present'
111
+
112
+ UNRECOGNIZED_MODEL = 'Unrecognized Model Detected', ' is an unrecognized model.\n\n' + \
113
+ 'Would you like to select the correct parameters before continuing?'
114
+
115
+ STOP_PROCESS_CONFIRM = 'Confirmation', 'You are about to stop all active processes.\n\nAre you sure you wish to continue?'
116
+ NO_ENSEMBLE_SELECTED = 'No Models Selected', 'Please select ensemble and try again.'
117
+ PICKLE_CORRU = 'File Corrupted', 'Unable to load this ensemble.\n\n' + \
118
+ 'Would you like to remove this ensemble from your list?'
119
+ DELETE_ENS_ENTRY = 'Confirm Removal', 'Are you sure you want to remove this entry?'
120
+
121
+ ALL_STEMS = 'All Stems'
122
+ VOCAL_STEM = 'Vocals'
123
+ INST_STEM = 'Instrumental'
124
+ OTHER_STEM = 'Other'
125
+ BASS_STEM = 'Bass'
126
+ DRUM_STEM = 'Drums'
127
+ GUITAR_STEM = 'Guitar'
128
+ PIANO_STEM = 'Piano'
129
+ SYNTH_STEM = 'Synthesizer'
130
+ STRINGS_STEM = 'Strings'
131
+ WOODWINDS_STEM = 'Woodwinds'
132
+ BRASS_STEM = 'Brass'
133
+ WIND_INST_STEM = 'Wind Inst'
134
+ NO_OTHER_STEM = 'No Other'
135
+ NO_BASS_STEM = 'No Bass'
136
+ NO_DRUM_STEM = 'No Drums'
137
+ NO_GUITAR_STEM = 'No Guitar'
138
+ NO_PIANO_STEM = 'No Piano'
139
+ NO_SYNTH_STEM = 'No Synthesizer'
140
+ NO_STRINGS_STEM = 'No Strings'
141
+ NO_WOODWINDS_STEM = 'No Woodwinds'
142
+ NO_WIND_INST_STEM = 'No Wind Inst'
143
+ NO_BRASS_STEM = 'No Brass'
144
+ PRIMARY_STEM = 'Primary Stem'
145
+ SECONDARY_STEM = 'Secondary Stem'
146
+
147
+ #Other Constants
148
+ DEMUCS_2_SOURCE = ["instrumental", "vocals"]
149
+ DEMUCS_4_SOURCE = ["drums", "bass", "other", "vocals"]
150
+
151
+ DEMUCS_2_SOURCE_MAPPER = {
152
+ INST_STEM: 0,
153
+ VOCAL_STEM: 1}
154
+
155
+ DEMUCS_4_SOURCE_MAPPER = {
156
+ BASS_STEM: 0,
157
+ DRUM_STEM: 1,
158
+ OTHER_STEM: 2,
159
+ VOCAL_STEM: 3}
160
+
161
+ DEMUCS_6_SOURCE_MAPPER = {
162
+ BASS_STEM: 0,
163
+ DRUM_STEM: 1,
164
+ OTHER_STEM: 2,
165
+ VOCAL_STEM: 3,
166
+ GUITAR_STEM:4,
167
+ PIANO_STEM:5}
168
+
169
+ DEMUCS_4_SOURCE_LIST = [BASS_STEM, DRUM_STEM, OTHER_STEM, VOCAL_STEM]
170
+ DEMUCS_6_SOURCE_LIST = [BASS_STEM, DRUM_STEM, OTHER_STEM, VOCAL_STEM, GUITAR_STEM, PIANO_STEM]
171
+
172
+ DEMUCS_UVR_MODEL = 'UVR_Model'
173
+
174
+ CHOOSE_STEM_PAIR = 'Choose Stem Pair'
175
+
176
+ STEM_SET_MENU = (VOCAL_STEM,
177
+ INST_STEM,
178
+ OTHER_STEM,
179
+ BASS_STEM,
180
+ DRUM_STEM,
181
+ GUITAR_STEM,
182
+ PIANO_STEM,
183
+ SYNTH_STEM,
184
+ STRINGS_STEM,
185
+ WOODWINDS_STEM,
186
+ BRASS_STEM,
187
+ WIND_INST_STEM,
188
+ NO_OTHER_STEM,
189
+ NO_BASS_STEM,
190
+ NO_DRUM_STEM,
191
+ NO_GUITAR_STEM,
192
+ NO_PIANO_STEM,
193
+ NO_SYNTH_STEM,
194
+ NO_STRINGS_STEM,
195
+ NO_WOODWINDS_STEM,
196
+ NO_BRASS_STEM,
197
+ NO_WIND_INST_STEM)
198
+
199
+ STEM_PAIR_MAPPER = {
200
+ VOCAL_STEM: INST_STEM,
201
+ INST_STEM: VOCAL_STEM,
202
+ OTHER_STEM: NO_OTHER_STEM,
203
+ BASS_STEM: NO_BASS_STEM,
204
+ DRUM_STEM: NO_DRUM_STEM,
205
+ GUITAR_STEM: NO_GUITAR_STEM,
206
+ PIANO_STEM: NO_PIANO_STEM,
207
+ SYNTH_STEM: NO_SYNTH_STEM,
208
+ STRINGS_STEM: NO_STRINGS_STEM,
209
+ WOODWINDS_STEM: NO_WOODWINDS_STEM,
210
+ BRASS_STEM: NO_BRASS_STEM,
211
+ WIND_INST_STEM: NO_WIND_INST_STEM,
212
+ NO_OTHER_STEM: OTHER_STEM,
213
+ NO_BASS_STEM: BASS_STEM,
214
+ NO_DRUM_STEM: DRUM_STEM,
215
+ NO_GUITAR_STEM: GUITAR_STEM,
216
+ NO_PIANO_STEM: PIANO_STEM,
217
+ NO_SYNTH_STEM: SYNTH_STEM,
218
+ NO_STRINGS_STEM: STRINGS_STEM,
219
+ NO_WOODWINDS_STEM: WOODWINDS_STEM,
220
+ NO_BRASS_STEM: BRASS_STEM,
221
+ NO_WIND_INST_STEM: WIND_INST_STEM,
222
+ PRIMARY_STEM: SECONDARY_STEM}
223
+
224
+ NON_ACCOM_STEMS = (
225
+ VOCAL_STEM,
226
+ OTHER_STEM,
227
+ BASS_STEM,
228
+ DRUM_STEM,
229
+ GUITAR_STEM,
230
+ PIANO_STEM,
231
+ SYNTH_STEM,
232
+ STRINGS_STEM,
233
+ WOODWINDS_STEM,
234
+ BRASS_STEM,
235
+ WIND_INST_STEM)
236
+
237
+ MDX_NET_FREQ_CUT = [VOCAL_STEM, INST_STEM]
238
+
239
+ DEMUCS_4_STEM_OPTIONS = (ALL_STEMS, VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM)
240
+ DEMUCS_6_STEM_OPTIONS = (ALL_STEMS, VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM, GUITAR_STEM, PIANO_STEM)
241
+ DEMUCS_2_STEM_OPTIONS = (VOCAL_STEM, INST_STEM)
242
+ DEMUCS_4_STEM_CHECK = (OTHER_STEM, BASS_STEM, DRUM_STEM)
243
+
244
+ #Menu Dropdowns
245
+
246
+ VOCAL_PAIR = f'{VOCAL_STEM}/{INST_STEM}'
247
+ INST_PAIR = f'{INST_STEM}/{VOCAL_STEM}'
248
+ OTHER_PAIR = f'{OTHER_STEM}/{NO_OTHER_STEM}'
249
+ DRUM_PAIR = f'{DRUM_STEM}/{NO_DRUM_STEM}'
250
+ BASS_PAIR = f'{BASS_STEM}/{NO_BASS_STEM}'
251
+ FOUR_STEM_ENSEMBLE = '4 Stem Ensemble'
252
+
253
+ ENSEMBLE_MAIN_STEM = (CHOOSE_STEM_PAIR, VOCAL_PAIR, OTHER_PAIR, DRUM_PAIR, BASS_PAIR, FOUR_STEM_ENSEMBLE)
254
+
255
+ MIN_SPEC = 'Min Spec'
256
+ MAX_SPEC = 'Max Spec'
257
+ AUDIO_AVERAGE = 'Average'
258
+
259
+ MAX_MIN = f'{MAX_SPEC}/{MIN_SPEC}'
260
+ MAX_MAX = f'{MAX_SPEC}/{MAX_SPEC}'
261
+ MAX_AVE = f'{MAX_SPEC}/{AUDIO_AVERAGE}'
262
+ MIN_MAX = f'{MIN_SPEC}/{MAX_SPEC}'
263
+ MIN_MIX = f'{MIN_SPEC}/{MIN_SPEC}'
264
+ MIN_AVE = f'{MIN_SPEC}/{AUDIO_AVERAGE}'
265
+ AVE_MAX = f'{AUDIO_AVERAGE}/{MAX_SPEC}'
266
+ AVE_MIN = f'{AUDIO_AVERAGE}/{MIN_SPEC}'
267
+ AVE_AVE = f'{AUDIO_AVERAGE}/{AUDIO_AVERAGE}'
268
+
269
+ ENSEMBLE_TYPE = (MAX_MIN, MAX_MAX, MAX_AVE, MIN_MAX, MIN_MIX, MIN_AVE, AVE_MAX, AVE_MIN, AVE_AVE)
270
+ ENSEMBLE_TYPE_4_STEM = (MAX_SPEC, MIN_SPEC, AUDIO_AVERAGE)
271
+
272
+ BATCH_MODE = 'Batch Mode'
273
+ BETA_VERSION = 'BETA'
274
+ DEF_OPT = 'Default'
275
+
276
+ CHUNKS = (AUTO_SELECT, '1', '5', '10', '15', '20',
277
+ '25', '30', '35', '40', '45', '50',
278
+ '55', '60', '65', '70', '75', '80',
279
+ '85', '90', '95', 'Full')
280
+
281
+ BATCH_SIZE = (DEF_OPT, '2', '3', '4', '5',
282
+ '6', '7', '8', '9', '10')
283
+
284
+ VOL_COMPENSATION = (AUTO_SELECT, '1.035', '1.08')
285
+
286
+ MARGIN_SIZE = ('44100', '22050', '11025')
287
+
288
+ AUDIO_TOOLS = 'Audio Tools'
289
+
290
+ MANUAL_ENSEMBLE = 'Manual Ensemble'
291
+ TIME_STRETCH = 'Time Stretch'
292
+ CHANGE_PITCH = 'Change Pitch'
293
+ ALIGN_INPUTS = 'Align Inputs'
294
+
295
+ if OPERATING_SYSTEM == 'Windows' or OPERATING_SYSTEM == 'Darwin':
296
+ AUDIO_TOOL_OPTIONS = (MANUAL_ENSEMBLE, TIME_STRETCH, CHANGE_PITCH, ALIGN_INPUTS)
297
+ else:
298
+ AUDIO_TOOL_OPTIONS = (MANUAL_ENSEMBLE, ALIGN_INPUTS)
299
+
300
+ MANUAL_ENSEMBLE_OPTIONS = (MIN_SPEC, MAX_SPEC, AUDIO_AVERAGE)
301
+
302
+ PROCESS_METHODS = (VR_ARCH_PM, MDX_ARCH_TYPE, DEMUCS_ARCH_TYPE, ENSEMBLE_MODE, AUDIO_TOOLS)
303
+
304
+ DEMUCS_SEGMENTS = ('Default', '1', '5', '10', '15', '20',
305
+ '25', '30', '35', '40', '45', '50',
306
+ '55', '60', '65', '70', '75', '80',
307
+ '85', '90', '95', '100')
308
+
309
+ DEMUCS_SHIFTS = (0, 1, 2, 3, 4, 5,
310
+ 6, 7, 8, 9, 10, 11,
311
+ 12, 13, 14, 15, 16, 17,
312
+ 18, 19, 20)
313
+
314
+ DEMUCS_OVERLAP = (0.25, 0.50, 0.75, 0.99)
315
+
316
+ VR_AGGRESSION = (1, 2, 3, 4, 5,
317
+ 6, 7, 8, 9, 10, 11,
318
+ 12, 13, 14, 15, 16, 17,
319
+ 18, 19, 20)
320
+
321
+ VR_WINDOW = ('320', '512','1024')
322
+ VR_CROP = ('256', '512', '1024')
323
+ POST_PROCESSES_THREASHOLD_VALUES = ('0.1', '0.2', '0.3')
324
+
325
+ MDX_POP_PRO = ('MDX-NET_Noise_Profile_14_kHz', 'MDX-NET_Noise_Profile_17_kHz', 'MDX-NET_Noise_Profile_Full_Band')
326
+ MDX_POP_STEMS = ('Vocals', 'Instrumental', 'Other', 'Drums', 'Bass')
327
+ MDX_POP_NFFT = ('4096', '5120', '6144', '7680', '8192', '16384')
328
+ MDX_POP_DIMF = ('2048', '3072', '4096')
329
+
330
+ SAVE_ENSEMBLE = 'Save Ensemble'
331
+ CLEAR_ENSEMBLE = 'Clear Selection(s)'
332
+ MENU_SEPARATOR = 35*'•'
333
+ CHOOSE_ENSEMBLE_OPTION = 'Choose Option'
334
+
335
+ INVALID_ENTRY = 'Invalid Input, Please Try Again'
336
+ ENSEMBLE_INPUT_RULE = '1. Only letters, numbers, spaces, and dashes allowed.\n2. No dashes or spaces at the start or end of input.'
337
+
338
+ ENSEMBLE_OPTIONS = (SAVE_ENSEMBLE, CLEAR_ENSEMBLE)
339
+ ENSEMBLE_CHECK = 'ensemble check'
340
+
341
+ SELECT_SAVED_ENSEMBLE = 'Select Saved Ensemble'
342
+ SELECT_SAVED_SETTING = 'Select Saved Setting'
343
+ ENSEMBLE_OPTION = "Ensemble Customization Options"
344
+ MDX_OPTION = "Advanced MDX-Net Options"
345
+ DEMUCS_OPTION = "Advanced Demucs Options"
346
+ VR_OPTION = "Advanced VR Options"
347
+ HELP_OPTION = "Open Information Guide"
348
+ ERROR_OPTION = "Open Error Log"
349
+ VERIFY_BEGIN = 'Verifying file '
350
+ SAMPLE_BEGIN = 'Creating Sample '
351
+ MODEL_MISSING_CHECK = 'Model Missing:'
352
+
353
+ # Audio Player
354
+
355
+ PLAYING_SONG = ": Playing"
356
+ PAUSE_SONG = ": Paused"
357
+ STOP_SONG = ": Stopped"
358
+
359
+ SELECTED_VER = 'Selected'
360
+ DETECTED_VER = 'Detected'
361
+
362
+ SAMPLE_MODE_CHECKBOX = lambda v:f'Sample Mode ({v}s)'
363
+ REMOVED_FILES = lambda r, e:f'Audio Input Verification Report:\n\nRemoved Files:\n\n{r}\n\nError Details:\n\n{e}'
364
+ ADVANCED_SETTINGS = (ENSEMBLE_OPTION, MDX_OPTION, DEMUCS_OPTION, VR_OPTION, HELP_OPTION, ERROR_OPTION)
365
+
366
+ WAV = 'WAV'
367
+ FLAC = 'FLAC'
368
+ MP3 = 'MP3'
369
+
370
+ MP3_BIT_RATES = ('96k', '128k', '160k', '224k', '256k', '320k')
371
+ WAV_TYPE = ('PCM_U8', 'PCM_16', 'PCM_24', 'PCM_32', '32-bit Float', '64-bit Float')
372
+
373
+ SELECT_SAVED_SET = 'Choose Option'
374
+ SAVE_SETTINGS = 'Save Current Settings'
375
+ RESET_TO_DEFAULT = 'Reset to Default'
376
+ RESET_FULL_TO_DEFAULT = 'Reset to Default'
377
+ RESET_PM_TO_DEFAULT = 'Reset All Application Settings to Default'
378
+
379
+ SAVE_SET_OPTIONS = (SAVE_SETTINGS, RESET_TO_DEFAULT)
380
+
381
+ TIME_PITCH = ('1.0', '2.0', '3.0', '4.0')
382
+ TIME_TEXT = '_time_stretched'
383
+ PITCH_TEXT = '_pitch_shifted'
384
+
385
+ #RegEx Input Validation
386
+
387
+ REG_PITCH = r'^[-+]?(1[0]|[0-9]([.][0-9]*)?)$'
388
+ REG_TIME = r'^[+]?(1[0]|[0-9]([.][0-9]*)?)$'
389
+ REG_COMPENSATION = r'\b^(1[0]|[0-9]([.][0-9]*)?|Auto|None)$\b'
390
+ REG_THES_POSTPORCESS = r'\b^([0]([.][0-9]{0,6})?)$\b'
391
+ REG_CHUNKS = r'\b^(200|1[0-9][0-9]|[1-9][0-9]?|Auto|Full)$\b'
392
+ REG_CHUNKS_DEMUCS = r'\b^(200|1[0-9][0-9]|[1-9][0-9]?|Auto|Full)$\b'
393
+ REG_MARGIN = r'\b^[0-9]*$\b'
394
+ REG_SEGMENTS = r'\b^(200|1[0-9][0-9]|[1-9][0-9]?|Default)$\b'
395
+ REG_SAVE_INPUT = r'\b^([a-zA-Z0-9 -]{0,25})$\b'
396
+ REG_AGGRESSION = r'^[-+]?[0-9]\d*?$'
397
+ REG_WINDOW = r'\b^[0-9]{0,4}$\b'
398
+ REG_SHIFTS = r'\b^[0-9]*$\b'
399
+ REG_BATCHES = r'\b^([0-9]*?|Default)$\b'
400
+ REG_OVERLAP = r'\b^([0]([.][0-9]{0,6})?|None)$\b'
401
+
402
+ # Sub Menu
403
+
404
+ VR_ARCH_SETTING_LOAD = 'Load for VR Arch'
405
+ MDX_SETTING_LOAD = 'Load for MDX-Net'
406
+ DEMUCS_SETTING_LOAD = 'Load for Demucs'
407
+ ALL_ARCH_SETTING_LOAD = 'Load for Full Application'
408
+
409
+ # Mappers
410
+
411
+ DEFAULT_DATA = {
412
+
413
+ 'chosen_process_method': MDX_ARCH_TYPE,
414
+ 'vr_model': CHOOSE_MODEL,
415
+ 'aggression_setting': 10,
416
+ 'window_size': 512,
417
+ 'batch_size': 4,
418
+ 'crop_size': 256,
419
+ 'is_tta': False,
420
+ 'is_output_image': False,
421
+ 'is_post_process': False,
422
+ 'is_high_end_process': False,
423
+ 'post_process_threshold': 0.2,
424
+ 'vr_voc_inst_secondary_model': NO_MODEL,
425
+ 'vr_other_secondary_model': NO_MODEL,
426
+ 'vr_bass_secondary_model': NO_MODEL,
427
+ 'vr_drums_secondary_model': NO_MODEL,
428
+ 'vr_is_secondary_model_activate': False,
429
+ 'vr_voc_inst_secondary_model_scale': 0.9,
430
+ 'vr_other_secondary_model_scale': 0.7,
431
+ 'vr_bass_secondary_model_scale': 0.5,
432
+ 'vr_drums_secondary_model_scale': 0.5,
433
+ 'demucs_model': CHOOSE_MODEL,
434
+ 'demucs_stems': ALL_STEMS,
435
+ 'segment': DEMUCS_SEGMENTS[0],
436
+ 'overlap': DEMUCS_OVERLAP[0],
437
+ 'shifts': 2,
438
+ 'chunks_demucs': CHUNKS[0],
439
+ 'margin_demucs': 44100,
440
+ 'is_chunk_demucs': False,
441
+ 'is_chunk_mdxnet': False,
442
+ 'is_primary_stem_only_Demucs': False,
443
+ 'is_secondary_stem_only_Demucs': False,
444
+ 'is_split_mode': True,
445
+ 'is_demucs_combine_stems': True,
446
+ 'demucs_voc_inst_secondary_model': NO_MODEL,
447
+ 'demucs_other_secondary_model': NO_MODEL,
448
+ 'demucs_bass_secondary_model': NO_MODEL,
449
+ 'demucs_drums_secondary_model': NO_MODEL,
450
+ 'demucs_is_secondary_model_activate': False,
451
+ 'demucs_voc_inst_secondary_model_scale': 0.9,
452
+ 'demucs_other_secondary_model_scale': 0.7,
453
+ 'demucs_bass_secondary_model_scale': 0.5,
454
+ 'demucs_drums_secondary_model_scale': 0.5,
455
+ 'demucs_stems': ALL_STEMS,
456
+ 'demucs_pre_proc_model': NO_MODEL,
457
+ 'is_demucs_pre_proc_model_activate': False,
458
+ 'is_demucs_pre_proc_model_inst_mix': False,
459
+ 'mdx_net_model': CHOOSE_MODEL,
460
+ 'chunks': CHUNKS[0],
461
+ 'margin': 44100,
462
+ 'compensate': AUTO_SELECT,
463
+ 'is_denoise': False,
464
+ 'is_invert_spec': False,
465
+ 'is_mixer_mode': False,
466
+ 'mdx_batch_size': DEF_OPT,
467
+ 'mdx_voc_inst_secondary_model': NO_MODEL,
468
+ 'mdx_other_secondary_model': NO_MODEL,
469
+ 'mdx_bass_secondary_model': NO_MODEL,
470
+ 'mdx_drums_secondary_model': NO_MODEL,
471
+ 'mdx_is_secondary_model_activate': False,
472
+ 'mdx_voc_inst_secondary_model_scale': 0.9,
473
+ 'mdx_other_secondary_model_scale': 0.7,
474
+ 'mdx_bass_secondary_model_scale': 0.5,
475
+ 'mdx_drums_secondary_model_scale': 0.5,
476
+ 'is_save_all_outputs_ensemble': True,
477
+ 'is_append_ensemble_name': False,
478
+ 'chosen_audio_tool': AUDIO_TOOL_OPTIONS[0],
479
+ 'choose_algorithm': MANUAL_ENSEMBLE_OPTIONS[0],
480
+ 'time_stretch_rate': 2.0,
481
+ 'pitch_rate': 2.0,
482
+ 'is_gpu_conversion': False,
483
+ 'is_primary_stem_only': False,
484
+ 'is_secondary_stem_only': False,
485
+ 'is_testing_audio': False,
486
+ 'is_add_model_name': False,
487
+ 'is_accept_any_input': False,
488
+ 'is_task_complete': False,
489
+ 'is_normalization': False,
490
+ 'is_create_model_folder': False,
491
+ 'mp3_bit_set': '320k',
492
+ 'save_format': WAV,
493
+ 'wav_type_set': 'PCM_16',
494
+ 'user_code': '',
495
+ 'export_path': '',
496
+ 'input_paths': [],
497
+ 'lastDir': None,
498
+ 'export_path': '',
499
+ 'model_hash_table': None,
500
+ 'help_hints_var': False,
501
+ 'model_sample_mode': False,
502
+ 'model_sample_mode_duration': 30
503
+ }
504
+
505
+ SETTING_CHECK = ('vr_model',
506
+ 'aggression_setting',
507
+ 'window_size',
508
+ 'batch_size',
509
+ 'crop_size',
510
+ 'is_tta',
511
+ 'is_output_image',
512
+ 'is_post_process',
513
+ 'is_high_end_process',
514
+ 'post_process_threshold',
515
+ 'vr_voc_inst_secondary_model',
516
+ 'vr_other_secondary_model',
517
+ 'vr_bass_secondary_model',
518
+ 'vr_drums_secondary_model',
519
+ 'vr_is_secondary_model_activate',
520
+ 'vr_voc_inst_secondary_model_scale',
521
+ 'vr_other_secondary_model_scale',
522
+ 'vr_bass_secondary_model_scale',
523
+ 'vr_drums_secondary_model_scale',
524
+ 'demucs_model',
525
+ 'segment',
526
+ 'overlap',
527
+ 'shifts',
528
+ 'chunks_demucs',
529
+ 'margin_demucs',
530
+ 'is_chunk_demucs',
531
+ 'is_primary_stem_only_Demucs',
532
+ 'is_secondary_stem_only_Demucs',
533
+ 'is_split_mode',
534
+ 'is_demucs_combine_stems',
535
+ 'demucs_voc_inst_secondary_model',
536
+ 'demucs_other_secondary_model',
537
+ 'demucs_bass_secondary_model',
538
+ 'demucs_drums_secondary_model',
539
+ 'demucs_is_secondary_model_activate',
540
+ 'demucs_voc_inst_secondary_model_scale',
541
+ 'demucs_other_secondary_model_scale',
542
+ 'demucs_bass_secondary_model_scale',
543
+ 'demucs_drums_secondary_model_scale',
544
+ 'demucs_stems',
545
+ 'mdx_net_model',
546
+ 'chunks',
547
+ 'margin',
548
+ 'compensate',
549
+ 'is_denoise',
550
+ 'is_invert_spec',
551
+ 'mdx_batch_size',
552
+ 'mdx_voc_inst_secondary_model',
553
+ 'mdx_other_secondary_model',
554
+ 'mdx_bass_secondary_model',
555
+ 'mdx_drums_secondary_model',
556
+ 'mdx_is_secondary_model_activate',
557
+ 'mdx_voc_inst_secondary_model_scale',
558
+ 'mdx_other_secondary_model_scale',
559
+ 'mdx_bass_secondary_model_scale',
560
+ 'mdx_drums_secondary_model_scale',
561
+ 'is_save_all_outputs_ensemble',
562
+ 'is_append_ensemble_name',
563
+ 'chosen_audio_tool',
564
+ 'choose_algorithm',
565
+ 'time_stretch_rate',
566
+ 'pitch_rate',
567
+ 'is_primary_stem_only',
568
+ 'is_secondary_stem_only',
569
+ 'is_testing_audio',
570
+ 'is_add_model_name',
571
+ "is_accept_any_input",
572
+ 'is_task_complete',
573
+ 'is_create_model_folder',
574
+ 'mp3_bit_set',
575
+ 'save_format',
576
+ 'wav_type_set',
577
+ 'user_code',
578
+ 'is_gpu_conversion',
579
+ 'is_normalization',
580
+ 'help_hints_var',
581
+ 'model_sample_mode',
582
+ 'model_sample_mode_duration')
583
+
584
+ # Message Box Text
585
+
586
+ INVALID_INPUT = 'Invalid Input', 'The input is invalid.\n\nPlease verify the input still exists or is valid and try again.'
587
+ INVALID_EXPORT = 'Invalid Export Directory', 'You have selected an invalid export directory.\n\nPlease make sure the selected directory still exists.'
588
+ INVALID_ENSEMBLE = 'Not Enough Models', 'You must select 2 or more models to run ensemble.'
589
+ INVALID_MODEL = 'No Model Chosen', 'You must select an model to continue.'
590
+ MISSING_MODEL = 'Model Missing', 'The selected model is missing or not valid.'
591
+ ERROR_OCCURED = 'Error Occured', '\n\nWould you like to open the error log for more details?\n'
592
+
593
+ # GUI Text Constants
594
+
595
+ BACK_TO_MAIN_MENU = 'Back to Main Menu'
596
+
597
+ # Help Hint Text
598
+
599
+ INTERNAL_MODEL_ATT = 'Internal model attribute. \n\n ***Do not change this setting if you are unsure!***'
600
+ STOP_HELP = 'Halts any running processes. \n A pop-up window will ask the user to confirm the action.'
601
+ SETTINGS_HELP = 'Opens the main settings guide. This window includes the \"Download Center\"'
602
+ COMMAND_TEXT_HELP = 'Provides information on the progress of the current process.'
603
+ SAVE_CURRENT_SETTINGS_HELP = 'Allows the user to open any saved settings or save the current application settings.'
604
+ CHUNKS_HELP = ('For MDX-Net, all values use the same amount of resources. Using chunks is no longer recommended.\n\n' + \
605
+ '• This option is now only for output quality.\n' + \
606
+ '• Some tracks may fare better depending on the value.\n' + \
607
+ '• Some tracks may fare worse depending on the value.\n' + \
608
+ '• Larger chunk sizes use will take less time to process.\n' +\
609
+ '• Smaller chunk sizes use will take more time to process.\n')
610
+ CHUNKS_DEMUCS_HELP = ('This option allows the user to reduce (or increase) RAM or V-RAM usage.\n\n' + \
611
+ '• Smaller chunk sizes use less RAM or V-RAM but can also increase processing times.\n' + \
612
+ '• Larger chunk sizes use more RAM or V-RAM but can also reduce processing times.\n' + \
613
+ '• Selecting \"Auto\" calculates an appropriate chuck size based on how much RAM or V-RAM your system has.\n' + \
614
+ '• Selecting \"Full\" will process the track as one whole chunk. (not recommended)\n' + \
615
+ '• The default selection is \"Auto\".')
616
+ MARGIN_HELP = 'Selects the frequency margins to slice the chunks from.\n\n• The recommended margin size is 44100.\n• Other values can give unpredictable results.'
617
+ AGGRESSION_SETTING_HELP = ('This option allows you to set how strong the primary stem extraction will be.\n\n' + \
618
+ '• The range is 0-100.\n' + \
619
+ '• Higher values perform deeper extractions.\n' + \
620
+ '• The default is 10 for instrumental & vocal models.\n' + \
621
+ '• Values over 10 can result in muddy-sounding instrumentals for the non-vocal models')
622
+ WINDOW_SIZE_HELP = ('The smaller your window size, the better your conversions will be. \nHowever, a smaller window means longer conversion times and heavier resource usage.\n\n' + \
623
+ 'Breakdown of the selectable window size values:\n' + \
624
+ '• 1024 - Low conversion quality, shortest conversion time, low resource usage.\n' + \
625
+ '• 512 - Average conversion quality, average conversion time, normal resource usage.\n' + \
626
+ '• 320 - Better conversion quality.')
627
+ DEMUCS_STEMS_HELP = ('Here, you can choose which stem to extract using the selected model.\n\n' +\
628
+ 'Stem Selections:\n\n' +\
629
+ '• All Stems - Saves all of the stems the model is able to extract.\n' +\
630
+ '• Vocals - Pulls vocal stem only.\n' +\
631
+ '• Other - Pulls other stem only.\n' +\
632
+ '• Bass - Pulls bass stem only.\n' +\
633
+ '• Drums - Pulls drum stem only.\n')
634
+ SEGMENT_HELP = ('This option allows the user to reduce (or increase) RAM or V-RAM usage.\n\n' + \
635
+ '• Smaller segment sizes use less RAM or V-RAM but can also increase processing times.\n' + \
636
+ '• Larger segment sizes use more RAM or V-RAM but can also reduce processing times.\n' + \
637
+ '• Selecting \"Default\" uses the recommended segment size.\n' + \
638
+ '• It is recommended that you not use segments with \"Chunking\".')
639
+ ENSEMBLE_MAIN_STEM_HELP = 'Allows the user to select the type of stems they wish to ensemble.\n\nOptions:\n\n' +\
640
+ f'• {VOCAL_PAIR} - The primary stem will be the vocals and the secondary stem will be the the instrumental\n' +\
641
+ f'• {OTHER_PAIR} - The primary stem will be other and the secondary stem will be no other (the mixture without the \'other\' stem)\n' +\
642
+ f'• {BASS_PAIR} - The primary stem will be bass and the secondary stem will be no bass (the mixture without the \'bass\' stem)\n' +\
643
+ f'• {DRUM_PAIR} - The primary stem will be drums and the secondary stem will be no drums (the mixture without the \'drums\' stem)\n' +\
644
+ f'• {FOUR_STEM_ENSEMBLE} - This option will gather all the 4 stem Demucs models and ensemble all of the outputs.\n'
645
+ ENSEMBLE_TYPE_HELP = 'Allows the user to select the ensemble algorithm to be used to generate the final output.\n\nExample & Other Note:\n\n' +\
646
+ f'• {MAX_MIN} - If this option is chosen, the primary stem outputs will be processed through \nthe \'Max Spec\' algorithm, and the secondary stem will be processed through the \'Min Spec\' algorithm.\n' +\
647
+ f'• Only a single algorithm will be shown when the \'4 Stem Ensemble\' option is chosen.\n\nAlgorithm Details:\n\n' +\
648
+ f'• {MAX_SPEC} - This algorithm combines the final results and generates the highest possible output from them.\nFor example, if this algorithm were processing vocal stems, you would get the fullest possible \n' +\
649
+ 'result making the ensembled vocal stem sound cleaner. However, it might result in more unwanted artifacts.\n' +\
650
+ f'• {MIN_SPEC} - This algorithm combines the results and generates the lowest possible output from them.\nFor example, if this algorithm were processing instrumental stems, you would get the cleanest possible result \n' +\
651
+ 'result, eliminating more unwanted artifacts. However, the result might also sound \'muddy\' and lack a fuller sound.\n' +\
652
+ f'• {AUDIO_AVERAGE} - This algorithm simply combines the results and averages all of them together. \n'
653
+ ENSEMBLE_LISTBOX_HELP = 'List of the all the models available for the main stem pair selected.'
654
+ IS_GPU_CONVERSION_HELP = ('When checked, the application will attempt to use your GPU (if you have one).\n' +\
655
+ 'If you do not have a GPU but have this checked, the application will default to your CPU.\n\n' +\
656
+ 'Note: CPU conversions are much slower than those processed through the GPU.')
657
+ SAVE_STEM_ONLY_HELP = 'Allows the user to save only the selected stem.'
658
+ IS_NORMALIZATION_HELP = 'Normalizes output to prevent clipping.'
659
+ CROP_SIZE_HELP = '**Only compatible with select models only!**\n\n Setting should match training crop-size value. Leave as is if unsure.'
660
+ IS_TTA_HELP = ('This option performs Test-Time-Augmentation to improve the separation quality.\n\n' +\
661
+ 'Note: Having this selected will increase the time it takes to complete a conversion')
662
+ IS_POST_PROCESS_HELP = ('This option can potentially identify leftover instrumental artifacts within the vocal outputs. \nThis option may improve the separation of some songs.\n\n' +\
663
+ 'Note: Selecting this option can adversely affect the conversion process, depending on the track. Because of this, it is only recommended as a last resort.')
664
+ IS_HIGH_END_PROCESS_HELP = 'The application will mirror the missing frequency range of the output.'
665
+ SHIFTS_HELP = ('Performs multiple predictions with random shifts of the input and averages them.\n\n' +\
666
+ '• The higher number of shifts, the longer the prediction will take. \n- Not recommended unless you have a GPU.')
667
+ OVERLAP_HELP = 'This option controls the amount of overlap between prediction windows (for Demucs one window is 10 seconds)'
668
+ IS_CHUNK_DEMUCS_HELP = '• Enables \"Chunks\".\n• We recommend you not enable this option with \"Split Mode\" enabled or with the Demucs v4 Models.'
669
+ IS_CHUNK_MDX_NET_HELP = '• Enables \"Chunks\".\n• Using this option for MDX-Net no longer effects RAM usage.\n• Having this enabled will effect output quality, for better or worse depending on the set value.'
670
+ IS_SPLIT_MODE_HELP = ('• Enables \"Segments\". \n• We recommend you not enable this option with \"Enable Chunks\".\n' +\
671
+ '• Deselecting this option is only recommended for those with powerful PCs or if using \"Chunk\" mode instead.')
672
+ IS_DEMUCS_COMBINE_STEMS_HELP = 'The application will create the secondary stem by combining the remaining stems \ninstead of inverting the primary stem with the mixture.'
673
+ COMPENSATE_HELP = 'Compensates the audio of the primary stems to allow for a better secondary stem.'
674
+ IS_DENOISE_HELP = '• This option removes a majority of the noise generated by the MDX-Net models.\n• The conversion will take nearly twice as long with this enabled.'
675
+ CLEAR_CACHE_HELP = 'Clears any user selected model settings for previously unrecognized models.'
676
+ IS_SAVE_ALL_OUTPUTS_ENSEMBLE_HELP = 'Enabling this option will keep all indivudual outputs generated by an ensemble.'
677
+ IS_APPEND_ENSEMBLE_NAME_HELP = 'The application will append the ensemble name to the final output \nwhen this option is enabled.'
678
+ DONATE_HELP = 'Takes the user to an external web-site to donate to this project!'
679
+ IS_INVERT_SPEC_HELP = '• This option may produce a better secondary stem.\n• Inverts primary stem with mixture using spectragrams instead of wavforms.\n• This inversion method is slightly slower.'
680
+ IS_MIXER_MODE_HELP = '• This option may improve separations for outputs from 4-stem models.\n• Might produce more noise.\n• This option might slow down separation time.'
681
+ IS_TESTING_AUDIO_HELP = 'Appends a unique 10 digit number to output files so the user \ncan compare results with different settings.'
682
+ IS_MODEL_TESTING_AUDIO_HELP = 'Appends the model name to output files so the user \ncan compare results with different settings.'
683
+ IS_ACCEPT_ANY_INPUT_HELP = 'The application will accept any input when enabled, even if it does not have an audio format extension.\n\nThis is for experimental purposes, and having it enabled is not recommended.'
684
+ IS_TASK_COMPLETE_HELP = 'When enabled, chimes will be heard when a process completes or fails.'
685
+ IS_CREATE_MODEL_FOLDER_HELP = 'Two new directories will be generated for the outputs in \nthe export directory after each conversion.\n\n' +\
686
+ '• First directory - Named after the model.\n' +\
687
+ '• Second directory - Named after the track.\n\n' +\
688
+ '• Example: \n\n' +\
689
+ '─ Export Directory\n' +\
690
+ ' └── First Directory\n' +\
691
+ ' └── Second Directory\n' +\
692
+ ' └── Output File(s)'
693
+ DELETE_YOUR_SETTINGS_HELP = 'This menu contains your saved settings. You will be asked to \nconfirm if you wish to delete the selected setting.'
694
+ SET_STEM_NAME_HELP = 'Choose the primary stem for the selected model.'
695
+ MDX_DIM_T_SET_HELP = INTERNAL_MODEL_ATT
696
+ MDX_DIM_F_SET_HELP = INTERNAL_MODEL_ATT
697
+ MDX_N_FFT_SCALE_SET_HELP = 'Set the N_FFT size the model was trained with.'
698
+ POPUP_COMPENSATE_HELP = f'Choose the appropriate voluem compensattion for the selected model\n\nReminder: {COMPENSATE_HELP}'
699
+ VR_MODEL_PARAM_HELP = 'Choose the parameters needed to run the selected model.'
700
+ CHOSEN_ENSEMBLE_HELP = 'Select saved enselble or save current ensemble.\n\nDefault Selections:\n\n• Save the current ensemble.\n• Clears all current model selections.'
701
+ CHOSEN_PROCESS_METHOD_HELP = 'Here, you choose between different Al networks and algorithms to process your track.\n\n' +\
702
+ 'There are five options:\n\n' +\
703
+ '• VR Architecture - These models use magnitude spectrograms for Source Separation.\n' +\
704
+ '• MDX-Net - These models use Hybrid Spectrogram/Waveform for Source Separation.\n' +\
705
+ '• Demucs v3 - These models use Hybrid Spectrogram/Waveform for Source Separation.\n' +\
706
+ '• Ensemble Mode - Here, you can get the best results from multiple models and networks.\n' +\
707
+ '• Audio Tools - These are additional tools for added convenience.'
708
+ INPUT_FOLDER_ENTRY_HELP = 'Select Input:\n\nHere is where you select the audio files(s) you wish to process.'
709
+ INPUT_FOLDER_ENTRY_HELP_2 = 'Input Option Menu:\n\nClick here to access the input option menu.'
710
+ OUTPUT_FOLDER_ENTRY_HELP = 'Select Output:\n\nHere is where you select the directory where your processed files are to be saved.'
711
+ INPUT_FOLDER_BUTTON_HELP = 'Open Input Folder Button: \n\nOpens the directory containing the selected input audio file(s).'
712
+ OUTPUT_FOLDER_BUTTON_HELP = 'Open Output Folder Button: \n\nOpens the selected output folder.'
713
+ CHOOSE_MODEL_HELP = 'Each process method comes with its own set of options and models.\n\nHere is where you choose the model associated with the selected process method.'
714
+ FORMAT_SETTING_HELP = 'Save outputs as '
715
+ SECONDARY_MODEL_ACTIVATE_HELP = 'When enabled, the application will run an additional inference with the selected model(s) above.'
716
+ SECONDARY_MODEL_HELP = 'Choose the secondary model associated with this stem you wish to run with the current process method.'
717
+ SECONDARY_MODEL_SCALE_HELP = 'The scale determines how the final audio outputs will be averaged between the primary and secondary models.\n\nFor example:\n\n' +\
718
+ '• 10% - 10 percent of the main model result will be factored into the final result.\n' +\
719
+ '• 50% - The results from the main and secondary models will be averaged evenly.\n' +\
720
+ '• 90% - 90 percent of the main model result will be factored into the final result.'
721
+ PRE_PROC_MODEL_ACTIVATE_HELP = 'The application will run an inference with the selected model above, pulling only the instrumental stem when enabled. \nFrom there, all of the non-vocal stems will be pulled from the generated instrumental.\n\nNotes:\n\n' +\
722
+ '• This option can significantly reduce vocal bleed within the non-vocal stems.\n' +\
723
+ '• It is only available in Demucs.\n' +\
724
+ '• It is only compatible with non-vocal and non-instrumental stem outputs.\n' +\
725
+ '• This will increase thetotal processing time.\n' +\
726
+ '• Only VR and MDX-Net Vocal or Instrumental models are selectable above.'
727
+
728
+ AUDIO_TOOLS_HELP = 'Here, you choose between different audio tools to process your track.\n\n' +\
729
+ '• Manual Ensemble - You must have 2 or more files selected as your inputs. Allows the user to run their tracks through \nthe same algorithms used in Ensemble Mode.\n' +\
730
+ '• Align Inputs - You must have exactly 2 files selected as your inputs. The second input will be aligned with the first input.\n' +\
731
+ '• Time Stretch - The user can speed up or slow down the selected inputs.\n' +\
732
+ '• Change Pitch - The user can change the pitch for the selected inputs.\n'
733
+ PRE_PROC_MODEL_INST_MIX_HELP = 'When enabled, the application will generate a third output without the selected stem and vocals.'
734
+ MODEL_SAMPLE_MODE_HELP = 'Allows the user to process only part of a track to sample settings or a model without \nrunning a full conversion.\n\nNotes:\n\n' +\
735
+ '• The number in the parentheses is the current number of seconds the generated sample will be.\n' +\
736
+ '• You can choose the number of seconds to extract from the track in the \"Additional Settings\" menu.'
737
+
738
+ POST_PROCESS_THREASHOLD_HELP = 'Allows the user to control the intensity of the Post_process option.\n\nNotes:\n\n' +\
739
+ '• Higher values potentially remove more artifacts. However, bleed might increase.\n' +\
740
+ '• Lower values limit artifact removal.'
741
+
742
+ BATCH_SIZE_HELP = 'Specify the number of batches to be processed at a time.\n\nNotes:\n\n' +\
743
+ '• Higher values mean more RAM usage but slightly faster processing times.\n' +\
744
+ '• Lower values mean less RAM usage but slightly longer processing times.\n' +\
745
+ '• Batch size value has no effect on output quality.'
746
+
747
+ # Warning Messages
748
+
749
+ STORAGE_ERROR = 'Insufficient Storage', 'There is not enough storage on main drive to continue. Your main drive must have at least 3 GB\'s of storage in order for this application function properly. \n\nPlease ensure your main drive has at least 3 GB\'s of storage and try again.\n\n'
750
+ STORAGE_WARNING = 'Available Storage Low', 'Your main drive is running low on storage. Your main drive must have at least 3 GB\'s of storage in order for this application function properly.\n\n'
751
+ CONFIRM_WARNING = '\nAre you sure you wish to continue?'
752
+ PROCESS_FAILED = 'Process failed, please see error log\n'
753
+ EXIT_PROCESS_ERROR = 'Active Process', 'Please stop the active process or wait for it to complete before you exit.'
754
+ EXIT_HALTED_PROCESS_ERROR = 'Halting Process', 'Please wait for the application to finish halting the process before exiting.'
755
+ EXIT_DOWNLOAD_ERROR = 'Active Download', 'Please stop the download or wait for it to complete before you exit.'
756
+ SET_TO_DEFAULT_PROCESS_ERROR = 'Active Process', 'You cannot reset all of the application settings during an active process.'
757
+ SET_TO_ANY_PROCESS_ERROR = 'Active Process', 'You cannot reset the application settings during an active process.'
758
+ RESET_ALL_TO_DEFAULT_WARNING = 'Reset Settings Confirmation', 'All application settings will be set to factory default.\n\nAre you sure you wish to continue?'
759
+ AUDIO_VERIFICATION_CHECK = lambda i, e:f'++++++++++++++++++++++++++++++++++++++++++++++++++++\n\nBroken File Removed: \n\n{i}\n\nError Details:\n\n{e}\n++++++++++++++++++++++++++++++++++++++++++++++++++++'
760
+ INVALID_ONNX_MODEL_ERROR = 'Invalid Model', 'The file selected is not a valid MDX-Net model. Please see the error log for more information.'
761
+
762
+
763
+ # Separation Text
764
+
765
+ LOADING_MODEL = 'Loading model...'
766
+ INFERENCE_STEP_1 = 'Running inference...'
767
+ INFERENCE_STEP_1_SEC = 'Running inference (secondary model)...'
768
+ INFERENCE_STEP_1_4_STEM = lambda stem:f'Running inference (secondary model for {stem})...'
769
+ INFERENCE_STEP_1_PRE = 'Running inference (pre-process model)...'
770
+ INFERENCE_STEP_2_PRE = lambda pm, m:f'Loading pre-process model ({pm}: {m})...'
771
+ INFERENCE_STEP_2_SEC = lambda pm, m:f'Loading secondary model ({pm}: {m})...'
772
+ INFERENCE_STEP_2_SEC_CACHED_MODOEL = lambda pm, m:f'Secondary model ({pm}: {m}) cache loaded.\n'
773
+ INFERENCE_STEP_2_PRE_CACHED_MODOEL = lambda pm, m:f'Pre-process model ({pm}: {m}) cache loaded.\n'
774
+ INFERENCE_STEP_2_SEC_CACHED = 'Loading cached secondary model source(s)... Done!\n'
775
+ INFERENCE_STEP_2_PRIMARY_CACHED = 'Model cache loaded.\n'
776
+ INFERENCE_STEP_2 = 'Inference complete.'
777
+ SAVING_STEM = 'Saving ', ' stem...'
778
+ SAVING_ALL_STEMS = 'Saving all stems...'
779
+ ENSEMBLING_OUTPUTS = 'Ensembling outputs...'
780
+ DONE = ' Done!\n'
781
+ ENSEMBLES_SAVED = 'Ensembled outputs saved!\n\n'
782
+ NEW_LINES = "\n\n"
783
+ NEW_LINE = "\n"
784
+ NO_LINE = ''
785
+
786
+ # Widget Placements
787
+
788
+ MAIN_ROW_Y = -15, -17
789
+ MAIN_ROW_X = -4, 21
790
+ MAIN_ROW_WIDTH = -53
791
+ MAIN_ROW_2_Y = -15, -17
792
+ MAIN_ROW_2_X = -28, 1
793
+ CHECK_BOX_Y = 0
794
+ CHECK_BOX_X = 20
795
+ CHECK_BOX_WIDTH = -50
796
+ CHECK_BOX_HEIGHT = 2
797
+ LEFT_ROW_WIDTH = -10
798
+ LABEL_HEIGHT = -5
799
+ OPTION_HEIGHT = 7
800
+ LOW_MENU_Y = 18, 16
801
+ FFMPEG_EXT = (".aac", ".aiff", ".alac" ,".flac", ".FLAC", ".mov", ".mp4", ".MP4",
802
+ ".m4a", ".M4A", ".mp2", ".mp3", "MP3", ".mpc", ".mpc8",
803
+ ".mpeg", ".ogg", ".OGG", ".tta", ".wav", ".wave", ".WAV", ".WAVE", ".wma", ".webm", ".eac3", ".mkv")
804
+
805
+ FFMPEG_MORE_EXT = (".aa", ".aac", ".ac3", ".aiff", ".alac", ".avi", ".f4v",".flac", ".flic", ".flv",
806
+ ".m4v",".mlv", ".mov", ".mp4", ".m4a", ".mp2", ".mp3", ".mp4", ".mpc", ".mpc8",
807
+ ".mpeg", ".ogg", ".tta", ".tty", ".vcd", ".wav", ".wma")
808
+ ANY_EXT = ""
809
+
810
+ # Secondary Menu Constants
811
+
812
+ VOCAL_PAIR_PLACEMENT = 1, 2, 3, 4
813
+ OTHER_PAIR_PLACEMENT = 5, 6, 7, 8
814
+ BASS_PAIR_PLACEMENT = 9, 10, 11, 12
815
+ DRUMS_PAIR_PLACEMENT = 13, 14, 15, 16
816
+
817
+ # Drag n Drop String Checks
818
+
819
+ DOUBLE_BRACKET = "} {"
820
+ RIGHT_BRACKET = "}"
821
+ LEFT_BRACKET = "{"
822
+
823
+ # Manual Downloads
824
+
825
+ VR_PLACEMENT_TEXT = 'Place models in \"models/VR_Models\" directory.'
826
+ MDX_PLACEMENT_TEXT = 'Place models in \"models/MDX_Net_Models\" directory.'
827
+ DEMUCS_PLACEMENT_TEXT = 'Place models in \"models/Demucs_Models\" directory.'
828
+ DEMUCS_V3_V4_PLACEMENT_TEXT = 'Place items in \"models/Demucs_Models/v3_v4_repo\" directory.'
829
+
830
+ FULL_DOWNLOAD_LIST_VR = {
831
+ "VR Arch Single Model v5: 1_HP-UVR": "1_HP-UVR.pth",
832
+ "VR Arch Single Model v5: 2_HP-UVR": "2_HP-UVR.pth",
833
+ "VR Arch Single Model v5: 3_HP-Vocal-UVR": "3_HP-Vocal-UVR.pth",
834
+ "VR Arch Single Model v5: 4_HP-Vocal-UVR": "4_HP-Vocal-UVR.pth",
835
+ "VR Arch Single Model v5: 5_HP-Karaoke-UVR": "5_HP-Karaoke-UVR.pth",
836
+ "VR Arch Single Model v5: 6_HP-Karaoke-UVR": "6_HP-Karaoke-UVR.pth",
837
+ "VR Arch Single Model v5: 7_HP2-UVR": "7_HP2-UVR.pth",
838
+ "VR Arch Single Model v5: 8_HP2-UVR": "8_HP2-UVR.pth",
839
+ "VR Arch Single Model v5: 9_HP2-UVR": "9_HP2-UVR.pth",
840
+ "VR Arch Single Model v5: 10_SP-UVR-2B-32000-1": "10_SP-UVR-2B-32000-1.pth",
841
+ "VR Arch Single Model v5: 11_SP-UVR-2B-32000-2": "11_SP-UVR-2B-32000-2.pth",
842
+ "VR Arch Single Model v5: 12_SP-UVR-3B-44100": "12_SP-UVR-3B-44100.pth",
843
+ "VR Arch Single Model v5: 13_SP-UVR-4B-44100-1": "13_SP-UVR-4B-44100-1.pth",
844
+ "VR Arch Single Model v5: 14_SP-UVR-4B-44100-2": "14_SP-UVR-4B-44100-2.pth",
845
+ "VR Arch Single Model v5: 15_SP-UVR-MID-44100-1": "15_SP-UVR-MID-44100-1.pth",
846
+ "VR Arch Single Model v5: 16_SP-UVR-MID-44100-2": "16_SP-UVR-MID-44100-2.pth",
847
+ "VR Arch Single Model v4: MGM_HIGHEND_v4": "MGM_HIGHEND_v4.pth",
848
+ "VR Arch Single Model v4: MGM_LOWEND_A_v4": "MGM_LOWEND_A_v4.pth",
849
+ "VR Arch Single Model v4: MGM_LOWEND_B_v4": "MGM_LOWEND_B_v4.pth",
850
+ "VR Arch Single Model v4: MGM_MAIN_v4": "MGM_MAIN_v4.pth"
851
+ }
852
+
853
+ FULL_DOWNLOAD_LIST_MDX = {
854
+ "MDX-Net Model: UVR-MDX-NET Main": "UVR_MDXNET_Main.onnx",
855
+ "MDX-Net Model: UVR-MDX-NET Inst Main": "UVR-MDX-NET-Inst_Main.onnx",
856
+ "MDX-Net Model: UVR-MDX-NET 1": "UVR_MDXNET_1_9703.onnx",
857
+ "MDX-Net Model: UVR-MDX-NET 2": "UVR_MDXNET_2_9682.onnx",
858
+ "MDX-Net Model: UVR-MDX-NET 3": "UVR_MDXNET_3_9662.onnx",
859
+ "MDX-Net Model: UVR-MDX-NET Inst 1": "UVR-MDX-NET-Inst_1.onnx",
860
+ "MDX-Net Model: UVR-MDX-NET Inst 2": "UVR-MDX-NET-Inst_2.onnx",
861
+ "MDX-Net Model: UVR-MDX-NET Inst 3": "UVR-MDX-NET-Inst_3.onnx",
862
+ "MDX-Net Model: UVR-MDX-NET Karaoke": "UVR_MDXNET_KARA.onnx",
863
+ "MDX-Net Model: UVR_MDXNET_9482": "UVR_MDXNET_9482.onnx",
864
+ "MDX-Net Model: Kim_Vocal_1": "Kim_Vocal_1.onnx",
865
+ "MDX-Net Model: kuielab_a_vocals": "kuielab_a_vocals.onnx",
866
+ "MDX-Net Model: kuielab_a_other": "kuielab_a_other.onnx",
867
+ "MDX-Net Model: kuielab_a_bass": "kuielab_a_bass.onnx",
868
+ "MDX-Net Model: kuielab_a_drums": "kuielab_a_drums.onnx",
869
+ "MDX-Net Model: kuielab_b_vocals": "kuielab_b_vocals.onnx",
870
+ "MDX-Net Model: kuielab_b_other": "kuielab_b_other.onnx",
871
+ "MDX-Net Model: kuielab_b_bass": "kuielab_b_bass.onnx",
872
+ "MDX-Net Model: kuielab_b_drums": "kuielab_b_drums.onnx"}
873
+
874
+ FULL_DOWNLOAD_LIST_DEMUCS = {
875
+
876
+ "Demucs v4: htdemucs_ft":{
877
+ "f7e0c4bc-ba3fe64a.th":"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th",
878
+ "d12395a8-e57c48e6.th":"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/d12395a8-e57c48e6.th",
879
+ "92cfc3b6-ef3bcb9c.th":"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/92cfc3b6-ef3bcb9c.th",
880
+ "04573f0d-f3cf25b2.th":"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th",
881
+ "htdemucs_ft.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/htdemucs_ft.yaml"
882
+ },
883
+
884
+ "Demucs v4: htdemucs":{
885
+ "955717e8-8726e21a.th": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/955717e8-8726e21a.th",
886
+ "htdemucs.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/htdemucs.yaml"
887
+ },
888
+
889
+ "Demucs v4: hdemucs_mmi":{
890
+ "75fc33f5-1941ce65.th": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/75fc33f5-1941ce65.th",
891
+ "hdemucs_mmi.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/hdemucs_mmi.yaml"
892
+ },
893
+ "Demucs v4: htdemucs_6s":{
894
+ "5c90dfd2-34c22ccb.th": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/5c90dfd2-34c22ccb.th",
895
+ "htdemucs_6s.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/htdemucs_6s.yaml"
896
+ },
897
+ "Demucs v3: mdx":{
898
+ "0d19c1c6-0f06f20e.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/0d19c1c6-0f06f20e.th",
899
+ "7ecf8ec1-70f50cc9.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/7ecf8ec1-70f50cc9.th",
900
+ "c511e2ab-fe698775.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/c511e2ab-fe698775.th",
901
+ "7d865c68-3d5dd56b.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/7d865c68-3d5dd56b.th",
902
+ "mdx.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/mdx.yaml"
903
+ },
904
+
905
+ "Demucs v3: mdx_q":{
906
+ "6b9c2ca1-3fd82607.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/6b9c2ca1-3fd82607.th",
907
+ "b72baf4e-8778635e.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/b72baf4e-8778635e.th",
908
+ "42e558d4-196e0e1b.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/42e558d4-196e0e1b.th",
909
+ "305bc58f-18378783.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/305bc58f-18378783.th",
910
+ "mdx_q.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/mdx_q.yaml"
911
+ },
912
+
913
+ "Demucs v3: mdx_extra":{
914
+ "e51eebcc-c1b80bdd.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/e51eebcc-c1b80bdd.th",
915
+ "a1d90b5c-ae9d2452.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/a1d90b5c-ae9d2452.th",
916
+ "5d2d6c55-db83574e.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/5d2d6c55-db83574e.th",
917
+ "cfa93e08-61801ae1.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/cfa93e08-61801ae1.th",
918
+ "mdx_extra.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/mdx_extra.yaml"
919
+ },
920
+
921
+ "Demucs v3: mdx_extra_q": {
922
+ "83fc094f-4a16d450.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/83fc094f-4a16d450.th",
923
+ "464b36d7-e5a9386e.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/464b36d7-e5a9386e.th",
924
+ "14fc6a69-a89dd0ee.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/14fc6a69-a89dd0ee.th",
925
+ "7fd6ef75-a905dd85.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/7fd6ef75-a905dd85.th",
926
+ "mdx_extra_q.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/mdx_extra_q.yaml"
927
+ },
928
+
929
+ "Demucs v3: UVR Model":{
930
+ "ebf34a2db.th": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/ebf34a2db.th",
931
+ "UVR_Demucs_Model_1.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/UVR_Demucs_Model_1.yaml"
932
+ },
933
+
934
+ "Demucs v3: repro_mdx_a":{
935
+ "9a6b4851-03af0aa6.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/9a6b4851-03af0aa6.th",
936
+ "1ef250f1-592467ce.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/1ef250f1-592467ce.th",
937
+ "fa0cb7f9-100d8bf4.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/fa0cb7f9-100d8bf4.th",
938
+ "902315c2-b39ce9c9.th": "https://dl.fbaipublicfiles.com/demucs/mdx_final/902315c2-b39ce9c9.th",
939
+ "repro_mdx_a.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/repro_mdx_a.yaml"
940
+ },
941
+
942
+ "Demucs v3: repro_mdx_a_time_only":{
943
+ "9a6b4851-03af0aa6.th":"https://dl.fbaipublicfiles.com/demucs/mdx_final/9a6b4851-03af0aa6.th",
944
+ "1ef250f1-592467ce.th":"https://dl.fbaipublicfiles.com/demucs/mdx_final/1ef250f1-592467ce.th",
945
+ "repro_mdx_a_time_only.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/repro_mdx_a_time_only.yaml"
946
+ },
947
+
948
+ "Demucs v3: repro_mdx_a_hybrid_only":{
949
+ "fa0cb7f9-100d8bf4.th":"https://dl.fbaipublicfiles.com/demucs/mdx_final/fa0cb7f9-100d8bf4.th",
950
+ "902315c2-b39ce9c9.th":"https://dl.fbaipublicfiles.com/demucs/mdx_final/902315c2-b39ce9c9.th",
951
+ "repro_mdx_a_hybrid_only.yaml": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/repro_mdx_a_hybrid_only.yaml"
952
+ },
953
+
954
+ "Demucs v2: demucs": {
955
+ "demucs-e07c671f.th": "https://dl.fbaipublicfiles.com/demucs/v3.0/demucs-e07c671f.th"
956
+ },
957
+
958
+ "Demucs v2: demucs_extra": {
959
+ "demucs_extra-3646af93.th":"https://dl.fbaipublicfiles.com/demucs/v3.0/demucs_extra-3646af93.th"
960
+ },
961
+
962
+ "Demucs v2: demucs48_hq": {
963
+ "demucs48_hq-28a1282c.th":"https://dl.fbaipublicfiles.com/demucs/v3.0/demucs48_hq-28a1282c.th"
964
+ },
965
+
966
+ "Demucs v2: tasnet": {
967
+ "tasnet-beb46fac.th":"https://dl.fbaipublicfiles.com/demucs/v3.0/tasnet-beb46fac.th"
968
+ },
969
+
970
+ "Demucs v2: tasnet_extra": {
971
+ "tasnet_extra-df3777b2.th":"https://dl.fbaipublicfiles.com/demucs/v3.0/tasnet_extra-df3777b2.th"
972
+ },
973
+
974
+ "Demucs v2: demucs_unittest": {
975
+ "demucs_unittest-09ebc15f.th":"https://dl.fbaipublicfiles.com/demucs/v3.0/demucs_unittest-09ebc15f.th"
976
+ },
977
+
978
+ "Demucs v1: demucs": {
979
+ "demucs.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/demucs.th"
980
+ },
981
+
982
+ "Demucs v1: demucs_extra": {
983
+ "demucs_extra.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/demucs_extra.th"
984
+ },
985
+
986
+ "Demucs v1: light": {
987
+ "light.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/light.th"
988
+ },
989
+
990
+ "Demucs v1: light_extra": {
991
+ "light_extra.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/light_extra.th"
992
+ },
993
+
994
+ "Demucs v1: tasnet": {
995
+ "tasnet.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/tasnet.th"
996
+ },
997
+
998
+ "Demucs v1: tasnet_extra": {
999
+ "tasnet_extra.th":"https://dl.fbaipublicfiles.com/demucs/v2.0/tasnet_extra.th"
1000
+ }
1001
+ }
1002
+
1003
+ # Main Menu Labels
1004
+
1005
+ CHOOSE_PROC_METHOD_MAIN_LABEL = 'CHOOSE PROCESS METHOD (เลือก "MDX-Net")'
1006
+ SELECT_SAVED_SETTINGS_MAIN_LABEL = 'SELECT SAVED SETTINGS'
1007
+ CHOOSE_MDX_MODEL_MAIN_LABEL = 'CHOOSE MDX-NET MODEL (เลือก "UVR_MDXNET_Main" - แนะนำ)'
1008
+ BATCHES_MDX_MAIN_LABEL = 'BATCH SIZE'
1009
+ VOL_COMP_MDX_MAIN_LABEL = 'VOLUME COMPENSATION'
1010
+ SELECT_VR_MODEL_MAIN_LABEL = 'CHOOSE VR MODEL'
1011
+ AGGRESSION_SETTING_MAIN_LABEL = 'AGGRESSION SETTING'
1012
+ WINDOW_SIZE_MAIN_LABEL = 'WINDOW SIZE'
1013
+ CHOOSE_DEMUCS_MODEL_MAIN_LABEL = 'CHOOSE DEMUCS MODEL'
1014
+ CHOOSE_DEMUCS_STEMS_MAIN_LABEL = 'CHOOSE STEM(S)'
1015
+ CHOOSE_SEGMENT_MAIN_LABEL = 'SEGMENT'
1016
+ ENSEMBLE_OPTIONS_MAIN_LABEL = 'ENSEMBLE OPTIONS'
1017
+ CHOOSE_MAIN_PAIR_MAIN_LABEL = 'MAIN STEM PAIR'
1018
+ CHOOSE_ENSEMBLE_ALGORITHM_MAIN_LABEL = 'ENSEMBLE ALGORITHM'
1019
+ AVAILABLE_MODELS_MAIN_LABEL = 'AVAILABLE MODELS'
1020
+ CHOOSE_AUDIO_TOOLS_MAIN_LABEL = 'CHOOSE AUDIO TOOL'
1021
+ CHOOSE_MANUAL_ALGORITHM_MAIN_LABEL = 'CHOOSE ALGORITHM'
1022
+ CHOOSE_RATE_MAIN_LABEL = 'RATE'
1023
+ CHOOSE_SEMITONES_MAIN_LABEL = 'SEMITONES'
1024
+ GPU_CONVERSION_MAIN_LABEL = 'GPU Conversion'
1025
+
1026
+ if OPERATING_SYSTEM=="Darwin":
1027
+ LICENSE_OS_SPECIFIC_TEXT = '• This application is intended for those running macOS Catalina and above.\n' +\
1028
+ '• Application functionality for systems running macOS Mojave or lower is not guaranteed.\n' +\
1029
+ '• Application functionality for older or budget Mac systems is not guaranteed.\n\n'
1030
+ FONT_SIZE_F1 = 13
1031
+ FONT_SIZE_F2 = 11
1032
+ FONT_SIZE_F3 = 12
1033
+ FONT_SIZE_0 = 9
1034
+ FONT_SIZE_1 = 11
1035
+ FONT_SIZE_2 = 12
1036
+ FONT_SIZE_3 = 13
1037
+ FONT_SIZE_4 = 14
1038
+ FONT_SIZE_5 = 15
1039
+ FONT_SIZE_6 = 17
1040
+ HELP_HINT_CHECKBOX_WIDTH = 13
1041
+ MDX_CHECKBOXS_WIDTH = 14
1042
+ VR_CHECKBOXS_WIDTH = 14
1043
+ ENSEMBLE_CHECKBOXS_WIDTH = 18
1044
+ DEMUCS_CHECKBOXS_WIDTH = 14
1045
+ DEMUCS_PRE_CHECKBOXS_WIDTH = 20
1046
+ GEN_SETTINGS_WIDTH = 17
1047
+ MENU_COMBOBOX_WIDTH = 16
1048
+
1049
+ elif OPERATING_SYSTEM=="Linux":
1050
+ LICENSE_OS_SPECIFIC_TEXT = '• This application is intended for those running Linux Ubuntu 18.04+.\n' +\
1051
+ '• Application functionality for systems running other Linux platforms is not guaranteed.\n' +\
1052
+ '• Application functionality for older or budget systems is not guaranteed.\n\n'
1053
+ FONT_SIZE_F1 = 10
1054
+ FONT_SIZE_F2 = 8
1055
+ FONT_SIZE_F3 = 9
1056
+ FONT_SIZE_0 = 7
1057
+ FONT_SIZE_1 = 8
1058
+ FONT_SIZE_2 = 9
1059
+ FONT_SIZE_3 = 10
1060
+ FONT_SIZE_4 = 11
1061
+ FONT_SIZE_5 = 12
1062
+ FONT_SIZE_6 = 15
1063
+ HELP_HINT_CHECKBOX_WIDTH = 13
1064
+ MDX_CHECKBOXS_WIDTH = 14
1065
+ VR_CHECKBOXS_WIDTH = 16
1066
+ ENSEMBLE_CHECKBOXS_WIDTH = 25
1067
+ DEMUCS_CHECKBOXS_WIDTH = 18
1068
+ DEMUCS_PRE_CHECKBOXS_WIDTH = 27
1069
+ GEN_SETTINGS_WIDTH = 17
1070
+ MENU_COMBOBOX_WIDTH = 19
1071
+
1072
+ elif OPERATING_SYSTEM=="Windows":
1073
+ LICENSE_OS_SPECIFIC_TEXT = '• This application is intended for those running Windows 10 or higher.\n' +\
1074
+ '• Application functionality for systems running Windows 7 or lower is not guaranteed.\n' +\
1075
+ '• Application functionality for Intel Pentium & Celeron CPUs systems is not guaranteed.\n\n'
1076
+ FONT_SIZE_F1 = 10
1077
+ FONT_SIZE_F2 = 8
1078
+ FONT_SIZE_F3 = 9
1079
+ FONT_SIZE_0 = 7
1080
+ FONT_SIZE_1 = 8
1081
+ FONT_SIZE_2 = 9
1082
+ FONT_SIZE_3 = 10
1083
+ FONT_SIZE_4 = 11
1084
+ FONT_SIZE_5 = 12
1085
+ FONT_SIZE_6 = 15
1086
+ HELP_HINT_CHECKBOX_WIDTH = 16
1087
+ MDX_CHECKBOXS_WIDTH = 16
1088
+ VR_CHECKBOXS_WIDTH = 16
1089
+ ENSEMBLE_CHECKBOXS_WIDTH = 25
1090
+ DEMUCS_CHECKBOXS_WIDTH = 18
1091
+ DEMUCS_PRE_CHECKBOXS_WIDTH = 27
1092
+ GEN_SETTINGS_WIDTH = 23
1093
+ MENU_COMBOBOX_WIDTH = 19
1094
+
1095
+
1096
+ LICENSE_TEXT = lambda a, p:f'Current Application Version: Ultimate Vocal Remover {a}\n' +\
1097
+ f'Current Patch Version: {p}\n\n' +\
1098
+ 'Copyright (c) 2022 Ultimate Vocal Remover\n\n' +\
1099
+ 'UVR is free and open-source, but MIT licensed. Please credit us if you use our\n' +\
1100
+ f'models or code for projects unrelated to UVR.\n\n{LICENSE_OS_SPECIFIC_TEXT}' +\
1101
+ 'This bundle contains the UVR interface, Python, PyTorch, and other\n' +\
1102
+ 'dependencies needed to run the application effectively.\n\n' +\
1103
+ 'Website Links: This application, System or Service(s) may contain links to\n' +\
1104
+ 'other websites and downloads, and they are solely provided to you as an\n' +\
1105
+ 'additional convenience. You understand and acknowledge that by clicking\n' +\
1106
+ 'or activating such links you are accessing a site or service outside of\n' +\
1107
+ 'this application, and that we do not screen, review, approve, or otherwise\n' +\
1108
+ 'endorse any content or information contained in these linked websites.\n' +\
1109
+ 'You acknowledge and agree that we, our affiliates and partners are not\n' +\
1110
+ 'responsible for the contents of any of these linked websites, including\n' +\
1111
+ 'the accuracy or availability of information provided by the linked websites,\n' +\
1112
+ 'and we make no representations or warranties regarding your use of\n' +\
1113
+ 'the linked websites.\n\n' +\
1114
+ 'This application is MIT Licensed\n\n' +\
1115
+ 'Permission is hereby granted, free of charge, to any person obtaining a copy\n' +\
1116
+ 'of this software and associated documentation files (the "Software"), to deal\n' +\
1117
+ 'in the Software without restriction, including without limitation the rights\n' +\
1118
+ 'to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n' +\
1119
+ 'copies of the Software, and to permit persons to whom the Software is\n' +\
1120
+ 'furnished to do so, subject to the following conditions:\n\n' +\
1121
+ 'The above copyright notice and this permission notice shall be included in all\n' +\
1122
+ 'copies or substantial portions of the Software.\n\n' +\
1123
+ 'THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n' +\
1124
+ 'IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n' +\
1125
+ 'FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n' +\
1126
+ 'AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n' +\
1127
+ 'LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n' +\
1128
+ 'OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n' +\
1129
+ 'SOFTWARE.'
1130
+
1131
+ CHANGE_LOG_HEADER = lambda patch:f"Patch Version:\n\n{patch}"
1132
+
1133
+ #DND CONSTS
1134
+
1135
+ MAC_DND_CHECK = ('/Users/',
1136
+ '/Applications/',
1137
+ '/Library/',
1138
+ '/System/')
1139
+ LINUX_DND_CHECK = ('/home/',
1140
+ '/usr/')
1141
+ WINDOWS_DND_CHECK = ('A:', 'B:', 'C:', 'D:', 'E:', 'F:', 'G:', 'H:', 'I:', 'J:', 'K:', 'L:', 'M:', 'N:', 'O:', 'P:', 'Q:', 'R:', 'S:', 'T:', 'U:', 'V:', 'W:', 'X:', 'Y:', 'Z:')
1142
+
1143
+ WOOD_INST_MODEL_HASH = '0ec76fd9e65f81d8b4fbd13af4826ed8'
1144
+ WOOD_INST_PARAMS = {
1145
+ "vr_model_param": "4band_v3",
1146
+ "primary_stem": NO_WIND_INST_STEM
1147
+ }
gui_data/error_handling.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import traceback
3
+
4
+ CUDA_MEMORY_ERROR = "CUDA out of memory"
5
+ CUDA_RUNTIME_ERROR = "CUDNN error executing cudnnSetTensorNdDescriptor"
6
+ DEMUCS_MODEL_MISSING_ERROR = "is neither a single pre-trained model or a bag of models."
7
+ ENSEMBLE_MISSING_MODEL_ERROR = "local variable \'enseExport\' referenced before assignment"
8
+ FFMPEG_MISSING_ERROR = """audioread\__init__.py", line 116, in audio_open"""
9
+ FILE_MISSING_ERROR = "FileNotFoundError"
10
+ MDX_MEMORY_ERROR = "onnxruntime::CudaCall CUDA failure 2: out of memory"
11
+ MDX_MODEL_MISSING = "[ONNXRuntimeError] : 3 : NO_SUCHFILE"
12
+ MDX_MODEL_SETTINGS_ERROR = "Got invalid dimensions for input"
13
+ MDX_RUNTIME_ERROR = "onnxruntime::BFCArena::AllocateRawInternal"
14
+ MODULE_ERROR = "ModuleNotFoundError"
15
+ WINDOW_SIZE_ERROR = "h1_shape[3] must be greater than h2_shape[3]"
16
+ SF_WRITE_ERROR = "sf.write"
17
+ SYSTEM_MEMORY_ERROR = "DefaultCPUAllocator: not enough memory"
18
+ MISSING_MODEL_ERROR = "'NoneType\' object has no attribute \'model_basename\'"
19
+ ARRAY_SIZE_ERROR = "ValueError: \"array is too big; `arr.size * arr.dtype.itemsize` is larger than the maximum possible size.\""
20
+ GPU_INCOMPATIBLE_ERROR = "no kernel image is available for execution on the device"
21
+
22
+ CONTACT_DEV = 'If this error persists, please contact the developers with the error details.'
23
+
24
+ ERROR_MAPPER = {
25
+ CUDA_MEMORY_ERROR:
26
+ ('The application was unable to allocate enough GPU memory to use this model. ' +
27
+ 'Please close any GPU intensive applications and try again.\n' +
28
+ 'If the error persists, your GPU might not be supported.') ,
29
+ CUDA_RUNTIME_ERROR:
30
+ (f'Your PC cannot process this audio file with the chunk size selected. Please lower the chunk size and try again.\n\n{CONTACT_DEV}'),
31
+ DEMUCS_MODEL_MISSING_ERROR:
32
+ ('The selected Demucs model is missing. ' +
33
+ 'Please download the model or make sure it is in the correct directory.'),
34
+ ENSEMBLE_MISSING_MODEL_ERROR:
35
+ ('The application was unable to locate a model you selected for this ensemble.\n\n' +
36
+ 'Please do the following to use all compatible models:\n\n1. Navigate to the \"Updates\" tab in the Help Guide.\n2. Download and install the model expansion pack.\n3. Then try again.\n\n' +
37
+ 'If the error persists, please verify all models are present.'),
38
+ FFMPEG_MISSING_ERROR:
39
+ ('The input file type is not supported or FFmpeg is missing. Please select a file type supported by FFmpeg and try again. ' +
40
+ 'If FFmpeg is missing or not installed, you will only be able to process \".wav\" files until it is available on this system. ' +
41
+ f'See the \"More Info\" tab in the Help Guide.\n\n{CONTACT_DEV}'),
42
+ FILE_MISSING_ERROR:
43
+ (f'Missing file error raised. Please address the error and try again.\n\n{CONTACT_DEV}'),
44
+ MDX_MEMORY_ERROR:
45
+ ('The application was unable to allocate enough GPU memory to use this model.\n\n' +
46
+ 'Please do the following:\n\n1. Close any GPU intensive applications.\n2. Lower the set chunk size.\n3. Then try again.\n\n' +
47
+ 'If the error persists, your GPU might not be supported.'),
48
+ MDX_MODEL_MISSING:
49
+ ('The application could not detect this MDX-Net model on your system. ' +
50
+ 'Please make sure all the models are present in the correct directory.\n\n' +
51
+ 'If the error persists, please reinstall application or contact the developers.'),
52
+ MDX_RUNTIME_ERROR:
53
+ ('The application was unable to allocate enough GPU memory to use this model.\n\n' +
54
+ 'Please do the following:\n\n1. Close any GPU intensive applications.\n2. Lower the set chunk size.\n3. Then try again.\n\n' +
55
+ 'If the error persists, your GPU might not be supported.'),
56
+ WINDOW_SIZE_ERROR:
57
+ ('Invalid window size.\n\n' +
58
+ 'The chosen window size is likely not compatible with this model. Please select a different size and try again.'),
59
+ SF_WRITE_ERROR:
60
+ ('Could not write audio file.\n\n' +
61
+ 'This could be due to one of the following:\n\n1. Low storage on target device.\n2. The export directory no longer exists.\n3. A system permissions issue.'),
62
+ SYSTEM_MEMORY_ERROR:
63
+ ('The application was unable to allocate enough system memory to use this model.\n\n' +
64
+ 'Please do the following:\n\n1. Restart this application.\n2. Ensure any CPU intensive applications are closed.\n3. Then try again.\n\n' +
65
+ 'Please Note: Intel Pentium and Intel Celeron processors do not work well with this application.\n\n' +
66
+ 'If the error persists, the system may not have enough RAM, or your CPU might not be supported.'),
67
+ MISSING_MODEL_ERROR:
68
+ ('Model Missing: The application was unable to locate the chosen model.\n\n' +
69
+ 'If the error persists, please verify any selected models are present.'),
70
+ GPU_INCOMPATIBLE_ERROR:
71
+ ('This process is not compatible with your GPU.\n\n' +
72
+ 'Please uncheck \"GPU Conversion\" and try again'),
73
+ ARRAY_SIZE_ERROR:
74
+ ('The application was not able to process the given audiofile. Please convert the audiofile to another format and try again.'),
75
+ }
76
+
77
+ def error_text(process_method, exception):
78
+
79
+ traceback_text = ''.join(traceback.format_tb(exception.__traceback__))
80
+ message = f'{type(exception).__name__}: "{exception}"\nTraceback Error: "\n{traceback_text}"\n'
81
+ error_message = f'\n\nRaw Error Details:\n\n{message}\nError Time Stamp [{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]\n'
82
+ process = f'Last Error Received:\n\nProcess: {process_method}\n\n'
83
+
84
+ for error_type, full_text in ERROR_MAPPER.items():
85
+ if error_type in message:
86
+ final_message = full_text
87
+ break
88
+ else:
89
+ final_message = (CONTACT_DEV)
90
+
91
+ return f"{process}{final_message}{error_message}"
92
+
93
+ def error_dialouge(exception):
94
+
95
+ error_name = f'{type(exception).__name__}'
96
+ traceback_text = ''.join(traceback.format_tb(exception.__traceback__))
97
+ message = f'{error_name}: "{exception}"\n{traceback_text}"'
98
+
99
+ for error_type, full_text in ERROR_MAPPER.items():
100
+ if error_type in message:
101
+ final_message = full_text
102
+ break
103
+ else:
104
+ final_message = (f'{error_name}: {exception}\n\n{CONTACT_DEV}')
105
+
106
+ return final_message
gui_data/old_data_check.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ def file_check(original_dir, new_dir):
5
+
6
+ if os.path.isdir(original_dir):
7
+ for file in os.listdir(original_dir):
8
+ shutil.move(os.path.join(original_dir, file), os.path.join(new_dir, file))
9
+
10
+ if len(os.listdir(original_dir)) == 0:
11
+ shutil.rmtree(original_dir)
12
+
13
+ def remove_unneeded_yamls(demucs_dir):
14
+
15
+ for file in os.listdir(demucs_dir):
16
+ if file.endswith('.yaml'):
17
+ if os.path.isfile(os.path.join(demucs_dir, file)):
18
+ os.remove(os.path.join(demucs_dir, file))
19
+
20
+ def remove_temps(remove_dir):
21
+
22
+ if os.path.isdir(remove_dir):
23
+ try:
24
+ shutil.rmtree(remove_dir)
25
+ except Exception as e:
26
+ print(e)
27
+
lib_v5/mdxnet.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from pytorch_lightning import LightningModule
6
+ from .modules import TFC_TDF
7
+
8
+ dim_s = 4
9
+
10
+ class AbstractMDXNet(LightningModule):
11
+ __metaclass__ = ABCMeta
12
+
13
+ def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap):
14
+ super().__init__()
15
+ self.target_name = target_name
16
+ self.lr = lr
17
+ self.optimizer = optimizer
18
+ self.dim_c = dim_c
19
+ self.dim_f = dim_f
20
+ self.dim_t = dim_t
21
+ self.n_fft = n_fft
22
+ self.n_bins = n_fft // 2 + 1
23
+ self.hop_length = hop_length
24
+ self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False)
25
+ self.freq_pad = nn.Parameter(torch.zeros([1, dim_c, self.n_bins - self.dim_f, self.dim_t]), requires_grad=False)
26
+
27
+ def configure_optimizers(self):
28
+ if self.optimizer == 'rmsprop':
29
+ return torch.optim.RMSprop(self.parameters(), self.lr)
30
+
31
+ if self.optimizer == 'adamw':
32
+ return torch.optim.AdamW(self.parameters(), self.lr)
33
+
34
+ class ConvTDFNet(AbstractMDXNet):
35
+ def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length,
36
+ num_blocks, l, g, k, bn, bias, overlap):
37
+
38
+ super(ConvTDFNet, self).__init__(
39
+ target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap)
40
+ self.save_hyperparameters()
41
+
42
+ self.num_blocks = num_blocks
43
+ self.l = l
44
+ self.g = g
45
+ self.k = k
46
+ self.bn = bn
47
+ self.bias = bias
48
+
49
+ if optimizer == 'rmsprop':
50
+ norm = nn.BatchNorm2d
51
+
52
+ if optimizer == 'adamw':
53
+ norm = lambda input:nn.GroupNorm(2, input)
54
+
55
+ self.n = num_blocks // 2
56
+ scale = (2, 2)
57
+
58
+ self.first_conv = nn.Sequential(
59
+ nn.Conv2d(in_channels=self.dim_c, out_channels=g, kernel_size=(1, 1)),
60
+ norm(g),
61
+ nn.ReLU(),
62
+ )
63
+
64
+ f = self.dim_f
65
+ c = g
66
+ self.encoding_blocks = nn.ModuleList()
67
+ self.ds = nn.ModuleList()
68
+ for i in range(self.n):
69
+ self.encoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm))
70
+ self.ds.append(
71
+ nn.Sequential(
72
+ nn.Conv2d(in_channels=c, out_channels=c + g, kernel_size=scale, stride=scale),
73
+ norm(c + g),
74
+ nn.ReLU()
75
+ )
76
+ )
77
+ f = f // 2
78
+ c += g
79
+
80
+ self.bottleneck_block = TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)
81
+
82
+ self.decoding_blocks = nn.ModuleList()
83
+ self.us = nn.ModuleList()
84
+ for i in range(self.n):
85
+ self.us.append(
86
+ nn.Sequential(
87
+ nn.ConvTranspose2d(in_channels=c, out_channels=c - g, kernel_size=scale, stride=scale),
88
+ norm(c - g),
89
+ nn.ReLU()
90
+ )
91
+ )
92
+ f = f * 2
93
+ c -= g
94
+
95
+ self.decoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm))
96
+
97
+ self.final_conv = nn.Sequential(
98
+ nn.Conv2d(in_channels=c, out_channels=self.dim_c, kernel_size=(1, 1)),
99
+ )
100
+
101
+ def forward(self, x):
102
+
103
+ x = self.first_conv(x)
104
+
105
+ x = x.transpose(-1, -2)
106
+
107
+ ds_outputs = []
108
+ for i in range(self.n):
109
+ x = self.encoding_blocks[i](x)
110
+ ds_outputs.append(x)
111
+ x = self.ds[i](x)
112
+
113
+ x = self.bottleneck_block(x)
114
+
115
+ for i in range(self.n):
116
+ x = self.us[i](x)
117
+ x *= ds_outputs[-i - 1]
118
+ x = self.decoding_blocks[i](x)
119
+
120
+ x = x.transpose(-1, -2)
121
+
122
+ x = self.final_conv(x)
123
+
124
+ return x
125
+
126
+ class Mixer(nn.Module):
127
+ def __init__(self, device, mixer_path):
128
+
129
+ super(Mixer, self).__init__()
130
+
131
+ self.linear = nn.Linear((dim_s+1)*2, dim_s*2, bias=False)
132
+
133
+ self.load_state_dict(
134
+ torch.load(mixer_path, map_location=device)
135
+ )
136
+
137
+ def forward(self, x):
138
+ x = x.reshape(1,(dim_s+1)*2,-1).transpose(-1,-2)
139
+ x = self.linear(x)
140
+ return x.transpose(-1,-2).reshape(dim_s,2,-1)
lib_v5/mixer.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35121bf49b892b112b073b605c86efa1812bc37c9603c0fbc1726f4320c2aa91
3
+ size 132
lib_v5/modules.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class TFC(nn.Module):
6
+ def __init__(self, c, l, k, norm):
7
+ super(TFC, self).__init__()
8
+
9
+ self.H = nn.ModuleList()
10
+ for i in range(l):
11
+ self.H.append(
12
+ nn.Sequential(
13
+ nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2),
14
+ norm(c),
15
+ nn.ReLU(),
16
+ )
17
+ )
18
+
19
+ def forward(self, x):
20
+ for h in self.H:
21
+ x = h(x)
22
+ return x
23
+
24
+
25
+ class DenseTFC(nn.Module):
26
+ def __init__(self, c, l, k, norm):
27
+ super(DenseTFC, self).__init__()
28
+
29
+ self.conv = nn.ModuleList()
30
+ for i in range(l):
31
+ self.conv.append(
32
+ nn.Sequential(
33
+ nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2),
34
+ norm(c),
35
+ nn.ReLU(),
36
+ )
37
+ )
38
+
39
+ def forward(self, x):
40
+ for layer in self.conv[:-1]:
41
+ x = torch.cat([layer(x), x], 1)
42
+ return self.conv[-1](x)
43
+
44
+
45
+ class TFC_TDF(nn.Module):
46
+ def __init__(self, c, l, f, k, bn, dense=False, bias=True, norm=nn.BatchNorm2d):
47
+
48
+ super(TFC_TDF, self).__init__()
49
+
50
+ self.use_tdf = bn is not None
51
+
52
+ self.tfc = DenseTFC(c, l, k, norm) if dense else TFC(c, l, k, norm)
53
+
54
+ if self.use_tdf:
55
+ if bn == 0:
56
+ self.tdf = nn.Sequential(
57
+ nn.Linear(f, f, bias=bias),
58
+ norm(c),
59
+ nn.ReLU()
60
+ )
61
+ else:
62
+ self.tdf = nn.Sequential(
63
+ nn.Linear(f, f // bn, bias=bias),
64
+ norm(c),
65
+ nn.ReLU(),
66
+ nn.Linear(f // bn, f, bias=bias),
67
+ norm(c),
68
+ nn.ReLU()
69
+ )
70
+
71
+ def forward(self, x):
72
+ x = self.tfc(x)
73
+ return x + self.tdf(x) if self.use_tdf else x
74
+
lib_v5/pyrb.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import tempfile
4
+ import six
5
+ import numpy as np
6
+ import soundfile as sf
7
+ import sys
8
+
9
+ if getattr(sys, 'frozen', False):
10
+ BASE_PATH_RUB = sys._MEIPASS
11
+ else:
12
+ BASE_PATH_RUB = os.path.dirname(os.path.abspath(__file__))
13
+
14
+ __all__ = ['time_stretch', 'pitch_shift']
15
+
16
+ __RUBBERBAND_UTIL = os.path.join(BASE_PATH_RUB, 'rubberband')
17
+
18
+ if six.PY2:
19
+ DEVNULL = open(os.devnull, 'w')
20
+ else:
21
+ DEVNULL = subprocess.DEVNULL
22
+
23
+ def __rubberband(y, sr, **kwargs):
24
+
25
+ assert sr > 0
26
+
27
+ # Get the input and output tempfile
28
+ fd, infile = tempfile.mkstemp(suffix='.wav')
29
+ os.close(fd)
30
+ fd, outfile = tempfile.mkstemp(suffix='.wav')
31
+ os.close(fd)
32
+
33
+ # dump the audio
34
+ sf.write(infile, y, sr)
35
+
36
+ try:
37
+ # Execute rubberband
38
+ arguments = [__RUBBERBAND_UTIL, '-q']
39
+
40
+ for key, value in six.iteritems(kwargs):
41
+ arguments.append(str(key))
42
+ arguments.append(str(value))
43
+
44
+ arguments.extend([infile, outfile])
45
+
46
+ subprocess.check_call(arguments, stdout=DEVNULL, stderr=DEVNULL)
47
+
48
+ # Load the processed audio.
49
+ y_out, _ = sf.read(outfile, always_2d=True)
50
+
51
+ # make sure that output dimensions matches input
52
+ if y.ndim == 1:
53
+ y_out = np.squeeze(y_out)
54
+
55
+ except OSError as exc:
56
+ six.raise_from(RuntimeError('Failed to execute rubberband. '
57
+ 'Please verify that rubberband-cli '
58
+ 'is installed.'),
59
+ exc)
60
+
61
+ finally:
62
+ # Remove temp files
63
+ os.unlink(infile)
64
+ os.unlink(outfile)
65
+
66
+ return y_out
67
+
68
+ def time_stretch(y, sr, rate, rbargs=None):
69
+ if rate <= 0:
70
+ raise ValueError('rate must be strictly positive')
71
+
72
+ if rate == 1.0:
73
+ return y
74
+
75
+ if rbargs is None:
76
+ rbargs = dict()
77
+
78
+ rbargs.setdefault('--tempo', rate)
79
+
80
+ return __rubberband(y, sr, **rbargs)
81
+
82
+ def pitch_shift(y, sr, n_steps, rbargs=None):
83
+
84
+ if n_steps == 0:
85
+ return y
86
+
87
+ if rbargs is None:
88
+ rbargs = dict()
89
+
90
+ rbargs.setdefault('--pitch', n_steps)
91
+
92
+ return __rubberband(y, sr, **rbargs)
lib_v5/spec_utils.py ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import soundfile as sf
4
+ import math
5
+ import random
6
+ import math
7
+ import platform
8
+ import traceback
9
+ from . import pyrb
10
+ #cur
11
+ OPERATING_SYSTEM = platform.system()
12
+ SYSTEM_ARCH = platform.platform()
13
+ SYSTEM_PROC = platform.processor()
14
+ ARM = 'arm'
15
+
16
+ if OPERATING_SYSTEM == 'Windows':
17
+ from pyrubberband import pyrb
18
+ else:
19
+ from . import pyrb
20
+
21
+ if OPERATING_SYSTEM == 'Darwin':
22
+ wav_resolution = "polyphase" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else "sinc_fastest"
23
+ else:
24
+ wav_resolution = "sinc_fastest"
25
+
26
+ MAX_SPEC = 'Max Spec'
27
+ MIN_SPEC = 'Min Spec'
28
+ AVERAGE = 'Average'
29
+
30
+ def crop_center(h1, h2):
31
+ h1_shape = h1.size()
32
+ h2_shape = h2.size()
33
+
34
+ if h1_shape[3] == h2_shape[3]:
35
+ return h1
36
+ elif h1_shape[3] < h2_shape[3]:
37
+ raise ValueError('h1_shape[3] must be greater than h2_shape[3]')
38
+
39
+ s_time = (h1_shape[3] - h2_shape[3]) // 2
40
+ e_time = s_time + h2_shape[3]
41
+ h1 = h1[:, :, :, s_time:e_time]
42
+
43
+ return h1
44
+
45
+ def preprocess(X_spec):
46
+ X_mag = np.abs(X_spec)
47
+ X_phase = np.angle(X_spec)
48
+
49
+ return X_mag, X_phase
50
+
51
+ def make_padding(width, cropsize, offset):
52
+ left = offset
53
+ roi_size = cropsize - offset * 2
54
+ if roi_size == 0:
55
+ roi_size = cropsize
56
+ right = roi_size - (width % roi_size) + left
57
+
58
+ return left, right, roi_size
59
+
60
+ def wave_to_spectrogram(wave, hop_length, n_fft, mid_side=False, mid_side_b2=False, reverse=False):
61
+ if reverse:
62
+ wave_left = np.flip(np.asfortranarray(wave[0]))
63
+ wave_right = np.flip(np.asfortranarray(wave[1]))
64
+ elif mid_side:
65
+ wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2)
66
+ wave_right = np.asfortranarray(np.subtract(wave[0], wave[1]))
67
+ elif mid_side_b2:
68
+ wave_left = np.asfortranarray(np.add(wave[1], wave[0] * .5))
69
+ wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * .5))
70
+ else:
71
+ wave_left = np.asfortranarray(wave[0])
72
+ wave_right = np.asfortranarray(wave[1])
73
+
74
+ spec_left = librosa.stft(wave_left, n_fft, hop_length=hop_length)
75
+ spec_right = librosa.stft(wave_right, n_fft, hop_length=hop_length)
76
+
77
+ spec = np.asfortranarray([spec_left, spec_right])
78
+
79
+ return spec
80
+
81
+ def wave_to_spectrogram_mt(wave, hop_length, n_fft, mid_side=False, mid_side_b2=False, reverse=False):
82
+ import threading
83
+
84
+ if reverse:
85
+ wave_left = np.flip(np.asfortranarray(wave[0]))
86
+ wave_right = np.flip(np.asfortranarray(wave[1]))
87
+ elif mid_side:
88
+ wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2)
89
+ wave_right = np.asfortranarray(np.subtract(wave[0], wave[1]))
90
+ elif mid_side_b2:
91
+ wave_left = np.asfortranarray(np.add(wave[1], wave[0] * .5))
92
+ wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * .5))
93
+ else:
94
+ wave_left = np.asfortranarray(wave[0])
95
+ wave_right = np.asfortranarray(wave[1])
96
+
97
+ def run_thread(**kwargs):
98
+ global spec_left
99
+ spec_left = librosa.stft(**kwargs)
100
+
101
+ thread = threading.Thread(target=run_thread, kwargs={'y': wave_left, 'n_fft': n_fft, 'hop_length': hop_length})
102
+ thread.start()
103
+ spec_right = librosa.stft(wave_right, n_fft, hop_length=hop_length)
104
+ thread.join()
105
+
106
+ spec = np.asfortranarray([spec_left, spec_right])
107
+
108
+ return spec
109
+
110
+ def normalize(wave, is_normalize=False):
111
+ """Save output music files"""
112
+ maxv = np.abs(wave).max()
113
+ if maxv > 1.0:
114
+ print(f"\nNormalization Set {is_normalize}: Input above threshold for clipping. Max:{maxv}")
115
+ if is_normalize:
116
+ print(f"The result was normalized.")
117
+ wave /= maxv
118
+ else:
119
+ print(f"The result was not normalized.")
120
+ else:
121
+ print(f"\nNormalization Set {is_normalize}: Input not above threshold for clipping. Max:{maxv}")
122
+
123
+ return wave
124
+
125
+ def normalize_two_stem(wave, mix, is_normalize=False):
126
+ """Save output music files"""
127
+
128
+ maxv = np.abs(wave).max()
129
+ max_mix = np.abs(mix).max()
130
+
131
+ if maxv > 1.0:
132
+ print(f"\nNormalization Set {is_normalize}: Primary source above threshold for clipping. Max:{maxv}")
133
+ print(f"\nNormalization Set {is_normalize}: Mixture above threshold for clipping. Max:{max_mix}")
134
+ if is_normalize:
135
+ print(f"The result was normalized.")
136
+ wave /= maxv
137
+ mix /= maxv
138
+ else:
139
+ print(f"The result was not normalized.")
140
+ else:
141
+ print(f"\nNormalization Set {is_normalize}: Input not above threshold for clipping. Max:{maxv}")
142
+
143
+
144
+ print(f"\nNormalization Set {is_normalize}: Primary source - Max:{np.abs(wave).max()}")
145
+ print(f"\nNormalization Set {is_normalize}: Mixture - Max:{np.abs(mix).max()}")
146
+
147
+ return wave, mix
148
+
149
+ def combine_spectrograms(specs, mp):
150
+ l = min([specs[i].shape[2] for i in specs])
151
+ spec_c = np.zeros(shape=(2, mp.param['bins'] + 1, l), dtype=np.complex64)
152
+ offset = 0
153
+ bands_n = len(mp.param['band'])
154
+
155
+ for d in range(1, bands_n + 1):
156
+ h = mp.param['band'][d]['crop_stop'] - mp.param['band'][d]['crop_start']
157
+ spec_c[:, offset:offset+h, :l] = specs[d][:, mp.param['band'][d]['crop_start']:mp.param['band'][d]['crop_stop'], :l]
158
+ offset += h
159
+
160
+ if offset > mp.param['bins']:
161
+ raise ValueError('Too much bins')
162
+
163
+ # lowpass fiter
164
+ if mp.param['pre_filter_start'] > 0: # and mp.param['band'][bands_n]['res_type'] in ['scipy', 'polyphase']:
165
+ if bands_n == 1:
166
+ spec_c = fft_lp_filter(spec_c, mp.param['pre_filter_start'], mp.param['pre_filter_stop'])
167
+ else:
168
+ gp = 1
169
+ for b in range(mp.param['pre_filter_start'] + 1, mp.param['pre_filter_stop']):
170
+ g = math.pow(10, -(b - mp.param['pre_filter_start']) * (3.5 - gp) / 20.0)
171
+ gp = g
172
+ spec_c[:, b, :] *= g
173
+
174
+ return np.asfortranarray(spec_c)
175
+
176
+ def spectrogram_to_image(spec, mode='magnitude'):
177
+ if mode == 'magnitude':
178
+ if np.iscomplexobj(spec):
179
+ y = np.abs(spec)
180
+ else:
181
+ y = spec
182
+ y = np.log10(y ** 2 + 1e-8)
183
+ elif mode == 'phase':
184
+ if np.iscomplexobj(spec):
185
+ y = np.angle(spec)
186
+ else:
187
+ y = spec
188
+
189
+ y -= y.min()
190
+ y *= 255 / y.max()
191
+ img = np.uint8(y)
192
+
193
+ if y.ndim == 3:
194
+ img = img.transpose(1, 2, 0)
195
+ img = np.concatenate([
196
+ np.max(img, axis=2, keepdims=True), img
197
+ ], axis=2)
198
+
199
+ return img
200
+
201
+ def reduce_vocal_aggressively(X, y, softmask):
202
+ v = X - y
203
+ y_mag_tmp = np.abs(y)
204
+ v_mag_tmp = np.abs(v)
205
+
206
+ v_mask = v_mag_tmp > y_mag_tmp
207
+ y_mag = np.clip(y_mag_tmp - v_mag_tmp * v_mask * softmask, 0, np.inf)
208
+
209
+ return y_mag * np.exp(1.j * np.angle(y))
210
+
211
+ def merge_artifacts(y_mask, thres=0.01, min_range=64, fade_size=32):
212
+ mask = y_mask
213
+
214
+ try:
215
+ if min_range < fade_size * 2:
216
+ raise ValueError('min_range must be >= fade_size * 2')
217
+
218
+ idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0]
219
+ start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0])
220
+ end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1])
221
+ artifact_idx = np.where(end_idx - start_idx > min_range)[0]
222
+ weight = np.zeros_like(y_mask)
223
+ if len(artifact_idx) > 0:
224
+ start_idx = start_idx[artifact_idx]
225
+ end_idx = end_idx[artifact_idx]
226
+ old_e = None
227
+ for s, e in zip(start_idx, end_idx):
228
+ if old_e is not None and s - old_e < fade_size:
229
+ s = old_e - fade_size * 2
230
+
231
+ if s != 0:
232
+ weight[:, :, s:s + fade_size] = np.linspace(0, 1, fade_size)
233
+ else:
234
+ s -= fade_size
235
+
236
+ if e != y_mask.shape[2]:
237
+ weight[:, :, e - fade_size:e] = np.linspace(1, 0, fade_size)
238
+ else:
239
+ e += fade_size
240
+
241
+ weight[:, :, s + fade_size:e - fade_size] = 1
242
+ old_e = e
243
+
244
+ v_mask = 1 - y_mask
245
+ y_mask += weight * v_mask
246
+
247
+ mask = y_mask
248
+ except Exception as e:
249
+ error_name = f'{type(e).__name__}'
250
+ traceback_text = ''.join(traceback.format_tb(e.__traceback__))
251
+ message = f'{error_name}: "{e}"\n{traceback_text}"'
252
+ print('Post Process Failed: ', message)
253
+
254
+
255
+ return mask
256
+
257
+ def align_wave_head_and_tail(a, b):
258
+ l = min([a[0].size, b[0].size])
259
+
260
+ return a[:l,:l], b[:l,:l]
261
+
262
+ def spectrogram_to_wave(spec, hop_length, mid_side, mid_side_b2, reverse, clamp=False):
263
+ spec_left = np.asfortranarray(spec[0])
264
+ spec_right = np.asfortranarray(spec[1])
265
+
266
+ wave_left = librosa.istft(spec_left, hop_length=hop_length)
267
+ wave_right = librosa.istft(spec_right, hop_length=hop_length)
268
+
269
+ if reverse:
270
+ return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)])
271
+ elif mid_side:
272
+ return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
273
+ elif mid_side_b2:
274
+ return np.asfortranarray([np.add(wave_right / 1.25, .4 * wave_left), np.subtract(wave_left / 1.25, .4 * wave_right)])
275
+ else:
276
+ return np.asfortranarray([wave_left, wave_right])
277
+
278
+ def spectrogram_to_wave_mt(spec, hop_length, mid_side, reverse, mid_side_b2):
279
+ import threading
280
+
281
+ spec_left = np.asfortranarray(spec[0])
282
+ spec_right = np.asfortranarray(spec[1])
283
+
284
+ def run_thread(**kwargs):
285
+ global wave_left
286
+ wave_left = librosa.istft(**kwargs)
287
+
288
+ thread = threading.Thread(target=run_thread, kwargs={'stft_matrix': spec_left, 'hop_length': hop_length})
289
+ thread.start()
290
+ wave_right = librosa.istft(spec_right, hop_length=hop_length)
291
+ thread.join()
292
+
293
+ if reverse:
294
+ return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)])
295
+ elif mid_side:
296
+ return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
297
+ elif mid_side_b2:
298
+ return np.asfortranarray([np.add(wave_right / 1.25, .4 * wave_left), np.subtract(wave_left / 1.25, .4 * wave_right)])
299
+ else:
300
+ return np.asfortranarray([wave_left, wave_right])
301
+
302
+ def cmb_spectrogram_to_wave(spec_m, mp, extra_bins_h=None, extra_bins=None):
303
+ bands_n = len(mp.param['band'])
304
+ offset = 0
305
+
306
+ for d in range(1, bands_n + 1):
307
+ bp = mp.param['band'][d]
308
+ spec_s = np.ndarray(shape=(2, bp['n_fft'] // 2 + 1, spec_m.shape[2]), dtype=complex)
309
+ h = bp['crop_stop'] - bp['crop_start']
310
+ spec_s[:, bp['crop_start']:bp['crop_stop'], :] = spec_m[:, offset:offset+h, :]
311
+
312
+ offset += h
313
+ if d == bands_n: # higher
314
+ if extra_bins_h: # if --high_end_process bypass
315
+ max_bin = bp['n_fft'] // 2
316
+ spec_s[:, max_bin-extra_bins_h:max_bin, :] = extra_bins[:, :extra_bins_h, :]
317
+ if bp['hpf_start'] > 0:
318
+ spec_s = fft_hp_filter(spec_s, bp['hpf_start'], bp['hpf_stop'] - 1)
319
+ if bands_n == 1:
320
+ wave = spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse'])
321
+ else:
322
+ wave = np.add(wave, spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse']))
323
+ else:
324
+ sr = mp.param['band'][d+1]['sr']
325
+ if d == 1: # lower
326
+ spec_s = fft_lp_filter(spec_s, bp['lpf_start'], bp['lpf_stop'])
327
+ wave = librosa.resample(spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse']), bp['sr'], sr, res_type=wav_resolution)
328
+ else: # mid
329
+ spec_s = fft_hp_filter(spec_s, bp['hpf_start'], bp['hpf_stop'] - 1)
330
+ spec_s = fft_lp_filter(spec_s, bp['lpf_start'], bp['lpf_stop'])
331
+ wave2 = np.add(wave, spectrogram_to_wave(spec_s, bp['hl'], mp.param['mid_side'], mp.param['mid_side_b2'], mp.param['reverse']))
332
+ wave = librosa.resample(wave2, bp['sr'], sr, res_type=wav_resolution)
333
+
334
+ return wave
335
+
336
+ def fft_lp_filter(spec, bin_start, bin_stop):
337
+ g = 1.0
338
+ for b in range(bin_start, bin_stop):
339
+ g -= 1 / (bin_stop - bin_start)
340
+ spec[:, b, :] = g * spec[:, b, :]
341
+
342
+ spec[:, bin_stop:, :] *= 0
343
+
344
+ return spec
345
+
346
+ def fft_hp_filter(spec, bin_start, bin_stop):
347
+ g = 1.0
348
+ for b in range(bin_start, bin_stop, -1):
349
+ g -= 1 / (bin_start - bin_stop)
350
+ spec[:, b, :] = g * spec[:, b, :]
351
+
352
+ spec[:, 0:bin_stop+1, :] *= 0
353
+
354
+ return spec
355
+
356
+ def mirroring(a, spec_m, input_high_end, mp):
357
+ if 'mirroring' == a:
358
+ mirror = np.flip(np.abs(spec_m[:, mp.param['pre_filter_start']-10-input_high_end.shape[1]:mp.param['pre_filter_start']-10, :]), 1)
359
+ mirror = mirror * np.exp(1.j * np.angle(input_high_end))
360
+
361
+ return np.where(np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror)
362
+
363
+ if 'mirroring2' == a:
364
+ mirror = np.flip(np.abs(spec_m[:, mp.param['pre_filter_start']-10-input_high_end.shape[1]:mp.param['pre_filter_start']-10, :]), 1)
365
+ mi = np.multiply(mirror, input_high_end * 1.7)
366
+
367
+ return np.where(np.abs(input_high_end) <= np.abs(mi), input_high_end, mi)
368
+
369
+ def adjust_aggr(mask, is_non_accom_stem, aggressiveness):
370
+ aggr = aggressiveness['value']
371
+
372
+ if aggr != 0:
373
+ if is_non_accom_stem:
374
+ aggr = 1 - aggr
375
+
376
+ aggr = [aggr, aggr]
377
+
378
+ if aggressiveness['aggr_correction'] is not None:
379
+ aggr[0] += aggressiveness['aggr_correction']['left']
380
+ aggr[1] += aggressiveness['aggr_correction']['right']
381
+
382
+ for ch in range(2):
383
+ mask[ch, :aggressiveness['split_bin']] = np.power(mask[ch, :aggressiveness['split_bin']], 1 + aggr[ch] / 3)
384
+ mask[ch, aggressiveness['split_bin']:] = np.power(mask[ch, aggressiveness['split_bin']:], 1 + aggr[ch])
385
+
386
+ # if is_non_accom_stem:
387
+ # mask = (1.0 - mask)
388
+
389
+ return mask
390
+
391
+ def stft(wave, nfft, hl):
392
+ wave_left = np.asfortranarray(wave[0])
393
+ wave_right = np.asfortranarray(wave[1])
394
+ spec_left = librosa.stft(wave_left, nfft, hop_length=hl)
395
+ spec_right = librosa.stft(wave_right, nfft, hop_length=hl)
396
+ spec = np.asfortranarray([spec_left, spec_right])
397
+
398
+ return spec
399
+
400
+ def istft(spec, hl):
401
+ spec_left = np.asfortranarray(spec[0])
402
+ spec_right = np.asfortranarray(spec[1])
403
+ wave_left = librosa.istft(spec_left, hop_length=hl)
404
+ wave_right = librosa.istft(spec_right, hop_length=hl)
405
+ wave = np.asfortranarray([wave_left, wave_right])
406
+
407
+ return wave
408
+
409
+ def spec_effects(wave, algorithm='Default', value=None):
410
+ spec = [stft(wave[0],2048,1024), stft(wave[1],2048,1024)]
411
+ if algorithm == 'Min_Mag':
412
+ v_spec_m = np.where(np.abs(spec[1]) <= np.abs(spec[0]), spec[1], spec[0])
413
+ wave = istft(v_spec_m,1024)
414
+ elif algorithm == 'Max_Mag':
415
+ v_spec_m = np.where(np.abs(spec[1]) >= np.abs(spec[0]), spec[1], spec[0])
416
+ wave = istft(v_spec_m,1024)
417
+ elif algorithm == 'Default':
418
+ wave = (wave[1] * value) + (wave[0] * (1-value))
419
+ elif algorithm == 'Invert_p':
420
+ X_mag = np.abs(spec[0])
421
+ y_mag = np.abs(spec[1])
422
+ max_mag = np.where(X_mag >= y_mag, X_mag, y_mag)
423
+ v_spec = spec[1] - max_mag * np.exp(1.j * np.angle(spec[0]))
424
+ wave = istft(v_spec,1024)
425
+
426
+ return wave
427
+
428
+ def spectrogram_to_wave_no_mp(spec, n_fft=2048, hop_length=1024):
429
+ wave = librosa.istft(spec, n_fft=n_fft, hop_length=hop_length)
430
+
431
+ if wave.ndim == 1:
432
+ wave = np.asfortranarray([wave,wave])
433
+
434
+ return wave
435
+
436
+ def wave_to_spectrogram_no_mp(wave):
437
+
438
+ spec = librosa.stft(wave, n_fft=2048, hop_length=1024)
439
+
440
+ if spec.ndim == 1:
441
+ spec = np.asfortranarray([spec,spec])
442
+
443
+ return spec
444
+
445
+ def invert_audio(specs, invert_p=True):
446
+
447
+ ln = min([specs[0].shape[2], specs[1].shape[2]])
448
+ specs[0] = specs[0][:,:,:ln]
449
+ specs[1] = specs[1][:,:,:ln]
450
+
451
+ if invert_p:
452
+ X_mag = np.abs(specs[0])
453
+ y_mag = np.abs(specs[1])
454
+ max_mag = np.where(X_mag >= y_mag, X_mag, y_mag)
455
+ v_spec = specs[1] - max_mag * np.exp(1.j * np.angle(specs[0]))
456
+ else:
457
+ specs[1] = reduce_vocal_aggressively(specs[0], specs[1], 0.2)
458
+ v_spec = specs[0] - specs[1]
459
+
460
+ return v_spec
461
+
462
+ def invert_stem(mixture, stem):
463
+
464
+ mixture = wave_to_spectrogram_no_mp(mixture)
465
+ stem = wave_to_spectrogram_no_mp(stem)
466
+ output = spectrogram_to_wave_no_mp(invert_audio([mixture, stem]))
467
+
468
+ return -output.T
469
+
470
+ def ensembling(a, specs):
471
+ for i in range(1, len(specs)):
472
+ if i == 1:
473
+ spec = specs[0]
474
+
475
+ ln = min([spec.shape[2], specs[i].shape[2]])
476
+ spec = spec[:,:,:ln]
477
+ specs[i] = specs[i][:,:,:ln]
478
+
479
+ if MIN_SPEC == a:
480
+ spec = np.where(np.abs(specs[i]) <= np.abs(spec), specs[i], spec)
481
+ if MAX_SPEC == a:
482
+ spec = np.where(np.abs(specs[i]) >= np.abs(spec), specs[i], spec)
483
+ if AVERAGE == a:
484
+ spec = np.where(np.abs(specs[i]) == np.abs(spec), specs[i], spec)
485
+
486
+ return spec
487
+
488
+ def ensemble_inputs(audio_input, algorithm, is_normalization, wav_type_set, save_path):
489
+
490
+ wavs_ = []
491
+
492
+ if algorithm == AVERAGE:
493
+ output = average_audio(audio_input)
494
+ samplerate = 44100
495
+ else:
496
+ specs = []
497
+
498
+ for i in range(len(audio_input)):
499
+ wave, samplerate = librosa.load(audio_input[i], mono=False, sr=44100)
500
+ wavs_.append(wave)
501
+ spec = wave_to_spectrogram_no_mp(wave)
502
+ specs.append(spec)
503
+
504
+ wave_shapes = [w.shape[1] for w in wavs_]
505
+ target_shape = wavs_[wave_shapes.index(max(wave_shapes))]
506
+
507
+ output = spectrogram_to_wave_no_mp(ensembling(algorithm, specs))
508
+ output = to_shape(output, target_shape.shape)
509
+
510
+ sf.write(save_path, normalize(output.T, is_normalization), samplerate, subtype=wav_type_set)
511
+
512
+ def to_shape(x, target_shape):
513
+ padding_list = []
514
+ for x_dim, target_dim in zip(x.shape, target_shape):
515
+ pad_value = (target_dim - x_dim)
516
+ pad_tuple = ((0, pad_value))
517
+ padding_list.append(pad_tuple)
518
+
519
+ return np.pad(x, tuple(padding_list), mode='constant')
520
+
521
+ def to_shape_minimize(x: np.ndarray, target_shape):
522
+
523
+ padding_list = []
524
+ for x_dim, target_dim in zip(x.shape, target_shape):
525
+ pad_value = (target_dim - x_dim)
526
+ pad_tuple = ((0, pad_value))
527
+ padding_list.append(pad_tuple)
528
+
529
+ return np.pad(x, tuple(padding_list), mode='constant')
530
+
531
+ def augment_audio(export_path, audio_file, rate, is_normalization, wav_type_set, save_format=None, is_pitch=False):
532
+
533
+ wav, sr = librosa.load(audio_file, sr=44100, mono=False)
534
+
535
+ if wav.ndim == 1:
536
+ wav = np.asfortranarray([wav,wav])
537
+
538
+ if is_pitch:
539
+ wav_1 = pyrb.pitch_shift(wav[0], sr, rate, rbargs=None)
540
+ wav_2 = pyrb.pitch_shift(wav[1], sr, rate, rbargs=None)
541
+ else:
542
+ wav_1 = pyrb.time_stretch(wav[0], sr, rate, rbargs=None)
543
+ wav_2 = pyrb.time_stretch(wav[1], sr, rate, rbargs=None)
544
+
545
+ if wav_1.shape > wav_2.shape:
546
+ wav_2 = to_shape(wav_2, wav_1.shape)
547
+ if wav_1.shape < wav_2.shape:
548
+ wav_1 = to_shape(wav_1, wav_2.shape)
549
+
550
+ wav_mix = np.asfortranarray([wav_1, wav_2])
551
+
552
+ sf.write(export_path, normalize(wav_mix.T, is_normalization), sr, subtype=wav_type_set)
553
+ save_format(export_path)
554
+
555
+ def average_audio(audio):
556
+
557
+ waves = []
558
+ wave_shapes = []
559
+ final_waves = []
560
+
561
+ for i in range(len(audio)):
562
+ wave = librosa.load(audio[i], sr=44100, mono=False)
563
+ waves.append(wave[0])
564
+ wave_shapes.append(wave[0].shape[1])
565
+
566
+ wave_shapes_index = wave_shapes.index(max(wave_shapes))
567
+ target_shape = waves[wave_shapes_index]
568
+ waves.pop(wave_shapes_index)
569
+ final_waves.append(target_shape)
570
+
571
+ for n_array in waves:
572
+ wav_target = to_shape(n_array, target_shape.shape)
573
+ final_waves.append(wav_target)
574
+
575
+ waves = sum(final_waves)
576
+ waves = waves/len(audio)
577
+
578
+ return waves
579
+
580
+ def average_dual_sources(wav_1, wav_2, value):
581
+
582
+ if wav_1.shape > wav_2.shape:
583
+ wav_2 = to_shape(wav_2, wav_1.shape)
584
+ if wav_1.shape < wav_2.shape:
585
+ wav_1 = to_shape(wav_1, wav_2.shape)
586
+
587
+ wave = (wav_1 * value) + (wav_2 * (1-value))
588
+
589
+ return wave
590
+
591
+ def reshape_sources(wav_1: np.ndarray, wav_2: np.ndarray):
592
+
593
+ if wav_1.shape > wav_2.shape:
594
+ wav_2 = to_shape(wav_2, wav_1.shape)
595
+ if wav_1.shape < wav_2.shape:
596
+ ln = min([wav_1.shape[1], wav_2.shape[1]])
597
+ wav_2 = wav_2[:,:ln]
598
+
599
+ ln = min([wav_1.shape[1], wav_2.shape[1]])
600
+ wav_1 = wav_1[:,:ln]
601
+ wav_2 = wav_2[:,:ln]
602
+
603
+ return wav_2
604
+
605
+ def align_audio(file1, file2, file2_aligned, file_subtracted, wav_type_set, is_normalization, command_Text, progress_bar_main_var, save_format):
606
+ def get_diff(a, b):
607
+ corr = np.correlate(a, b, "full")
608
+ diff = corr.argmax() - (b.shape[0] - 1)
609
+ return diff
610
+
611
+ progress_bar_main_var.set(10)
612
+
613
+ # read tracks
614
+ wav1, sr1 = librosa.load(file1, sr=44100, mono=False)
615
+ wav2, sr2 = librosa.load(file2, sr=44100, mono=False)
616
+ wav1 = wav1.transpose()
617
+ wav2 = wav2.transpose()
618
+
619
+ command_Text(f"Audio file shapes: {wav1.shape} / {wav2.shape}\n")
620
+
621
+ wav2_org = wav2.copy()
622
+ progress_bar_main_var.set(20)
623
+
624
+ command_Text("Processing files... \n")
625
+
626
+ # pick random position and get diff
627
+
628
+ counts = {} # counting up for each diff value
629
+ progress = 20
630
+
631
+ check_range = 64
632
+
633
+ base = (64 / check_range)
634
+
635
+ for i in range(check_range):
636
+ index = int(random.uniform(44100 * 2, min(wav1.shape[0], wav2.shape[0]) - 44100 * 2))
637
+ shift = int(random.uniform(-22050,+22050))
638
+ samp1 = wav1[index :index +44100, 0] # currently use left channel
639
+ samp2 = wav2[index+shift:index+shift+44100, 0]
640
+ progress += 1 * base
641
+ progress_bar_main_var.set(progress)
642
+ diff = get_diff(samp1, samp2)
643
+ diff -= shift
644
+
645
+ if abs(diff) < 22050:
646
+ if not diff in counts:
647
+ counts[diff] = 0
648
+ counts[diff] += 1
649
+
650
+ # use max counted diff value
651
+ max_count = 0
652
+ est_diff = 0
653
+ for diff in counts.keys():
654
+ if counts[diff] > max_count:
655
+ max_count = counts[diff]
656
+ est_diff = diff
657
+
658
+ command_Text(f"Estimated difference is {est_diff} (count: {max_count})\n")
659
+
660
+ progress_bar_main_var.set(90)
661
+
662
+ audio_files = []
663
+
664
+ def save_aligned_audio(wav2_aligned):
665
+ command_Text(f"Aligned File 2 with File 1.\n")
666
+ command_Text(f"Saving files... ")
667
+ sf.write(file2_aligned, normalize(wav2_aligned, is_normalization), sr2, subtype=wav_type_set)
668
+ save_format(file2_aligned)
669
+ min_len = min(wav1.shape[0], wav2_aligned.shape[0])
670
+ wav_sub = wav1[:min_len] - wav2_aligned[:min_len]
671
+ audio_files.append(file2_aligned)
672
+ return min_len, wav_sub
673
+
674
+ # make aligned track 2
675
+ if est_diff > 0:
676
+ wav2_aligned = np.append(np.zeros((est_diff, 2)), wav2_org, axis=0)
677
+ min_len, wav_sub = save_aligned_audio(wav2_aligned)
678
+ elif est_diff < 0:
679
+ wav2_aligned = wav2_org[-est_diff:]
680
+ min_len, wav_sub = save_aligned_audio(wav2_aligned)
681
+ else:
682
+ command_Text(f"Audio files already aligned.\n")
683
+ command_Text(f"Saving inverted track... ")
684
+ min_len = min(wav1.shape[0], wav2.shape[0])
685
+ wav_sub = wav1[:min_len] - wav2[:min_len]
686
+
687
+ wav_sub = np.clip(wav_sub, -1, +1)
688
+
689
+ sf.write(file_subtracted, normalize(wav_sub, is_normalization), sr1, subtype=wav_type_set)
690
+ save_format(file_subtracted)
691
+
692
+ progress_bar_main_var.set(95)
lib_v5/vr_network/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # VR init.
lib_v5/vr_network/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (152 Bytes). View file
 
lib_v5/vr_network/__pycache__/layers.cpython-310.pyc ADDED
Binary file (4.47 kB). View file
 
lib_v5/vr_network/__pycache__/layers_new.cpython-310.pyc ADDED
Binary file (4.44 kB). View file
 
lib_v5/vr_network/__pycache__/model_param_init.cpython-310.pyc ADDED
Binary file (1.62 kB). View file
 
lib_v5/vr_network/__pycache__/nets.cpython-310.pyc ADDED
Binary file (4.39 kB). View file
 
lib_v5/vr_network/__pycache__/nets_new.cpython-310.pyc ADDED
Binary file (4 kB). View file
 
lib_v5/vr_network/layers.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from lib_v5 import spec_utils
6
+
7
+ class Conv2DBNActiv(nn.Module):
8
+
9
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
10
+ super(Conv2DBNActiv, self).__init__()
11
+ self.conv = nn.Sequential(
12
+ nn.Conv2d(
13
+ nin, nout,
14
+ kernel_size=ksize,
15
+ stride=stride,
16
+ padding=pad,
17
+ dilation=dilation,
18
+ bias=False),
19
+ nn.BatchNorm2d(nout),
20
+ activ()
21
+ )
22
+
23
+ def __call__(self, x):
24
+ return self.conv(x)
25
+
26
+ class SeperableConv2DBNActiv(nn.Module):
27
+
28
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
29
+ super(SeperableConv2DBNActiv, self).__init__()
30
+ self.conv = nn.Sequential(
31
+ nn.Conv2d(
32
+ nin, nin,
33
+ kernel_size=ksize,
34
+ stride=stride,
35
+ padding=pad,
36
+ dilation=dilation,
37
+ groups=nin,
38
+ bias=False),
39
+ nn.Conv2d(
40
+ nin, nout,
41
+ kernel_size=1,
42
+ bias=False),
43
+ nn.BatchNorm2d(nout),
44
+ activ()
45
+ )
46
+
47
+ def __call__(self, x):
48
+ return self.conv(x)
49
+
50
+
51
+ class Encoder(nn.Module):
52
+
53
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
54
+ super(Encoder, self).__init__()
55
+ self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
56
+ self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ)
57
+
58
+ def __call__(self, x):
59
+ skip = self.conv1(x)
60
+ h = self.conv2(skip)
61
+
62
+ return h, skip
63
+
64
+
65
+ class Decoder(nn.Module):
66
+
67
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
68
+ super(Decoder, self).__init__()
69
+ self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
70
+ self.dropout = nn.Dropout2d(0.1) if dropout else None
71
+
72
+ def __call__(self, x, skip=None):
73
+ x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
74
+ if skip is not None:
75
+ skip = spec_utils.crop_center(skip, x)
76
+ x = torch.cat([x, skip], dim=1)
77
+ h = self.conv(x)
78
+
79
+ if self.dropout is not None:
80
+ h = self.dropout(h)
81
+
82
+ return h
83
+
84
+
85
+ class ASPPModule(nn.Module):
86
+
87
+ def __init__(self, nn_architecture, nin, nout, dilations=(4, 8, 16), activ=nn.ReLU):
88
+ super(ASPPModule, self).__init__()
89
+ self.conv1 = nn.Sequential(
90
+ nn.AdaptiveAvgPool2d((1, None)),
91
+ Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
92
+ )
93
+
94
+ self.nn_architecture = nn_architecture
95
+ self.six_layer = [129605]
96
+ self.seven_layer = [537238, 537227, 33966]
97
+
98
+ extra_conv = SeperableConv2DBNActiv(
99
+ nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
100
+
101
+ self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
102
+ self.conv3 = SeperableConv2DBNActiv(
103
+ nin, nin, 3, 1, dilations[0], dilations[0], activ=activ)
104
+ self.conv4 = SeperableConv2DBNActiv(
105
+ nin, nin, 3, 1, dilations[1], dilations[1], activ=activ)
106
+ self.conv5 = SeperableConv2DBNActiv(
107
+ nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
108
+
109
+ if self.nn_architecture in self.six_layer:
110
+ self.conv6 = extra_conv
111
+ nin_x = 6
112
+ elif self.nn_architecture in self.seven_layer:
113
+ self.conv6 = extra_conv
114
+ self.conv7 = extra_conv
115
+ nin_x = 7
116
+ else:
117
+ nin_x = 5
118
+
119
+ self.bottleneck = nn.Sequential(
120
+ Conv2DBNActiv(nin * nin_x, nout, 1, 1, 0, activ=activ),
121
+ nn.Dropout2d(0.1)
122
+ )
123
+
124
+ def forward(self, x):
125
+ _, _, h, w = x.size()
126
+ feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True)
127
+ feat2 = self.conv2(x)
128
+ feat3 = self.conv3(x)
129
+ feat4 = self.conv4(x)
130
+ feat5 = self.conv5(x)
131
+
132
+ if self.nn_architecture in self.six_layer:
133
+ feat6 = self.conv6(x)
134
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6), dim=1)
135
+ elif self.nn_architecture in self.seven_layer:
136
+ feat6 = self.conv6(x)
137
+ feat7 = self.conv7(x)
138
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6, feat7), dim=1)
139
+ else:
140
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
141
+
142
+ bottle = self.bottleneck(out)
143
+ return bottle
lib_v5/vr_network/layers_new.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from lib_v5 import spec_utils
6
+
7
+ class Conv2DBNActiv(nn.Module):
8
+
9
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
10
+ super(Conv2DBNActiv, self).__init__()
11
+ self.conv = nn.Sequential(
12
+ nn.Conv2d(
13
+ nin, nout,
14
+ kernel_size=ksize,
15
+ stride=stride,
16
+ padding=pad,
17
+ dilation=dilation,
18
+ bias=False),
19
+ nn.BatchNorm2d(nout),
20
+ activ()
21
+ )
22
+
23
+ def __call__(self, x):
24
+ return self.conv(x)
25
+
26
+ class Encoder(nn.Module):
27
+
28
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
29
+ super(Encoder, self).__init__()
30
+ self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ)
31
+ self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
32
+
33
+ def __call__(self, x):
34
+ h = self.conv1(x)
35
+ h = self.conv2(h)
36
+
37
+ return h
38
+
39
+
40
+ class Decoder(nn.Module):
41
+
42
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
43
+ super(Decoder, self).__init__()
44
+ self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
45
+ # self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
46
+ self.dropout = nn.Dropout2d(0.1) if dropout else None
47
+
48
+ def __call__(self, x, skip=None):
49
+ x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
50
+
51
+ if skip is not None:
52
+ skip = spec_utils.crop_center(skip, x)
53
+ x = torch.cat([x, skip], dim=1)
54
+
55
+ h = self.conv1(x)
56
+ # h = self.conv2(h)
57
+
58
+ if self.dropout is not None:
59
+ h = self.dropout(h)
60
+
61
+ return h
62
+
63
+
64
+ class ASPPModule(nn.Module):
65
+
66
+ def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False):
67
+ super(ASPPModule, self).__init__()
68
+ self.conv1 = nn.Sequential(
69
+ nn.AdaptiveAvgPool2d((1, None)),
70
+ Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
71
+ )
72
+ self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
73
+ self.conv3 = Conv2DBNActiv(
74
+ nin, nout, 3, 1, dilations[0], dilations[0], activ=activ
75
+ )
76
+ self.conv4 = Conv2DBNActiv(
77
+ nin, nout, 3, 1, dilations[1], dilations[1], activ=activ
78
+ )
79
+ self.conv5 = Conv2DBNActiv(
80
+ nin, nout, 3, 1, dilations[2], dilations[2], activ=activ
81
+ )
82
+ self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ)
83
+ self.dropout = nn.Dropout2d(0.1) if dropout else None
84
+
85
+ def forward(self, x):
86
+ _, _, h, w = x.size()
87
+ feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True)
88
+ feat2 = self.conv2(x)
89
+ feat3 = self.conv3(x)
90
+ feat4 = self.conv4(x)
91
+ feat5 = self.conv5(x)
92
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
93
+ out = self.bottleneck(out)
94
+
95
+ if self.dropout is not None:
96
+ out = self.dropout(out)
97
+
98
+ return out
99
+
100
+
101
+ class LSTMModule(nn.Module):
102
+
103
+ def __init__(self, nin_conv, nin_lstm, nout_lstm):
104
+ super(LSTMModule, self).__init__()
105
+ self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0)
106
+ self.lstm = nn.LSTM(
107
+ input_size=nin_lstm,
108
+ hidden_size=nout_lstm // 2,
109
+ bidirectional=True
110
+ )
111
+ self.dense = nn.Sequential(
112
+ nn.Linear(nout_lstm, nin_lstm),
113
+ nn.BatchNorm1d(nin_lstm),
114
+ nn.ReLU()
115
+ )
116
+
117
+ def forward(self, x):
118
+ N, _, nbins, nframes = x.size()
119
+ h = self.conv(x)[:, 0] # N, nbins, nframes
120
+ h = h.permute(2, 0, 1) # nframes, N, nbins
121
+ h, _ = self.lstm(h)
122
+ h = self.dense(h.reshape(-1, h.size()[-1])) # nframes * N, nbins
123
+ h = h.reshape(nframes, N, 1, nbins)
124
+ h = h.permute(1, 2, 3, 0)
125
+
126
+ return h
lib_v5/vr_network/model_param_init.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pathlib
3
+
4
+ default_param = {}
5
+ default_param['bins'] = 768
6
+ default_param['unstable_bins'] = 9 # training only
7
+ default_param['reduction_bins'] = 762 # training only
8
+ default_param['sr'] = 44100
9
+ default_param['pre_filter_start'] = 757
10
+ default_param['pre_filter_stop'] = 768
11
+ default_param['band'] = {}
12
+
13
+
14
+ default_param['band'][1] = {
15
+ 'sr': 11025,
16
+ 'hl': 128,
17
+ 'n_fft': 960,
18
+ 'crop_start': 0,
19
+ 'crop_stop': 245,
20
+ 'lpf_start': 61, # inference only
21
+ 'res_type': 'polyphase'
22
+ }
23
+
24
+ default_param['band'][2] = {
25
+ 'sr': 44100,
26
+ 'hl': 512,
27
+ 'n_fft': 1536,
28
+ 'crop_start': 24,
29
+ 'crop_stop': 547,
30
+ 'hpf_start': 81, # inference only
31
+ 'res_type': 'sinc_best'
32
+ }
33
+
34
+
35
+ def int_keys(d):
36
+ r = {}
37
+ for k, v in d:
38
+ if k.isdigit():
39
+ k = int(k)
40
+ r[k] = v
41
+ return r
42
+
43
+
44
+ class ModelParameters(object):
45
+ def __init__(self, config_path=''):
46
+ if '.pth' == pathlib.Path(config_path).suffix:
47
+ import zipfile
48
+
49
+ with zipfile.ZipFile(config_path, 'r') as zip:
50
+ self.param = json.loads(zip.read('param.json'), object_pairs_hook=int_keys)
51
+ elif '.json' == pathlib.Path(config_path).suffix:
52
+ with open(config_path, 'r') as f:
53
+ self.param = json.loads(f.read(), object_pairs_hook=int_keys)
54
+ else:
55
+ self.param = default_param
56
+
57
+ for k in ['mid_side', 'mid_side_b', 'mid_side_b2', 'stereo_w', 'stereo_n', 'reverse']:
58
+ if not k in self.param:
59
+ self.param[k] = False
lib_v5/vr_network/modelparams/1band_sr16000_hl512.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bins": 1024,
3
+ "unstable_bins": 0,
4
+ "reduction_bins": 0,
5
+ "band": {
6
+ "1": {
7
+ "sr": 16000,
8
+ "hl": 512,
9
+ "n_fft": 2048,
10
+ "crop_start": 0,
11
+ "crop_stop": 1024,
12
+ "hpf_start": -1,
13
+ "res_type": "sinc_best"
14
+ }
15
+ },
16
+ "sr": 16000,
17
+ "pre_filter_start": 1023,
18
+ "pre_filter_stop": 1024
19
+ }
lib_v5/vr_network/modelparams/1band_sr32000_hl512.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bins": 1024,
3
+ "unstable_bins": 0,
4
+ "reduction_bins": 0,
5
+ "band": {
6
+ "1": {
7
+ "sr": 32000,
8
+ "hl": 512,
9
+ "n_fft": 2048,
10
+ "crop_start": 0,
11
+ "crop_stop": 1024,
12
+ "hpf_start": -1,
13
+ "res_type": "kaiser_fast"
14
+ }
15
+ },
16
+ "sr": 32000,
17
+ "pre_filter_start": 1000,
18
+ "pre_filter_stop": 1021
19
+ }
lib_v5/vr_network/modelparams/1band_sr33075_hl384.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bins": 1024,
3
+ "unstable_bins": 0,
4
+ "reduction_bins": 0,
5
+ "band": {
6
+ "1": {
7
+ "sr": 33075,
8
+ "hl": 384,
9
+ "n_fft": 2048,
10
+ "crop_start": 0,
11
+ "crop_stop": 1024,
12
+ "hpf_start": -1,
13
+ "res_type": "sinc_best"
14
+ }
15
+ },
16
+ "sr": 33075,
17
+ "pre_filter_start": 1000,
18
+ "pre_filter_stop": 1021
19
+ }
lib_v5/vr_network/modelparams/1band_sr44100_hl1024.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bins": 1024,
3
+ "unstable_bins": 0,
4
+ "reduction_bins": 0,
5
+ "band": {
6
+ "1": {
7
+ "sr": 44100,
8
+ "hl": 1024,
9
+ "n_fft": 2048,
10
+ "crop_start": 0,
11
+ "crop_stop": 1024,
12
+ "hpf_start": -1,
13
+ "res_type": "sinc_best"
14
+ }
15
+ },
16
+ "sr": 44100,
17
+ "pre_filter_start": 1023,
18
+ "pre_filter_stop": 1024
19
+ }
lib_v5/vr_network/modelparams/1band_sr44100_hl256.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bins": 256,
3
+ "unstable_bins": 0,
4
+ "reduction_bins": 0,
5
+ "band": {
6
+ "1": {
7
+ "sr": 44100,
8
+ "hl": 256,
9
+ "n_fft": 512,
10
+ "crop_start": 0,
11
+ "crop_stop": 256,
12
+ "hpf_start": -1,
13
+ "res_type": "sinc_best"
14
+ }
15
+ },
16
+ "sr": 44100,
17
+ "pre_filter_start": 256,
18
+ "pre_filter_stop": 256
19
+ }
lib_v5/vr_network/modelparams/1band_sr44100_hl512.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bins": 1024,
3
+ "unstable_bins": 0,
4
+ "reduction_bins": 0,
5
+ "band": {
6
+ "1": {
7
+ "sr": 44100,
8
+ "hl": 512,
9
+ "n_fft": 2048,
10
+ "crop_start": 0,
11
+ "crop_stop": 1024,
12
+ "hpf_start": -1,
13
+ "res_type": "sinc_best"
14
+ }
15
+ },
16
+ "sr": 44100,
17
+ "pre_filter_start": 1023,
18
+ "pre_filter_stop": 1024
19
+ }
lib_v5/vr_network/modelparams/1band_sr44100_hl512_cut.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bins": 1024,
3
+ "unstable_bins": 0,
4
+ "reduction_bins": 0,
5
+ "band": {
6
+ "1": {
7
+ "sr": 44100,
8
+ "hl": 512,
9
+ "n_fft": 2048,
10
+ "crop_start": 0,
11
+ "crop_stop": 700,
12
+ "hpf_start": -1,
13
+ "res_type": "sinc_best"
14
+ }
15
+ },
16
+ "sr": 44100,
17
+ "pre_filter_start": 1023,
18
+ "pre_filter_stop": 700
19
+ }
lib_v5/vr_network/modelparams/1band_sr44100_hl512_nf1024.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bins": 512,
3
+ "unstable_bins": 0,
4
+ "reduction_bins": 0,
5
+ "band": {
6
+ "1": {
7
+ "sr": 44100,
8
+ "hl": 512,
9
+ "n_fft": 1024,
10
+ "crop_start": 0,
11
+ "crop_stop": 512,
12
+ "hpf_start": -1,
13
+ "res_type": "sinc_best"
14
+ }
15
+ },
16
+ "sr": 44100,
17
+ "pre_filter_start": 511,
18
+ "pre_filter_stop": 512
19
+ }
lib_v5/vr_network/modelparams/2band_32000.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bins": 768,
3
+ "unstable_bins": 7,
4
+ "reduction_bins": 705,
5
+ "band": {
6
+ "1": {
7
+ "sr": 6000,
8
+ "hl": 66,
9
+ "n_fft": 512,
10
+ "crop_start": 0,
11
+ "crop_stop": 240,
12
+ "lpf_start": 60,
13
+ "lpf_stop": 118,
14
+ "res_type": "sinc_fastest"
15
+ },
16
+ "2": {
17
+ "sr": 32000,
18
+ "hl": 352,
19
+ "n_fft": 1024,
20
+ "crop_start": 22,
21
+ "crop_stop": 505,
22
+ "hpf_start": 44,
23
+ "hpf_stop": 23,
24
+ "res_type": "sinc_medium"
25
+ }
26
+ },
27
+ "sr": 32000,
28
+ "pre_filter_start": 710,
29
+ "pre_filter_stop": 731
30
+ }
lib_v5/vr_network/modelparams/2band_44100_lofi.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bins": 512,
3
+ "unstable_bins": 7,
4
+ "reduction_bins": 510,
5
+ "band": {
6
+ "1": {
7
+ "sr": 11025,
8
+ "hl": 160,
9
+ "n_fft": 768,
10
+ "crop_start": 0,
11
+ "crop_stop": 192,
12
+ "lpf_start": 41,
13
+ "lpf_stop": 139,
14
+ "res_type": "sinc_fastest"
15
+ },
16
+ "2": {
17
+ "sr": 44100,
18
+ "hl": 640,
19
+ "n_fft": 1024,
20
+ "crop_start": 10,
21
+ "crop_stop": 320,
22
+ "hpf_start": 47,
23
+ "hpf_stop": 15,
24
+ "res_type": "sinc_medium"
25
+ }
26
+ },
27
+ "sr": 44100,
28
+ "pre_filter_start": 510,
29
+ "pre_filter_stop": 512
30
+ }
lib_v5/vr_network/modelparams/2band_48000.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bins": 768,
3
+ "unstable_bins": 7,
4
+ "reduction_bins": 705,
5
+ "band": {
6
+ "1": {
7
+ "sr": 6000,
8
+ "hl": 66,
9
+ "n_fft": 512,
10
+ "crop_start": 0,
11
+ "crop_stop": 240,
12
+ "lpf_start": 60,
13
+ "lpf_stop": 240,
14
+ "res_type": "sinc_fastest"
15
+ },
16
+ "2": {
17
+ "sr": 48000,
18
+ "hl": 528,
19
+ "n_fft": 1536,
20
+ "crop_start": 22,
21
+ "crop_stop": 505,
22
+ "hpf_start": 82,
23
+ "hpf_stop": 22,
24
+ "res_type": "sinc_medium"
25
+ }
26
+ },
27
+ "sr": 48000,
28
+ "pre_filter_start": 710,
29
+ "pre_filter_stop": 731
30
+ }
lib_v5/vr_network/modelparams/3band_44100.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bins": 768,
3
+ "unstable_bins": 5,
4
+ "reduction_bins": 733,
5
+ "band": {
6
+ "1": {
7
+ "sr": 11025,
8
+ "hl": 128,
9
+ "n_fft": 768,
10
+ "crop_start": 0,
11
+ "crop_stop": 278,
12
+ "lpf_start": 28,
13
+ "lpf_stop": 140,
14
+ "res_type": "polyphase"
15
+ },
16
+ "2": {
17
+ "sr": 22050,
18
+ "hl": 256,
19
+ "n_fft": 768,
20
+ "crop_start": 14,
21
+ "crop_stop": 322,
22
+ "hpf_start": 70,
23
+ "hpf_stop": 14,
24
+ "lpf_start": 283,
25
+ "lpf_stop": 314,
26
+ "res_type": "polyphase"
27
+ },
28
+ "3": {
29
+ "sr": 44100,
30
+ "hl": 512,
31
+ "n_fft": 768,
32
+ "crop_start": 131,
33
+ "crop_stop": 313,
34
+ "hpf_start": 154,
35
+ "hpf_stop": 141,
36
+ "res_type": "sinc_medium"
37
+ }
38
+ },
39
+ "sr": 44100,
40
+ "pre_filter_start": 757,
41
+ "pre_filter_stop": 768
42
+ }