Politrees commited on
Commit
e976963
1 Parent(s): ecd9a9f

Upload 29 files

Browse files
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.29.0
2
+ tensorboardX
3
+ einops
4
+ local-attention
5
+ pedalboard==0.7.7
6
+
7
+ fairseq==0.12.2
8
+ faiss-cpu==1.7.3
9
+ ffmpeg-python>=0.2.0
10
+ praat-parselmouth>=0.4.2
11
+ pyworld==0.3.4
12
+ torchcrepe==0.0.20
rvc_models/MODELS.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ RVC Models can be added as a folder here. Each folder should contain the model file (.pth extension), and an index file (.index extension).
2
+ For example, a folder called Maya, containing 2 files, Maya.pth and added_IVF1905_Flat_nprobe_Maya_v2.index.
song_output/OUTPUT.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Output is stored in this folder, where directory names represent the YouTube IDs from the original song.
src/CoverGenLite.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import urllib.request
4
+ import zipfile
5
+ import gdown
6
+ import gradio as gr
7
+
8
+ from main import song_cover_pipeline
9
+ from audio_effects import add_audio_effects
10
+ from modules.model_management import ignore_files, update_models_list, extract_zip, download_from_url, upload_zip_model
11
+ from modules.ui_updates import show_hop_slider, update_f0_method, update_button_text, update_button_text_voc, update_button_text_inst
12
+ from modules.file_processing import process_file_upload
13
+
14
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
15
+ rvc_models_dir = os.path.join(BASE_DIR, 'rvc_models')
16
+ output_dir = os.path.join(BASE_DIR, 'song_output')
17
+
18
+
19
+ if __name__ == '__main__':
20
+ voice_models = ignore_files(rvc_models_dir)
21
+
22
+ with gr.Blocks(title='CoverGen Lite - Politrees (v0.2)', theme=gr.themes.Soft(primary_hue="green", secondary_hue="green", neutral_hue="neutral", spacing_size="sm", radius_size="lg")) as app:
23
+ with gr.Tab("Велком/Контакты"):
24
+ gr.HTML("<center><h1>Добро пожаловать в CoverGen Lite - Politrees (v0.2)</h1></center>")
25
+ with gr.Row():
26
+ with gr.Column(variant='panel'):
27
+ gr.HTML("<center><h2><a href='https://www.youtube.com/channel/UCHb3fZEVxUisnqLqCrEM8ZA'>YouTube: Politrees</a></h2></center>")
28
+ gr.HTML("<center><h2><a href='https://vk.com/artem__bebroy'>ВКонтакте (страница)</a></h2></center>")
29
+ with gr.Column(variant='panel'):
30
+ gr.HTML("<center><h2><a href='https://t.me/pol1trees'>Telegram Канал</a></h2></center>")
31
+ gr.HTML("<center><h2><a href='https://t.me/+GMTP7hZqY0E4OGRi'>Telegram Чат</a></h2></center>")
32
+ with gr.Column(variant='panel'):
33
+ gr.HTML("<center><h2><a href='https://github.com/Bebra777228/Pol-Litres-RVC'>GitHub проекта</a></h2></center>")
34
+
35
+ with gr.Tab("Преобразование голоса"):
36
+ with gr.Row(equal_height=False):
37
+ with gr.Column(scale=1, variant='panel'):
38
+ with gr.Group():
39
+ rvc_model = gr.Dropdown(voice_models, label='Модели голоса')
40
+ ref_btn = gr.Button('Обновить список моделей', variant='primary')
41
+ with gr.Group():
42
+ pitch = gr.Slider(-24, 24, value=0, step=0.5, label='Изменение тона голоса', info='-24 - мужской голос || 24 - женский голос')
43
+
44
+ with gr.Column(scale=2, variant='panel'):
45
+ with gr.Group():
46
+ local_file = gr.Audio(label='Аудио-файл', interactive=False, show_download_button=False)
47
+ uploaded_file = gr.UploadButton(label='Загрузить аудио-файл', file_types=['audio'], variant='primary')
48
+ uploaded_file.upload(process_file_upload, inputs=[uploaded_file], outputs=[local_file])
49
+ uploaded_file.upload(update_button_text, outputs=[uploaded_file])
50
+
51
+ with gr.Group():
52
+ with gr.Row(variant='panel'):
53
+ generate_btn = gr.Button("Генерировать", variant='primary', scale=1)
54
+ converted_voice = gr.Audio(label='Преобразованный голос', scale=5)
55
+ output_format = gr.Dropdown(['mp3', 'flac', 'wav'], value='mp3', label='Формат файла', scale=0.1, allow_custom_value=False, filterable=False)
56
+
57
+ with gr.Accordion('Настройки преобразования голоса', open=False):
58
+ with gr.Group():
59
+ with gr.Column(variant='panel'):
60
+ use_hybrid_methods = gr.Checkbox(label="Использовать гибридные методы", value=False)
61
+ f0_method = gr.Dropdown(['rmvpe+', 'fcpe', 'rmvpe', 'mangio-crepe', 'crepe'], value='rmvpe+', label='Метод выделения тона', allow_custom_value=False, filterable=False)
62
+ use_hybrid_methods.change(update_f0_method, inputs=use_hybrid_methods, outputs=f0_method)
63
+ crepe_hop_length = gr.Slider(8, 512, value=128, step=8, visible=False, label='Длина шага Crepe')
64
+ f0_method.change(show_hop_slider, inputs=f0_method, outputs=crepe_hop_length)
65
+ with gr.Column(variant='panel'):
66
+ index_rate = gr.Slider(0, 1, value=0, label='Влияние индекса', info='Контролирует степень влияния индексного файла на результат анализа. Более высокое значение увеличивает влияние индексного файла, но может усилить артефакты в аудио. Выбор более низкого значения может помочь сниз��ть артефакты.')
67
+ filter_radius = gr.Slider(0, 7, value=3, step=1, label='Радиус фильтра', info='Управляет радиусом фильтрации результатов анализа тона. Если значение фильтрации равняется или превышает три, применяется медианная фильтрация для уменьшения шума дыхания в аудиозаписи.')
68
+ rms_mix_rate = gr.Slider(0, 1, value=0.25, step=0.01, label='Скорость смешивания RMS', info='Контролирует степень смешивания выходного сигнала с его оболочкой громкости. Значение близкое к 1 увеличивает использование оболочки громкости выходного сигнала, что может улучшить качество звука.')
69
+ protect = gr.Slider(0, 0.5, value=0.33, step=0.01, label='Защита согласных', info='Контролирует степень защиты отдельных согласных и звуков дыхания от электроакустических разрывов и других артефактов. Максимальное значение 0,5 обеспечивает наибольшую защиту, но может увеличить эффект индексирования, который может негативно влиять на качество звука. Уменьшение значения может уменьшить степень защиты, но снизить эффект индексирования.')
70
+
71
+ ref_btn.click(update_models_list, None, outputs=rvc_model)
72
+ generate_btn.click(song_cover_pipeline,
73
+ inputs=[uploaded_file, rvc_model, pitch, index_rate, filter_radius, rms_mix_rate, f0_method, crepe_hop_length, protect, output_format],
74
+ outputs=[converted_voice])
75
+
76
+ with gr.Tab('Объединение/Обработка'):
77
+ with gr.Row(equal_height=False):
78
+ with gr.Column(variant='panel'):
79
+ with gr.Group():
80
+ vocal_audio = gr.Audio(label='Вокал', interactive=False, show_download_button=False)
81
+ upload_vocal_audio = gr.UploadButton(label='Загрузить вокал', file_types=['audio'], variant='primary')
82
+ upload_vocal_audio.upload(process_file_upload, inputs=[upload_vocal_audio], outputs=[vocal_audio])
83
+ upload_vocal_audio.upload(update_button_text_voc, outputs=[upload_vocal_audio])
84
+
85
+ with gr.Column(variant='panel'):
86
+ with gr.Group():
87
+ instrumental_audio = gr.Audio(label='Инструментал', interactive=False, show_download_button=False)
88
+ upload_instrumental_audio = gr.UploadButton(label='Загрузить инструментал', file_types=['audio'], variant='primary')
89
+ upload_instrumental_audio.upload(process_file_upload, inputs=[upload_instrumental_audio], outputs=[instrumental_audio])
90
+ upload_instrumental_audio.upload(update_button_text_inst, outputs=[upload_instrumental_audio])
91
+
92
+ with gr.Group():
93
+ with gr.Row(variant='panel'):
94
+ process_btn = gr.Button("Обработать", variant='primary', scale=1)
95
+ ai_cover = gr.Audio(label='Ai-Cover', scale=5)
96
+ output_format = gr.Dropdown(['mp3', 'flac', 'wav'], value='mp3', label='Формат файла', scale=0.1, allow_custom_value=False, filterable=False)
97
+
98
+ with gr.Accordion('Настройки сведения аудио', open=False):
99
+ gr.HTML('<center><h2>Изменение громкости</h2></center>')
100
+ with gr.Row(variant='panel'):
101
+ vocal_gain = gr.Slider(-10, 10, value=0, step=1, label='Вокал', scale=1)
102
+ instrumental_gain = gr.Slider(-10, 10, value=0, step=1, label='Инструментал', scale=1)
103
+ clear_btn = gr.Button("Сбросить все эффекты", scale=0.1)
104
+
105
+ with gr.Accordion('Эффекты', open=False):
106
+ with gr.Accordion('Реверберация', open=False):
107
+ with gr.Group():
108
+ with gr.Column(variant='panel'):
109
+ with gr.Row():
110
+ reverb_rm_size = gr.Slider(0, 1, value=0.15, label='Размер комнаты', info='Этот параметр отвечает за размер виртуального помещения, в котором будет звучать реверберация. Большее значение означает больш��й размер комнаты и более длительное звучание реверберации.')
111
+ reverb_width = gr.Slider(0, 1, value=1.0, label='Ширина реверберации', info='Этот параметр отвечает за ширину звучания реверберации. Чем выше значение, тем шире будет звучание реверберации.')
112
+ with gr.Row():
113
+ reverb_wet = gr.Slider(0, 1, value=0.1, label='Уровень влажности', info='Этот параметр отвечает за уровень реверберации. Чем выше значение, тем сильнее будет слышен эффект реверберации и тем дольше будет звучать «хвост».')
114
+ reverb_dry = gr.Slider(0, 1, value=0.8, label='Уровень сухости', info='Этот параметр отвечает за уровень исходного звука без реверберации. Чем меньше значение, тем тише звук ai вокала. Если значение будет на 0, то исходный звук полностью исчезнет.')
115
+ with gr.Row():
116
+ reverb_damping = gr.Slider(0, 1, value=0.7, label='Уровень демпфирования', info='Этот параметр отвечает за поглощение высоких частот в реверберации. Чем выше его значение, тем сильнее будет поглощение частот и тем менее будет «яркий» звук реверберации.')
117
+
118
+ with gr.Accordion('Хорус', open=False):
119
+ with gr.Group():
120
+ with gr.Column(variant='panel'):
121
+ with gr.Row():
122
+ chorus_rate_hz = gr.Slider(0.1, 10, value=0, label='Скорость хоруса', info='Этот параметр отвечает за скорость колебаний эффекта хоруса в герцах. Чем выше значение, тем быстрее будут колебаться звуки.')
123
+ chorus_depth = gr.Slider(0, 1, value=0, label='Глубина хоруса', info='Этот параметр отвечает за глубину эффекта хоруса. Чем выше значение, тем сильнее будет эффект хоруса.')
124
+ with gr.Row():
125
+ chorus_centre_delay_ms = gr.Slider(0, 50, value=0, label='Задержка центра (мс)', info='Этот параметр отвечает за задержку центрального сигнала эффекта хоруса в миллисекундах. Чем выше значение, тем дольше будет задержка.')
126
+ chorus_feedback = gr.Slider(0, 1, value=0, label='Обратная связь', info='Этот параметр отвечает за уровень обратной связи эффекта хоруса. Чем выше значение, тем сильнее будет эффект обратной связи.')
127
+ with gr.Row():
128
+ chorus_mix = gr.Slider(0, 1, value=0, label='Смешение', info='Этот параметр отвечает за уровень смешивания оригинального сигнала и эффекта хоруса. Чем выше значение, тем сильнее будет эффект хоруса.')
129
+
130
+ with gr.Accordion('Обработка', open=False):
131
+ with gr.Accordion('Компрессор', open=False):
132
+ with gr.Row(variant='panel'):
133
+ compressor_ratio = gr.Slider(1, 20, value=4, label='Соотношение', info='Этот параметр контролирует количество применяемого сжатия аудио. Большее значение означает большее сжатие, которое уменьшает динамический диапазон аудио, делая громкие части более тихими и тихие части более громкими.')
134
+ compressor_threshold = gr.Slider(-60, 0, value=-16, label='Порог', info='Этот параметр устанавливает порог, при превышении которого начинает действовать компрессор. Компрессор сжимает громкие звуки, чтобы сделать звук более ровным. Чем ниже порог, тем большее коли��ество звуков будет подвергнуто компрессии.')
135
+
136
+ with gr.Accordion('Фильтры', open=False):
137
+ with gr.Row(variant='panel'):
138
+ low_shelf_gain = gr.Slider(-20, 20, value=0, label='Фильтр нижних частот', info='Этот параметр контролирует усиление (громкость) низких частот. Положительное значение усиливает низкие частоты, делая звук более басским. Отрицательное значение ослабляет низкие частоты, делая звук более тонким.')
139
+ high_shelf_gain = gr.Slider(-20, 20, value=0, label='Фильтр высоких частот', info='Этот параметр контролирует усиление высоких частот. Положительное значение усиливает высокие частоты, делая звук более ярким. Отрицательное значение ослабляет высокие частоты, делая звук более тусклым.')
140
+
141
+ with gr.Accordion('Подавление шума', open=False):
142
+ with gr.Group():
143
+ with gr.Column(variant='panel'):
144
+ with gr.Row():
145
+ noise_gate_threshold = gr.Slider(-60, 0, value=-30, label='Порог', info='Этот параметр устанавливает пороговое значение в децибелах, ниже которого сигнал считается шумом. Когда сигнал опускается ниже этого порога, шумовой шлюз активируется и уменьшает громкость сигнала.')
146
+ noise_gate_ratio = gr.Slider(1, 20, value=6, label='Соотношение', info='Этот параметр устанавливает уровень подавления шума. Большее значение означает более сильное подавление шума.')
147
+ with gr.Row():
148
+ noise_gate_attack = gr.Slider(0, 100, value=10, label='Время атаки (мс)', info='Этот параметр контролирует скорость, с которой шумовой шлюз открывается, когда звук становится достаточно громким. Большее значение означает, что шлюз открывается медленнее.')
149
+ noise_gate_release = gr.Slider(0, 1000, value=100, label='Время спада (мс)', info='Этот параметр контролирует скорость, с которой шумовой шлюз закрывается, когда звук становится достаточно тихим. Большее значение означает, что шлюз закрывается медленнее.')
150
+
151
+ process_btn.click(add_audio_effects,
152
+ inputs=[upload_vocal_audio, upload_instrumental_audio, reverb_rm_size, reverb_wet, reverb_dry, reverb_damping,
153
+ reverb_width, low_shelf_gain, high_shelf_gain, compressor_ratio, compressor_threshold,
154
+ noise_gate_threshold, noise_gate_ratio, noise_gate_attack, noise_gate_release,
155
+ chorus_rate_hz, chorus_depth, chorus_centre_delay_ms, chorus_feedback, chorus_mix,
156
+ output_format, vocal_gain, instrumental_gain],
157
+ outputs=[ai_cover])
158
+
159
+ default_values = [0, 0, 0.15, 1.0, 0.1, 0.8, 0.7, 0, 0, 0, 0, 0, 4, -16, 0, 0, -30, 6, 10, 100]
160
+ clear_btn.click(lambda: default_values,
161
+ outputs=[vocal_gain, instrumental_gain, reverb_rm_size, reverb_width, reverb_wet, reverb_dry, reverb_damping,
162
+ chorus_rate_hz, chorus_depth, chorus_centre_delay_ms, chorus_feedback, chorus_mix,
163
+ compressor_ratio, compressor_threshold, low_shelf_gain, high_shelf_gain, noise_gate_threshold,
164
+ noise_gate_ratio, noise_gate_attack, noise_gate_release])
165
+
166
+ with gr.Tab('Загрузка модели'):
167
+ with gr.Tab('Загрузить по ссылке'):
168
+ with gr.Row():
169
+ with gr.Column(variant='panel'):
170
+ gr.HTML("<center><h3>Вставьте в поле ниже ссылку от <a href='https://huggingface.co/' target='_blank'>HuggingFace</a>, <a href='https://pixeldrain.com/' target='_blank'>Pixeldrain</a> или <a href='https://drive.google.com/' target='_blank'>Google Drive</a></h3></center>")
171
+ model_zip_link = gr.Text(label='Ссылка на загрузку модели')
172
+ with gr.Column(variant='panel'):
173
+ with gr.Group():
174
+ model_name = gr.Text(label='Имя модели', info='Дайте вашей загружаемой модели уникальное имя, отличное от других голосовых моделей.')
175
+ download_btn = gr.Button('Загрузить модель', variant='primary')
176
+
177
+ dl_output_message = gr.Text(label='Сообщение вывода', interactive=False)
178
+ download_btn.click(download_from_url, inputs=[model_zip_link, model_name], outputs=dl_output_message)
179
+
180
+ with gr.Tab('Загрузить локально'):
181
+ with gr.Row():
182
+ with gr.Column(variant='panel'):
183
+ zip_file = gr.File(label='Zip-файл', file_types=['.zip'], file_count='single')
184
+ with gr.Column(variant='panel'):
185
+ gr.HTML("<h3>1. Найдите и скачайте файлы: .pth и необязательный файл .index</h3>")
186
+ gr.HTML("<h3>2. Закиньте файл(-ы) в ZIP-архив и поместите его в область загрузки</h3>")
187
+ gr.HTML('<h3>3. Дождитесь полной загрузки ZIP-архива в интерфейс</h3>')
188
+ with gr.Group():
189
+ local_model_name = gr.Text(label='Имя модели', info='Дайте вашей загружаемой модели уникальное имя, отличное от других голосовых моделей.')
190
+ model_upload_button = gr.Button('Загрузить модель', variant='primary')
191
+
192
+ local_upload_output_message = gr.Text(label='Сообщение вывода', interactive=False)
193
+ model_upload_button.click(upload_zip_model, inputs=[zip_file, local_model_name], outputs=local_upload_output_message)
194
+
195
+ app.launch(share=True, quiet=True)
src/audio_effects.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import librosa
3
+ import numpy as np
4
+ import gradio as gr
5
+ import soundfile as sf
6
+ from pedalboard import (
7
+ Pedalboard, Reverb, Compressor, HighpassFilter,
8
+ LowShelfFilter, HighShelfFilter, NoiseGate, Chorus
9
+ )
10
+ from pedalboard.io import AudioFile
11
+ from pydub import AudioSegment
12
+
13
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14
+
15
+ def display_progress(percent, message, progress=gr.Progress()):
16
+ progress(percent, desc=message)
17
+
18
+ def combine_audio(vocal_path, instrumental_path, output_path, vocal_gain, instrumental_gain, output_format):
19
+ vocal_format = vocal_path.split('.')[-1]
20
+ instrumental_format = instrumental_path.split('.')[-1]
21
+
22
+ vocal = AudioSegment.from_file(vocal_path, format=vocal_format)
23
+ instrumental = AudioSegment.from_file(instrumental_path, format=instrumental_format)
24
+
25
+ vocal += vocal_gain
26
+ instrumental += instrumental_gain
27
+
28
+ combined = vocal.overlay(instrumental)
29
+ combined.export(output_path, format=output_format)
30
+
31
+ def add_audio_effects(vocal_audio_path, instrumental_audio_path, reverb_rm_size, reverb_wet, reverb_dry, reverb_damping, reverb_width,
32
+ low_shelf_gain, high_shelf_gain, compressor_ratio, compressor_threshold, noise_gate_threshold, noise_gate_ratio,
33
+ noise_gate_attack, noise_gate_release, chorus_rate_hz, chorus_depth, chorus_centre_delay_ms, chorus_feedback,
34
+ chorus_mix, output_format, vocal_gain, instrumental_gain, progress=gr.Progress()):
35
+
36
+ if not vocal_audio_path or not instrumental_audio_path:
37
+ raise ValueError("Оба пути к аудиофайлам должны быть заполнены.")
38
+
39
+ display_progress(0.2, "Применение аудиоэффектов к вокалу...", progress)
40
+ board = Pedalboard(
41
+ [
42
+ HighpassFilter(),
43
+ Compressor(ratio=compressor_ratio, threshold_db=compressor_threshold),
44
+ NoiseGate(threshold_db=noise_gate_threshold, ratio=noise_gate_ratio, attack_ms=noise_gate_attack, release_ms=noise_gate_release),
45
+ Reverb(room_size=reverb_rm_size, dry_level=reverb_dry, wet_level=reverb_wet, damping=reverb_damping, width=reverb_width),
46
+ LowShelfFilter(gain_db=low_shelf_gain),
47
+ HighShelfFilter(gain_db=high_shelf_gain),
48
+ Chorus(rate_hz=chorus_rate_hz, depth=chorus_depth, centre_delay_ms=chorus_centre_delay_ms, feedback=chorus_feedback, mix=chorus_mix),
49
+ ]
50
+ )
51
+
52
+ vocal_output_path = f'Vocal_Effects.wav'
53
+ with AudioFile(vocal_audio_path) as f:
54
+ with AudioFile(vocal_output_path, 'w', f.samplerate, 2) as o:
55
+ while f.tell() < f.frames:
56
+ chunk = f.read(int(f.samplerate))
57
+ chunk = np.tile(chunk, (2, 1)).T
58
+ effected = board(chunk, f.samplerate, reset=False)
59
+ o.write(effected)
60
+
61
+ display_progress(0.5, "Объединение вокала и инструментальной части...", progress)
62
+ output_dir = os.path.join(BASE_DIR, 'processed_output')
63
+ if not os.path.exists(output_dir):
64
+ os.makedirs(output_dir)
65
+ combined_output_path = os.path.join(output_dir, f'AiCover_combined.{output_format}')
66
+
67
+ if os.path.exists(combined_output_path):
68
+ os.remove(combined_output_path)
69
+
70
+ combine_audio(vocal_output_path, instrumental_audio_path, combined_output_path, vocal_gain, instrumental_gain, output_format)
71
+
72
+ display_progress(1.0, "Готово!", progress)
73
+
74
+ return combined_output_path
src/configs/32k.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "seed": 1234,
5
+ "epochs": 20000,
6
+ "learning_rate": 1e-4,
7
+ "betas": [0.8, 0.99],
8
+ "eps": 1e-9,
9
+ "batch_size": 4,
10
+ "fp16_run": false,
11
+ "lr_decay": 0.999875,
12
+ "segment_size": 12800,
13
+ "init_lr_ratio": 1,
14
+ "warmup_epochs": 0,
15
+ "c_mel": 45,
16
+ "c_kl": 1.0
17
+ },
18
+ "data": {
19
+ "max_wav_value": 32768.0,
20
+ "sampling_rate": 32000,
21
+ "filter_length": 1024,
22
+ "hop_length": 320,
23
+ "win_length": 1024,
24
+ "n_mel_channels": 80,
25
+ "mel_fmin": 0.0,
26
+ "mel_fmax": null
27
+ },
28
+ "model": {
29
+ "inter_channels": 192,
30
+ "hidden_channels": 192,
31
+ "filter_channels": 768,
32
+ "n_heads": 2,
33
+ "n_layers": 6,
34
+ "kernel_size": 3,
35
+ "p_dropout": 0,
36
+ "resblock": "1",
37
+ "resblock_kernel_sizes": [3,7,11],
38
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
39
+ "upsample_rates": [10,4,2,2,2],
40
+ "upsample_initial_channel": 512,
41
+ "upsample_kernel_sizes": [16,16,4,4,4],
42
+ "use_spectral_norm": false,
43
+ "gin_channels": 256,
44
+ "spk_embed_dim": 109
45
+ }
46
+ }
src/configs/32k_v2.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "seed": 1234,
5
+ "epochs": 20000,
6
+ "learning_rate": 1e-4,
7
+ "betas": [0.8, 0.99],
8
+ "eps": 1e-9,
9
+ "batch_size": 4,
10
+ "fp16_run": true,
11
+ "lr_decay": 0.999875,
12
+ "segment_size": 12800,
13
+ "init_lr_ratio": 1,
14
+ "warmup_epochs": 0,
15
+ "c_mel": 45,
16
+ "c_kl": 1.0
17
+ },
18
+ "data": {
19
+ "max_wav_value": 32768.0,
20
+ "sampling_rate": 32000,
21
+ "filter_length": 1024,
22
+ "hop_length": 320,
23
+ "win_length": 1024,
24
+ "n_mel_channels": 80,
25
+ "mel_fmin": 0.0,
26
+ "mel_fmax": null
27
+ },
28
+ "model": {
29
+ "inter_channels": 192,
30
+ "hidden_channels": 192,
31
+ "filter_channels": 768,
32
+ "n_heads": 2,
33
+ "n_layers": 6,
34
+ "kernel_size": 3,
35
+ "p_dropout": 0,
36
+ "resblock": "1",
37
+ "resblock_kernel_sizes": [3,7,11],
38
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
39
+ "upsample_rates": [10,8,2,2],
40
+ "upsample_initial_channel": 512,
41
+ "upsample_kernel_sizes": [20,16,4,4],
42
+ "use_spectral_norm": false,
43
+ "gin_channels": 256,
44
+ "spk_embed_dim": 109
45
+ }
46
+ }
src/configs/40k.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "seed": 1234,
5
+ "epochs": 20000,
6
+ "learning_rate": 1e-4,
7
+ "betas": [0.8, 0.99],
8
+ "eps": 1e-9,
9
+ "batch_size": 4,
10
+ "fp16_run": false,
11
+ "lr_decay": 0.999875,
12
+ "segment_size": 12800,
13
+ "init_lr_ratio": 1,
14
+ "warmup_epochs": 0,
15
+ "c_mel": 45,
16
+ "c_kl": 1.0
17
+ },
18
+ "data": {
19
+ "max_wav_value": 32768.0,
20
+ "sampling_rate": 40000,
21
+ "filter_length": 2048,
22
+ "hop_length": 400,
23
+ "win_length": 2048,
24
+ "n_mel_channels": 125,
25
+ "mel_fmin": 0.0,
26
+ "mel_fmax": null
27
+ },
28
+ "model": {
29
+ "inter_channels": 192,
30
+ "hidden_channels": 192,
31
+ "filter_channels": 768,
32
+ "n_heads": 2,
33
+ "n_layers": 6,
34
+ "kernel_size": 3,
35
+ "p_dropout": 0,
36
+ "resblock": "1",
37
+ "resblock_kernel_sizes": [3,7,11],
38
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
39
+ "upsample_rates": [10,10,2,2],
40
+ "upsample_initial_channel": 512,
41
+ "upsample_kernel_sizes": [16,16,4,4],
42
+ "use_spectral_norm": false,
43
+ "gin_channels": 256,
44
+ "spk_embed_dim": 109
45
+ }
46
+ }
src/configs/48k.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "seed": 1234,
5
+ "epochs": 20000,
6
+ "learning_rate": 1e-4,
7
+ "betas": [0.8, 0.99],
8
+ "eps": 1e-9,
9
+ "batch_size": 4,
10
+ "fp16_run": false,
11
+ "lr_decay": 0.999875,
12
+ "segment_size": 11520,
13
+ "init_lr_ratio": 1,
14
+ "warmup_epochs": 0,
15
+ "c_mel": 45,
16
+ "c_kl": 1.0
17
+ },
18
+ "data": {
19
+ "max_wav_value": 32768.0,
20
+ "sampling_rate": 48000,
21
+ "filter_length": 2048,
22
+ "hop_length": 480,
23
+ "win_length": 2048,
24
+ "n_mel_channels": 128,
25
+ "mel_fmin": 0.0,
26
+ "mel_fmax": null
27
+ },
28
+ "model": {
29
+ "inter_channels": 192,
30
+ "hidden_channels": 192,
31
+ "filter_channels": 768,
32
+ "n_heads": 2,
33
+ "n_layers": 6,
34
+ "kernel_size": 3,
35
+ "p_dropout": 0,
36
+ "resblock": "1",
37
+ "resblock_kernel_sizes": [3,7,11],
38
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
39
+ "upsample_rates": [10,6,2,2,2],
40
+ "upsample_initial_channel": 512,
41
+ "upsample_kernel_sizes": [16,16,4,4,4],
42
+ "use_spectral_norm": false,
43
+ "gin_channels": 256,
44
+ "spk_embed_dim": 109
45
+ }
46
+ }
src/configs/48k_v2.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "seed": 1234,
5
+ "epochs": 20000,
6
+ "learning_rate": 1e-4,
7
+ "betas": [0.8, 0.99],
8
+ "eps": 1e-9,
9
+ "batch_size": 4,
10
+ "fp16_run": true,
11
+ "lr_decay": 0.999875,
12
+ "segment_size": 17280,
13
+ "init_lr_ratio": 1,
14
+ "warmup_epochs": 0,
15
+ "c_mel": 45,
16
+ "c_kl": 1.0
17
+ },
18
+ "data": {
19
+ "max_wav_value": 32768.0,
20
+ "sampling_rate": 48000,
21
+ "filter_length": 2048,
22
+ "hop_length": 480,
23
+ "win_length": 2048,
24
+ "n_mel_channels": 128,
25
+ "mel_fmin": 0.0,
26
+ "mel_fmax": null
27
+ },
28
+ "model": {
29
+ "inter_channels": 192,
30
+ "hidden_channels": 192,
31
+ "filter_channels": 768,
32
+ "n_heads": 2,
33
+ "n_layers": 6,
34
+ "kernel_size": 3,
35
+ "p_dropout": 0,
36
+ "resblock": "1",
37
+ "resblock_kernel_sizes": [3,7,11],
38
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
39
+ "upsample_rates": [12,10,2,2],
40
+ "upsample_initial_channel": 512,
41
+ "upsample_kernel_sizes": [24,20,4,4],
42
+ "use_spectral_norm": false,
43
+ "gin_channels": 256,
44
+ "spk_embed_dim": 109
45
+ }
46
+ }
src/download_models.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import requests
3
+
4
+ RVC_other_DOWNLOAD_LINK = 'https://huggingface.co/Politrees/all_RVC-pretrained_and_other/resolve/main/other/'
5
+ RVC_hubert_DOWNLOAD_LINK = 'https://huggingface.co/Politrees/all_RVC-pretrained_and_other/resolve/main/HuBERTs/'
6
+
7
+ BASE_DIR = Path(__file__).resolve().parent.parent
8
+ rvc_models_dir = BASE_DIR / 'rvc_models'
9
+
10
+
11
+ def dl_model(link, model_name, dir_name):
12
+ with requests.get(f'{link}{model_name}') as r:
13
+ r.raise_for_status()
14
+ with open(dir_name / model_name, 'wb') as f:
15
+ for chunk in r.iter_content(chunk_size=8192):
16
+ f.write(chunk)
17
+
18
+ if __name__ == '__main__':
19
+ rvc_other_names = ['rmvpe.pt', 'fcpe.pt']
20
+ for model in rvc_other_names:
21
+ print(f'Downloading {model}...')
22
+ dl_model(RVC_other_DOWNLOAD_LINK, model, rvc_models_dir)
23
+
24
+ rvc_hubert_names = ['hubert_base.pt']
25
+ for model in rvc_hubert_names:
26
+ print(f'Downloading {model}...')
27
+ dl_model(RVC_hubert_DOWNLOAD_LINK, model, rvc_models_dir)
28
+
29
+ print('All models downloaded!')
src/infer_pack/attentions.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from infer_pack import commons
9
+ from infer_pack import modules
10
+ from infer_pack.modules import LayerNorm
11
+
12
+
13
+ class Encoder(nn.Module):
14
+ def __init__(
15
+ self,
16
+ hidden_channels,
17
+ filter_channels,
18
+ n_heads,
19
+ n_layers,
20
+ kernel_size=1,
21
+ p_dropout=0.0,
22
+ window_size=10,
23
+ **kwargs
24
+ ):
25
+ super().__init__()
26
+ self.hidden_channels = hidden_channels
27
+ self.filter_channels = filter_channels
28
+ self.n_heads = n_heads
29
+ self.n_layers = n_layers
30
+ self.kernel_size = kernel_size
31
+ self.p_dropout = p_dropout
32
+ self.window_size = window_size
33
+
34
+ self.drop = nn.Dropout(p_dropout)
35
+ self.attn_layers = nn.ModuleList()
36
+ self.norm_layers_1 = nn.ModuleList()
37
+ self.ffn_layers = nn.ModuleList()
38
+ self.norm_layers_2 = nn.ModuleList()
39
+ for i in range(self.n_layers):
40
+ self.attn_layers.append(
41
+ MultiHeadAttention(
42
+ hidden_channels,
43
+ hidden_channels,
44
+ n_heads,
45
+ p_dropout=p_dropout,
46
+ window_size=window_size,
47
+ )
48
+ )
49
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
50
+ self.ffn_layers.append(
51
+ FFN(
52
+ hidden_channels,
53
+ hidden_channels,
54
+ filter_channels,
55
+ kernel_size,
56
+ p_dropout=p_dropout,
57
+ )
58
+ )
59
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
60
+
61
+ def forward(self, x, x_mask):
62
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
63
+ x = x * x_mask
64
+ for i in range(self.n_layers):
65
+ y = self.attn_layers[i](x, x, attn_mask)
66
+ y = self.drop(y)
67
+ x = self.norm_layers_1[i](x + y)
68
+
69
+ y = self.ffn_layers[i](x, x_mask)
70
+ y = self.drop(y)
71
+ x = self.norm_layers_2[i](x + y)
72
+ x = x * x_mask
73
+ return x
74
+
75
+
76
+ class Decoder(nn.Module):
77
+ def __init__(
78
+ self,
79
+ hidden_channels,
80
+ filter_channels,
81
+ n_heads,
82
+ n_layers,
83
+ kernel_size=1,
84
+ p_dropout=0.0,
85
+ proximal_bias=False,
86
+ proximal_init=True,
87
+ **kwargs
88
+ ):
89
+ super().__init__()
90
+ self.hidden_channels = hidden_channels
91
+ self.filter_channels = filter_channels
92
+ self.n_heads = n_heads
93
+ self.n_layers = n_layers
94
+ self.kernel_size = kernel_size
95
+ self.p_dropout = p_dropout
96
+ self.proximal_bias = proximal_bias
97
+ self.proximal_init = proximal_init
98
+
99
+ self.drop = nn.Dropout(p_dropout)
100
+ self.self_attn_layers = nn.ModuleList()
101
+ self.norm_layers_0 = nn.ModuleList()
102
+ self.encdec_attn_layers = nn.ModuleList()
103
+ self.norm_layers_1 = nn.ModuleList()
104
+ self.ffn_layers = nn.ModuleList()
105
+ self.norm_layers_2 = nn.ModuleList()
106
+ for i in range(self.n_layers):
107
+ self.self_attn_layers.append(
108
+ MultiHeadAttention(
109
+ hidden_channels,
110
+ hidden_channels,
111
+ n_heads,
112
+ p_dropout=p_dropout,
113
+ proximal_bias=proximal_bias,
114
+ proximal_init=proximal_init,
115
+ )
116
+ )
117
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
118
+ self.encdec_attn_layers.append(
119
+ MultiHeadAttention(
120
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
121
+ )
122
+ )
123
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
124
+ self.ffn_layers.append(
125
+ FFN(
126
+ hidden_channels,
127
+ hidden_channels,
128
+ filter_channels,
129
+ kernel_size,
130
+ p_dropout=p_dropout,
131
+ causal=True,
132
+ )
133
+ )
134
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
135
+
136
+ def forward(self, x, x_mask, h, h_mask):
137
+ """
138
+ x: decoder input
139
+ h: encoder output
140
+ """
141
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
142
+ device=x.device, dtype=x.dtype
143
+ )
144
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
145
+ x = x * x_mask
146
+ for i in range(self.n_layers):
147
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
148
+ y = self.drop(y)
149
+ x = self.norm_layers_0[i](x + y)
150
+
151
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
152
+ y = self.drop(y)
153
+ x = self.norm_layers_1[i](x + y)
154
+
155
+ y = self.ffn_layers[i](x, x_mask)
156
+ y = self.drop(y)
157
+ x = self.norm_layers_2[i](x + y)
158
+ x = x * x_mask
159
+ return x
160
+
161
+
162
+ class MultiHeadAttention(nn.Module):
163
+ def __init__(
164
+ self,
165
+ channels,
166
+ out_channels,
167
+ n_heads,
168
+ p_dropout=0.0,
169
+ window_size=None,
170
+ heads_share=True,
171
+ block_length=None,
172
+ proximal_bias=False,
173
+ proximal_init=False,
174
+ ):
175
+ super().__init__()
176
+ assert channels % n_heads == 0
177
+
178
+ self.channels = channels
179
+ self.out_channels = out_channels
180
+ self.n_heads = n_heads
181
+ self.p_dropout = p_dropout
182
+ self.window_size = window_size
183
+ self.heads_share = heads_share
184
+ self.block_length = block_length
185
+ self.proximal_bias = proximal_bias
186
+ self.proximal_init = proximal_init
187
+ self.attn = None
188
+
189
+ self.k_channels = channels // n_heads
190
+ self.conv_q = nn.Conv1d(channels, channels, 1)
191
+ self.conv_k = nn.Conv1d(channels, channels, 1)
192
+ self.conv_v = nn.Conv1d(channels, channels, 1)
193
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
194
+ self.drop = nn.Dropout(p_dropout)
195
+
196
+ if window_size is not None:
197
+ n_heads_rel = 1 if heads_share else n_heads
198
+ rel_stddev = self.k_channels**-0.5
199
+ self.emb_rel_k = nn.Parameter(
200
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
201
+ * rel_stddev
202
+ )
203
+ self.emb_rel_v = nn.Parameter(
204
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
205
+ * rel_stddev
206
+ )
207
+
208
+ nn.init.xavier_uniform_(self.conv_q.weight)
209
+ nn.init.xavier_uniform_(self.conv_k.weight)
210
+ nn.init.xavier_uniform_(self.conv_v.weight)
211
+ if proximal_init:
212
+ with torch.no_grad():
213
+ self.conv_k.weight.copy_(self.conv_q.weight)
214
+ self.conv_k.bias.copy_(self.conv_q.bias)
215
+
216
+ def forward(self, x, c, attn_mask=None):
217
+ q = self.conv_q(x)
218
+ k = self.conv_k(c)
219
+ v = self.conv_v(c)
220
+
221
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
222
+
223
+ x = self.conv_o(x)
224
+ return x
225
+
226
+ def attention(self, query, key, value, mask=None):
227
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
228
+ b, d, t_s, t_t = (*key.size(), query.size(2))
229
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
230
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
231
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
232
+
233
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
234
+ if self.window_size is not None:
235
+ assert (
236
+ t_s == t_t
237
+ ), "Relative attention is only available for self-attention."
238
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
239
+ rel_logits = self._matmul_with_relative_keys(
240
+ query / math.sqrt(self.k_channels), key_relative_embeddings
241
+ )
242
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
243
+ scores = scores + scores_local
244
+ if self.proximal_bias:
245
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
246
+ scores = scores + self._attention_bias_proximal(t_s).to(
247
+ device=scores.device, dtype=scores.dtype
248
+ )
249
+ if mask is not None:
250
+ scores = scores.masked_fill(mask == 0, -1e4)
251
+ if self.block_length is not None:
252
+ assert (
253
+ t_s == t_t
254
+ ), "Local attention is only available for self-attention."
255
+ block_mask = (
256
+ torch.ones_like(scores)
257
+ .triu(-self.block_length)
258
+ .tril(self.block_length)
259
+ )
260
+ scores = scores.masked_fill(block_mask == 0, -1e4)
261
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
262
+ p_attn = self.drop(p_attn)
263
+ output = torch.matmul(p_attn, value)
264
+ if self.window_size is not None:
265
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
266
+ value_relative_embeddings = self._get_relative_embeddings(
267
+ self.emb_rel_v, t_s
268
+ )
269
+ output = output + self._matmul_with_relative_values(
270
+ relative_weights, value_relative_embeddings
271
+ )
272
+ output = (
273
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
274
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
275
+ return output, p_attn
276
+
277
+ def _matmul_with_relative_values(self, x, y):
278
+ """
279
+ x: [b, h, l, m]
280
+ y: [h or 1, m, d]
281
+ ret: [b, h, l, d]
282
+ """
283
+ ret = torch.matmul(x, y.unsqueeze(0))
284
+ return ret
285
+
286
+ def _matmul_with_relative_keys(self, x, y):
287
+ """
288
+ x: [b, h, l, d]
289
+ y: [h or 1, m, d]
290
+ ret: [b, h, l, m]
291
+ """
292
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
293
+ return ret
294
+
295
+ def _get_relative_embeddings(self, relative_embeddings, length):
296
+ max_relative_position = 2 * self.window_size + 1
297
+ # Pad first before slice to avoid using cond ops.
298
+ pad_length = max(length - (self.window_size + 1), 0)
299
+ slice_start_position = max((self.window_size + 1) - length, 0)
300
+ slice_end_position = slice_start_position + 2 * length - 1
301
+ if pad_length > 0:
302
+ padded_relative_embeddings = F.pad(
303
+ relative_embeddings,
304
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
305
+ )
306
+ else:
307
+ padded_relative_embeddings = relative_embeddings
308
+ used_relative_embeddings = padded_relative_embeddings[
309
+ :, slice_start_position:slice_end_position
310
+ ]
311
+ return used_relative_embeddings
312
+
313
+ def _relative_position_to_absolute_position(self, x):
314
+ """
315
+ x: [b, h, l, 2*l-1]
316
+ ret: [b, h, l, l]
317
+ """
318
+ batch, heads, length, _ = x.size()
319
+ # Concat columns of pad to shift from relative to absolute indexing.
320
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
321
+
322
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
323
+ x_flat = x.view([batch, heads, length * 2 * length])
324
+ x_flat = F.pad(
325
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
326
+ )
327
+
328
+ # Reshape and slice out the padded elements.
329
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
330
+ :, :, :length, length - 1 :
331
+ ]
332
+ return x_final
333
+
334
+ def _absolute_position_to_relative_position(self, x):
335
+ """
336
+ x: [b, h, l, l]
337
+ ret: [b, h, l, 2*l-1]
338
+ """
339
+ batch, heads, length, _ = x.size()
340
+ # padd along column
341
+ x = F.pad(
342
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
343
+ )
344
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
345
+ # add 0's in the beginning that will skew the elements after reshape
346
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
347
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
348
+ return x_final
349
+
350
+ def _attention_bias_proximal(self, length):
351
+ """Bias for self-attention to encourage attention to close positions.
352
+ Args:
353
+ length: an integer scalar.
354
+ Returns:
355
+ a Tensor with shape [1, 1, length, length]
356
+ """
357
+ r = torch.arange(length, dtype=torch.float32)
358
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
359
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
360
+
361
+
362
+ class FFN(nn.Module):
363
+ def __init__(
364
+ self,
365
+ in_channels,
366
+ out_channels,
367
+ filter_channels,
368
+ kernel_size,
369
+ p_dropout=0.0,
370
+ activation=None,
371
+ causal=False,
372
+ ):
373
+ super().__init__()
374
+ self.in_channels = in_channels
375
+ self.out_channels = out_channels
376
+ self.filter_channels = filter_channels
377
+ self.kernel_size = kernel_size
378
+ self.p_dropout = p_dropout
379
+ self.activation = activation
380
+ self.causal = causal
381
+
382
+ if causal:
383
+ self.padding = self._causal_padding
384
+ else:
385
+ self.padding = self._same_padding
386
+
387
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
388
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
389
+ self.drop = nn.Dropout(p_dropout)
390
+
391
+ def forward(self, x, x_mask):
392
+ x = self.conv_1(self.padding(x * x_mask))
393
+ if self.activation == "gelu":
394
+ x = x * torch.sigmoid(1.702 * x)
395
+ else:
396
+ x = torch.relu(x)
397
+ x = self.drop(x)
398
+ x = self.conv_2(self.padding(x * x_mask))
399
+ return x * x_mask
400
+
401
+ def _causal_padding(self, x):
402
+ if self.kernel_size == 1:
403
+ return x
404
+ pad_l = self.kernel_size - 1
405
+ pad_r = 0
406
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
407
+ x = F.pad(x, commons.convert_pad_shape(padding))
408
+ return x
409
+
410
+ def _same_padding(self, x):
411
+ if self.kernel_size == 1:
412
+ return x
413
+ pad_l = (self.kernel_size - 1) // 2
414
+ pad_r = self.kernel_size // 2
415
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
416
+ x = F.pad(x, commons.convert_pad_shape(padding))
417
+ return x
src/infer_pack/commons.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def init_weights(m, mean=0.0, std=0.01):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ m.weight.data.normal_(mean, std)
12
+
13
+
14
+ def get_padding(kernel_size, dilation=1):
15
+ return int((kernel_size * dilation - dilation) / 2)
16
+
17
+
18
+ def convert_pad_shape(pad_shape):
19
+ l = pad_shape[::-1]
20
+ pad_shape = [item for sublist in l for item in sublist]
21
+ return pad_shape
22
+
23
+
24
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
25
+ """KL(P||Q)"""
26
+ kl = (logs_q - logs_p) - 0.5
27
+ kl += (
28
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
29
+ )
30
+ return kl
31
+
32
+
33
+ def rand_gumbel(shape):
34
+ """Sample from the Gumbel distribution, protect from overflows."""
35
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
36
+ return -torch.log(-torch.log(uniform_samples))
37
+
38
+
39
+ def rand_gumbel_like(x):
40
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
41
+ return g
42
+
43
+
44
+ def slice_segments(x, ids_str, segment_size=4):
45
+ ret = torch.zeros_like(x[:, :, :segment_size])
46
+ for i in range(x.size(0)):
47
+ idx_str = ids_str[i]
48
+ idx_end = idx_str + segment_size
49
+ ret[i] = x[i, :, idx_str:idx_end]
50
+ return ret
51
+
52
+
53
+ def slice_segments2(x, ids_str, segment_size=4):
54
+ ret = torch.zeros_like(x[:, :segment_size])
55
+ for i in range(x.size(0)):
56
+ idx_str = ids_str[i]
57
+ idx_end = idx_str + segment_size
58
+ ret[i] = x[i, idx_str:idx_end]
59
+ return ret
60
+
61
+
62
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
63
+ b, d, t = x.size()
64
+ if x_lengths is None:
65
+ x_lengths = t
66
+ ids_str_max = x_lengths - segment_size + 1
67
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
68
+ ret = slice_segments(x, ids_str, segment_size)
69
+ return ret, ids_str
70
+
71
+
72
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
73
+ position = torch.arange(length, dtype=torch.float)
74
+ num_timescales = channels // 2
75
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
76
+ num_timescales - 1
77
+ )
78
+ inv_timescales = min_timescale * torch.exp(
79
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
80
+ )
81
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
82
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
83
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
84
+ signal = signal.view(1, channels, length)
85
+ return signal
86
+
87
+
88
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
89
+ b, channels, length = x.size()
90
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
91
+ return x + signal.to(dtype=x.dtype, device=x.device)
92
+
93
+
94
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
95
+ b, channels, length = x.size()
96
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
97
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
98
+
99
+
100
+ def subsequent_mask(length):
101
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
102
+ return mask
103
+
104
+
105
+ @torch.jit.script
106
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
107
+ n_channels_int = n_channels[0]
108
+ in_act = input_a + input_b
109
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
110
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
111
+ acts = t_act * s_act
112
+ return acts
113
+
114
+
115
+ def convert_pad_shape(pad_shape):
116
+ l = pad_shape[::-1]
117
+ pad_shape = [item for sublist in l for item in sublist]
118
+ return pad_shape
119
+
120
+
121
+ def shift_1d(x):
122
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
123
+ return x
124
+
125
+
126
+ def sequence_mask(length, max_length=None):
127
+ if max_length is None:
128
+ max_length = length.max()
129
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
130
+ return x.unsqueeze(0) < length.unsqueeze(1)
131
+
132
+
133
+ def generate_path(duration, mask):
134
+ """
135
+ duration: [b, 1, t_x]
136
+ mask: [b, 1, t_y, t_x]
137
+ """
138
+ device = duration.device
139
+
140
+ b, _, t_y, t_x = mask.shape
141
+ cum_duration = torch.cumsum(duration, -1)
142
+
143
+ cum_duration_flat = cum_duration.view(b * t_x)
144
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
145
+ path = path.view(b, t_x, t_y)
146
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
147
+ path = path.unsqueeze(1).transpose(2, 3) * mask
148
+ return path
149
+
150
+
151
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
152
+ if isinstance(parameters, torch.Tensor):
153
+ parameters = [parameters]
154
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
155
+ norm_type = float(norm_type)
156
+ if clip_value is not None:
157
+ clip_value = float(clip_value)
158
+
159
+ total_norm = 0
160
+ for p in parameters:
161
+ param_norm = p.grad.data.norm(norm_type)
162
+ total_norm += param_norm.item() ** norm_type
163
+ if clip_value is not None:
164
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
165
+ total_norm = total_norm ** (1.0 / norm_type)
166
+ return total_norm
src/infer_pack/models.py ADDED
@@ -0,0 +1,1124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math, pdb, os
2
+ from time import time as ttime
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from infer_pack import modules
7
+ from infer_pack import attentions
8
+ from infer_pack import commons
9
+ from infer_pack.commons import init_weights, get_padding
10
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+ from infer_pack.commons import init_weights
13
+ import numpy as np
14
+ from infer_pack import commons
15
+
16
+
17
+ class TextEncoder256(nn.Module):
18
+ def __init__(
19
+ self,
20
+ out_channels,
21
+ hidden_channels,
22
+ filter_channels,
23
+ n_heads,
24
+ n_layers,
25
+ kernel_size,
26
+ p_dropout,
27
+ f0=True,
28
+ ):
29
+ super().__init__()
30
+ self.out_channels = out_channels
31
+ self.hidden_channels = hidden_channels
32
+ self.filter_channels = filter_channels
33
+ self.n_heads = n_heads
34
+ self.n_layers = n_layers
35
+ self.kernel_size = kernel_size
36
+ self.p_dropout = p_dropout
37
+ self.emb_phone = nn.Linear(256, hidden_channels)
38
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
39
+ if f0 == True:
40
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
41
+ self.encoder = attentions.Encoder(
42
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
43
+ )
44
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
45
+
46
+ def forward(self, phone, pitch, lengths):
47
+ if pitch == None:
48
+ x = self.emb_phone(phone)
49
+ else:
50
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
51
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
52
+ x = self.lrelu(x)
53
+ x = torch.transpose(x, 1, -1) # [b, h, t]
54
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
55
+ x.dtype
56
+ )
57
+ x = self.encoder(x * x_mask, x_mask)
58
+ stats = self.proj(x) * x_mask
59
+
60
+ m, logs = torch.split(stats, self.out_channels, dim=1)
61
+ return m, logs, x_mask
62
+
63
+
64
+ class TextEncoder768(nn.Module):
65
+ def __init__(
66
+ self,
67
+ out_channels,
68
+ hidden_channels,
69
+ filter_channels,
70
+ n_heads,
71
+ n_layers,
72
+ kernel_size,
73
+ p_dropout,
74
+ f0=True,
75
+ ):
76
+ super().__init__()
77
+ self.out_channels = out_channels
78
+ self.hidden_channels = hidden_channels
79
+ self.filter_channels = filter_channels
80
+ self.n_heads = n_heads
81
+ self.n_layers = n_layers
82
+ self.kernel_size = kernel_size
83
+ self.p_dropout = p_dropout
84
+ self.emb_phone = nn.Linear(768, hidden_channels)
85
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
86
+ if f0 == True:
87
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
88
+ self.encoder = attentions.Encoder(
89
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
90
+ )
91
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
92
+
93
+ def forward(self, phone, pitch, lengths):
94
+ if pitch == None:
95
+ x = self.emb_phone(phone)
96
+ else:
97
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
98
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
99
+ x = self.lrelu(x)
100
+ x = torch.transpose(x, 1, -1) # [b, h, t]
101
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
102
+ x.dtype
103
+ )
104
+ x = self.encoder(x * x_mask, x_mask)
105
+ stats = self.proj(x) * x_mask
106
+
107
+ m, logs = torch.split(stats, self.out_channels, dim=1)
108
+ return m, logs, x_mask
109
+
110
+
111
+ class ResidualCouplingBlock(nn.Module):
112
+ def __init__(
113
+ self,
114
+ channels,
115
+ hidden_channels,
116
+ kernel_size,
117
+ dilation_rate,
118
+ n_layers,
119
+ n_flows=4,
120
+ gin_channels=0,
121
+ ):
122
+ super().__init__()
123
+ self.channels = channels
124
+ self.hidden_channels = hidden_channels
125
+ self.kernel_size = kernel_size
126
+ self.dilation_rate = dilation_rate
127
+ self.n_layers = n_layers
128
+ self.n_flows = n_flows
129
+ self.gin_channels = gin_channels
130
+
131
+ self.flows = nn.ModuleList()
132
+ for i in range(n_flows):
133
+ self.flows.append(
134
+ modules.ResidualCouplingLayer(
135
+ channels,
136
+ hidden_channels,
137
+ kernel_size,
138
+ dilation_rate,
139
+ n_layers,
140
+ gin_channels=gin_channels,
141
+ mean_only=True,
142
+ )
143
+ )
144
+ self.flows.append(modules.Flip())
145
+
146
+ def forward(self, x, x_mask, g=None, reverse=False):
147
+ if not reverse:
148
+ for flow in self.flows:
149
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
150
+ else:
151
+ for flow in reversed(self.flows):
152
+ x = flow(x, x_mask, g=g, reverse=reverse)
153
+ return x
154
+
155
+ def remove_weight_norm(self):
156
+ for i in range(self.n_flows):
157
+ self.flows[i * 2].remove_weight_norm()
158
+
159
+
160
+ class PosteriorEncoder(nn.Module):
161
+ def __init__(
162
+ self,
163
+ in_channels,
164
+ out_channels,
165
+ hidden_channels,
166
+ kernel_size,
167
+ dilation_rate,
168
+ n_layers,
169
+ gin_channels=0,
170
+ ):
171
+ super().__init__()
172
+ self.in_channels = in_channels
173
+ self.out_channels = out_channels
174
+ self.hidden_channels = hidden_channels
175
+ self.kernel_size = kernel_size
176
+ self.dilation_rate = dilation_rate
177
+ self.n_layers = n_layers
178
+ self.gin_channels = gin_channels
179
+
180
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
181
+ self.enc = modules.WN(
182
+ hidden_channels,
183
+ kernel_size,
184
+ dilation_rate,
185
+ n_layers,
186
+ gin_channels=gin_channels,
187
+ )
188
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
189
+
190
+ def forward(self, x, x_lengths, g=None):
191
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
192
+ x.dtype
193
+ )
194
+ x = self.pre(x) * x_mask
195
+ x = self.enc(x, x_mask, g=g)
196
+ stats = self.proj(x) * x_mask
197
+ m, logs = torch.split(stats, self.out_channels, dim=1)
198
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
199
+ return z, m, logs, x_mask
200
+
201
+ def remove_weight_norm(self):
202
+ self.enc.remove_weight_norm()
203
+
204
+
205
+ class Generator(torch.nn.Module):
206
+ def __init__(
207
+ self,
208
+ initial_channel,
209
+ resblock,
210
+ resblock_kernel_sizes,
211
+ resblock_dilation_sizes,
212
+ upsample_rates,
213
+ upsample_initial_channel,
214
+ upsample_kernel_sizes,
215
+ gin_channels=0,
216
+ ):
217
+ super(Generator, self).__init__()
218
+ self.num_kernels = len(resblock_kernel_sizes)
219
+ self.num_upsamples = len(upsample_rates)
220
+ self.conv_pre = Conv1d(
221
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
222
+ )
223
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
224
+
225
+ self.ups = nn.ModuleList()
226
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
227
+ self.ups.append(
228
+ weight_norm(
229
+ ConvTranspose1d(
230
+ upsample_initial_channel // (2**i),
231
+ upsample_initial_channel // (2 ** (i + 1)),
232
+ k,
233
+ u,
234
+ padding=(k - u) // 2,
235
+ )
236
+ )
237
+ )
238
+
239
+ self.resblocks = nn.ModuleList()
240
+ for i in range(len(self.ups)):
241
+ ch = upsample_initial_channel // (2 ** (i + 1))
242
+ for j, (k, d) in enumerate(
243
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
244
+ ):
245
+ self.resblocks.append(resblock(ch, k, d))
246
+
247
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
248
+ self.ups.apply(init_weights)
249
+
250
+ if gin_channels != 0:
251
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
252
+
253
+ def forward(self, x, g=None):
254
+ x = self.conv_pre(x)
255
+ if g is not None:
256
+ x = x + self.cond(g)
257
+
258
+ for i in range(self.num_upsamples):
259
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
260
+ x = self.ups[i](x)
261
+ xs = None
262
+ for j in range(self.num_kernels):
263
+ if xs is None:
264
+ xs = self.resblocks[i * self.num_kernels + j](x)
265
+ else:
266
+ xs += self.resblocks[i * self.num_kernels + j](x)
267
+ x = xs / self.num_kernels
268
+ x = F.leaky_relu(x)
269
+ x = self.conv_post(x)
270
+ x = torch.tanh(x)
271
+
272
+ return x
273
+
274
+ def remove_weight_norm(self):
275
+ for l in self.ups:
276
+ remove_weight_norm(l)
277
+ for l in self.resblocks:
278
+ l.remove_weight_norm()
279
+
280
+
281
+ class SineGen(torch.nn.Module):
282
+ """Definition of sine generator
283
+ SineGen(samp_rate, harmonic_num = 0,
284
+ sine_amp = 0.1, noise_std = 0.003,
285
+ voiced_threshold = 0,
286
+ flag_for_pulse=False)
287
+ samp_rate: sampling rate in Hz
288
+ harmonic_num: number of harmonic overtones (default 0)
289
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
290
+ noise_std: std of Gaussian noise (default 0.003)
291
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
292
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
293
+ Note: when flag_for_pulse is True, the first time step of a voiced
294
+ segment is always sin(np.pi) or cos(0)
295
+ """
296
+
297
+ def __init__(
298
+ self,
299
+ samp_rate,
300
+ harmonic_num=0,
301
+ sine_amp=0.1,
302
+ noise_std=0.003,
303
+ voiced_threshold=0,
304
+ flag_for_pulse=False,
305
+ ):
306
+ super(SineGen, self).__init__()
307
+ self.sine_amp = sine_amp
308
+ self.noise_std = noise_std
309
+ self.harmonic_num = harmonic_num
310
+ self.dim = self.harmonic_num + 1
311
+ self.sampling_rate = samp_rate
312
+ self.voiced_threshold = voiced_threshold
313
+
314
+ def _f02uv(self, f0):
315
+ # generate uv signal
316
+ uv = torch.ones_like(f0)
317
+ uv = uv * (f0 > self.voiced_threshold)
318
+ return uv
319
+
320
+ def forward(self, f0, upp):
321
+ """sine_tensor, uv = forward(f0)
322
+ input F0: tensor(batchsize=1, length, dim=1)
323
+ f0 for unvoiced steps should be 0
324
+ output sine_tensor: tensor(batchsize=1, length, dim)
325
+ output uv: tensor(batchsize=1, length, 1)
326
+ """
327
+ with torch.no_grad():
328
+ f0 = f0[:, None].transpose(1, 2)
329
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
330
+ # fundamental component
331
+ f0_buf[:, :, 0] = f0[:, :, 0]
332
+ for idx in np.arange(self.harmonic_num):
333
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
334
+ idx + 2
335
+ ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
336
+ rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
337
+ rand_ini = torch.rand(
338
+ f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
339
+ )
340
+ rand_ini[:, 0] = 0
341
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
342
+ tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
343
+ tmp_over_one *= upp
344
+ tmp_over_one = F.interpolate(
345
+ tmp_over_one.transpose(2, 1),
346
+ scale_factor=upp,
347
+ mode="linear",
348
+ align_corners=True,
349
+ ).transpose(2, 1)
350
+ rad_values = F.interpolate(
351
+ rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
352
+ ).transpose(
353
+ 2, 1
354
+ ) #######
355
+ tmp_over_one %= 1
356
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
357
+ cumsum_shift = torch.zeros_like(rad_values)
358
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
359
+ sine_waves = torch.sin(
360
+ torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
361
+ )
362
+ sine_waves = sine_waves * self.sine_amp
363
+ uv = self._f02uv(f0)
364
+ uv = F.interpolate(
365
+ uv.transpose(2, 1), scale_factor=upp, mode="nearest"
366
+ ).transpose(2, 1)
367
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
368
+ noise = noise_amp * torch.randn_like(sine_waves)
369
+ sine_waves = sine_waves * uv + noise
370
+ return sine_waves, uv, noise
371
+
372
+
373
+ class SourceModuleHnNSF(torch.nn.Module):
374
+ """SourceModule for hn-nsf
375
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
376
+ add_noise_std=0.003, voiced_threshod=0)
377
+ sampling_rate: sampling_rate in Hz
378
+ harmonic_num: number of harmonic above F0 (default: 0)
379
+ sine_amp: amplitude of sine source signal (default: 0.1)
380
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
381
+ note that amplitude of noise in unvoiced is decided
382
+ by sine_amp
383
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
384
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
385
+ F0_sampled (batchsize, length, 1)
386
+ Sine_source (batchsize, length, 1)
387
+ noise_source (batchsize, length 1)
388
+ uv (batchsize, length, 1)
389
+ """
390
+
391
+ def __init__(
392
+ self,
393
+ sampling_rate,
394
+ harmonic_num=0,
395
+ sine_amp=0.1,
396
+ add_noise_std=0.003,
397
+ voiced_threshod=0,
398
+ is_half=True,
399
+ ):
400
+ super(SourceModuleHnNSF, self).__init__()
401
+
402
+ self.sine_amp = sine_amp
403
+ self.noise_std = add_noise_std
404
+ self.is_half = is_half
405
+ # to produce sine waveforms
406
+ self.l_sin_gen = SineGen(
407
+ sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
408
+ )
409
+
410
+ # to merge source harmonics into a single excitation
411
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
412
+ self.l_tanh = torch.nn.Tanh()
413
+
414
+ def forward(self, x, upp=None):
415
+ sine_wavs, uv, _ = self.l_sin_gen(x, upp)
416
+ if self.is_half:
417
+ sine_wavs = sine_wavs.half()
418
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
419
+ return sine_merge, None, None # noise, uv
420
+
421
+
422
+ class GeneratorNSF(torch.nn.Module):
423
+ def __init__(
424
+ self,
425
+ initial_channel,
426
+ resblock,
427
+ resblock_kernel_sizes,
428
+ resblock_dilation_sizes,
429
+ upsample_rates,
430
+ upsample_initial_channel,
431
+ upsample_kernel_sizes,
432
+ gin_channels,
433
+ sr,
434
+ is_half=False,
435
+ ):
436
+ super(GeneratorNSF, self).__init__()
437
+ self.num_kernels = len(resblock_kernel_sizes)
438
+ self.num_upsamples = len(upsample_rates)
439
+
440
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
441
+ self.m_source = SourceModuleHnNSF(
442
+ sampling_rate=sr, harmonic_num=0, is_half=is_half
443
+ )
444
+ self.noise_convs = nn.ModuleList()
445
+ self.conv_pre = Conv1d(
446
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
447
+ )
448
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
449
+
450
+ self.ups = nn.ModuleList()
451
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
452
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
453
+ self.ups.append(
454
+ weight_norm(
455
+ ConvTranspose1d(
456
+ upsample_initial_channel // (2**i),
457
+ upsample_initial_channel // (2 ** (i + 1)),
458
+ k,
459
+ u,
460
+ padding=(k - u) // 2,
461
+ )
462
+ )
463
+ )
464
+ if i + 1 < len(upsample_rates):
465
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
466
+ self.noise_convs.append(
467
+ Conv1d(
468
+ 1,
469
+ c_cur,
470
+ kernel_size=stride_f0 * 2,
471
+ stride=stride_f0,
472
+ padding=stride_f0 // 2,
473
+ )
474
+ )
475
+ else:
476
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
477
+
478
+ self.resblocks = nn.ModuleList()
479
+ for i in range(len(self.ups)):
480
+ ch = upsample_initial_channel // (2 ** (i + 1))
481
+ for j, (k, d) in enumerate(
482
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
483
+ ):
484
+ self.resblocks.append(resblock(ch, k, d))
485
+
486
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
487
+ self.ups.apply(init_weights)
488
+
489
+ if gin_channels != 0:
490
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
491
+
492
+ self.upp = np.prod(upsample_rates)
493
+
494
+ def forward(self, x, f0, g=None):
495
+ har_source, noi_source, uv = self.m_source(f0, self.upp)
496
+ har_source = har_source.transpose(1, 2)
497
+ x = self.conv_pre(x)
498
+ if g is not None:
499
+ x = x + self.cond(g)
500
+
501
+ for i in range(self.num_upsamples):
502
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
503
+ x = self.ups[i](x)
504
+ x_source = self.noise_convs[i](har_source)
505
+ x = x + x_source
506
+ xs = None
507
+ for j in range(self.num_kernels):
508
+ if xs is None:
509
+ xs = self.resblocks[i * self.num_kernels + j](x)
510
+ else:
511
+ xs += self.resblocks[i * self.num_kernels + j](x)
512
+ x = xs / self.num_kernels
513
+ x = F.leaky_relu(x)
514
+ x = self.conv_post(x)
515
+ x = torch.tanh(x)
516
+ return x
517
+
518
+ def remove_weight_norm(self):
519
+ for l in self.ups:
520
+ remove_weight_norm(l)
521
+ for l in self.resblocks:
522
+ l.remove_weight_norm()
523
+
524
+
525
+ sr2sr = {
526
+ "32k": 32000,
527
+ "40k": 40000,
528
+ "48k": 48000,
529
+ }
530
+
531
+
532
+ class SynthesizerTrnMs256NSFsid(nn.Module):
533
+ def __init__(
534
+ self,
535
+ spec_channels,
536
+ segment_size,
537
+ inter_channels,
538
+ hidden_channels,
539
+ filter_channels,
540
+ n_heads,
541
+ n_layers,
542
+ kernel_size,
543
+ p_dropout,
544
+ resblock,
545
+ resblock_kernel_sizes,
546
+ resblock_dilation_sizes,
547
+ upsample_rates,
548
+ upsample_initial_channel,
549
+ upsample_kernel_sizes,
550
+ spk_embed_dim,
551
+ gin_channels,
552
+ sr,
553
+ **kwargs
554
+ ):
555
+ super().__init__()
556
+ if type(sr) == type("strr"):
557
+ sr = sr2sr[sr]
558
+ self.spec_channels = spec_channels
559
+ self.inter_channels = inter_channels
560
+ self.hidden_channels = hidden_channels
561
+ self.filter_channels = filter_channels
562
+ self.n_heads = n_heads
563
+ self.n_layers = n_layers
564
+ self.kernel_size = kernel_size
565
+ self.p_dropout = p_dropout
566
+ self.resblock = resblock
567
+ self.resblock_kernel_sizes = resblock_kernel_sizes
568
+ self.resblock_dilation_sizes = resblock_dilation_sizes
569
+ self.upsample_rates = upsample_rates
570
+ self.upsample_initial_channel = upsample_initial_channel
571
+ self.upsample_kernel_sizes = upsample_kernel_sizes
572
+ self.segment_size = segment_size
573
+ self.gin_channels = gin_channels
574
+ # self.hop_length = hop_length#
575
+ self.spk_embed_dim = spk_embed_dim
576
+ self.enc_p = TextEncoder256(
577
+ inter_channels,
578
+ hidden_channels,
579
+ filter_channels,
580
+ n_heads,
581
+ n_layers,
582
+ kernel_size,
583
+ p_dropout,
584
+ )
585
+ self.dec = GeneratorNSF(
586
+ inter_channels,
587
+ resblock,
588
+ resblock_kernel_sizes,
589
+ resblock_dilation_sizes,
590
+ upsample_rates,
591
+ upsample_initial_channel,
592
+ upsample_kernel_sizes,
593
+ gin_channels=gin_channels,
594
+ sr=sr,
595
+ is_half=kwargs["is_half"],
596
+ )
597
+ self.enc_q = PosteriorEncoder(
598
+ spec_channels,
599
+ inter_channels,
600
+ hidden_channels,
601
+ 5,
602
+ 1,
603
+ 16,
604
+ gin_channels=gin_channels,
605
+ )
606
+ self.flow = ResidualCouplingBlock(
607
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
608
+ )
609
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
610
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
611
+
612
+ def remove_weight_norm(self):
613
+ self.dec.remove_weight_norm()
614
+ self.flow.remove_weight_norm()
615
+ self.enc_q.remove_weight_norm()
616
+
617
+ def forward(
618
+ self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
619
+ ): # 这里ds是id,[bs,1]
620
+ # print(1,pitch.shape)#[bs,t]
621
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
622
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
623
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
624
+ z_p = self.flow(z, y_mask, g=g)
625
+ z_slice, ids_slice = commons.rand_slice_segments(
626
+ z, y_lengths, self.segment_size
627
+ )
628
+ # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
629
+ pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
630
+ # print(-2,pitchf.shape,z_slice.shape)
631
+ o = self.dec(z_slice, pitchf, g=g)
632
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
633
+
634
+ def infer(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
635
+ g = self.emb_g(sid).unsqueeze(-1)
636
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
637
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
638
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
639
+ o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
640
+ return o, x_mask, (z, z_p, m_p, logs_p)
641
+
642
+
643
+ class SynthesizerTrnMs768NSFsid(nn.Module):
644
+ def __init__(
645
+ self,
646
+ spec_channels,
647
+ segment_size,
648
+ inter_channels,
649
+ hidden_channels,
650
+ filter_channels,
651
+ n_heads,
652
+ n_layers,
653
+ kernel_size,
654
+ p_dropout,
655
+ resblock,
656
+ resblock_kernel_sizes,
657
+ resblock_dilation_sizes,
658
+ upsample_rates,
659
+ upsample_initial_channel,
660
+ upsample_kernel_sizes,
661
+ spk_embed_dim,
662
+ gin_channels,
663
+ sr,
664
+ **kwargs
665
+ ):
666
+ super().__init__()
667
+ if type(sr) == type("strr"):
668
+ sr = sr2sr[sr]
669
+ self.spec_channels = spec_channels
670
+ self.inter_channels = inter_channels
671
+ self.hidden_channels = hidden_channels
672
+ self.filter_channels = filter_channels
673
+ self.n_heads = n_heads
674
+ self.n_layers = n_layers
675
+ self.kernel_size = kernel_size
676
+ self.p_dropout = p_dropout
677
+ self.resblock = resblock
678
+ self.resblock_kernel_sizes = resblock_kernel_sizes
679
+ self.resblock_dilation_sizes = resblock_dilation_sizes
680
+ self.upsample_rates = upsample_rates
681
+ self.upsample_initial_channel = upsample_initial_channel
682
+ self.upsample_kernel_sizes = upsample_kernel_sizes
683
+ self.segment_size = segment_size
684
+ self.gin_channels = gin_channels
685
+ # self.hop_length = hop_length#
686
+ self.spk_embed_dim = spk_embed_dim
687
+ self.enc_p = TextEncoder768(
688
+ inter_channels,
689
+ hidden_channels,
690
+ filter_channels,
691
+ n_heads,
692
+ n_layers,
693
+ kernel_size,
694
+ p_dropout,
695
+ )
696
+ self.dec = GeneratorNSF(
697
+ inter_channels,
698
+ resblock,
699
+ resblock_kernel_sizes,
700
+ resblock_dilation_sizes,
701
+ upsample_rates,
702
+ upsample_initial_channel,
703
+ upsample_kernel_sizes,
704
+ gin_channels=gin_channels,
705
+ sr=sr,
706
+ is_half=kwargs["is_half"],
707
+ )
708
+ self.enc_q = PosteriorEncoder(
709
+ spec_channels,
710
+ inter_channels,
711
+ hidden_channels,
712
+ 5,
713
+ 1,
714
+ 16,
715
+ gin_channels=gin_channels,
716
+ )
717
+ self.flow = ResidualCouplingBlock(
718
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
719
+ )
720
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
721
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
722
+
723
+ def remove_weight_norm(self):
724
+ self.dec.remove_weight_norm()
725
+ self.flow.remove_weight_norm()
726
+ self.enc_q.remove_weight_norm()
727
+
728
+ def forward(
729
+ self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
730
+ ): # 这里ds是id,[bs,1]
731
+ # print(1,pitch.shape)#[bs,t]
732
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
733
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
734
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
735
+ z_p = self.flow(z, y_mask, g=g)
736
+ z_slice, ids_slice = commons.rand_slice_segments(
737
+ z, y_lengths, self.segment_size
738
+ )
739
+ # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
740
+ pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
741
+ # print(-2,pitchf.shape,z_slice.shape)
742
+ o = self.dec(z_slice, pitchf, g=g)
743
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
744
+
745
+ def infer(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
746
+ g = self.emb_g(sid).unsqueeze(-1)
747
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
748
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
749
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
750
+ o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
751
+ return o, x_mask, (z, z_p, m_p, logs_p)
752
+
753
+
754
+ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
755
+ def __init__(
756
+ self,
757
+ spec_channels,
758
+ segment_size,
759
+ inter_channels,
760
+ hidden_channels,
761
+ filter_channels,
762
+ n_heads,
763
+ n_layers,
764
+ kernel_size,
765
+ p_dropout,
766
+ resblock,
767
+ resblock_kernel_sizes,
768
+ resblock_dilation_sizes,
769
+ upsample_rates,
770
+ upsample_initial_channel,
771
+ upsample_kernel_sizes,
772
+ spk_embed_dim,
773
+ gin_channels,
774
+ sr=None,
775
+ **kwargs
776
+ ):
777
+ super().__init__()
778
+ self.spec_channels = spec_channels
779
+ self.inter_channels = inter_channels
780
+ self.hidden_channels = hidden_channels
781
+ self.filter_channels = filter_channels
782
+ self.n_heads = n_heads
783
+ self.n_layers = n_layers
784
+ self.kernel_size = kernel_size
785
+ self.p_dropout = p_dropout
786
+ self.resblock = resblock
787
+ self.resblock_kernel_sizes = resblock_kernel_sizes
788
+ self.resblock_dilation_sizes = resblock_dilation_sizes
789
+ self.upsample_rates = upsample_rates
790
+ self.upsample_initial_channel = upsample_initial_channel
791
+ self.upsample_kernel_sizes = upsample_kernel_sizes
792
+ self.segment_size = segment_size
793
+ self.gin_channels = gin_channels
794
+ # self.hop_length = hop_length#
795
+ self.spk_embed_dim = spk_embed_dim
796
+ self.enc_p = TextEncoder256(
797
+ inter_channels,
798
+ hidden_channels,
799
+ filter_channels,
800
+ n_heads,
801
+ n_layers,
802
+ kernel_size,
803
+ p_dropout,
804
+ f0=False,
805
+ )
806
+ self.dec = Generator(
807
+ inter_channels,
808
+ resblock,
809
+ resblock_kernel_sizes,
810
+ resblock_dilation_sizes,
811
+ upsample_rates,
812
+ upsample_initial_channel,
813
+ upsample_kernel_sizes,
814
+ gin_channels=gin_channels,
815
+ )
816
+ self.enc_q = PosteriorEncoder(
817
+ spec_channels,
818
+ inter_channels,
819
+ hidden_channels,
820
+ 5,
821
+ 1,
822
+ 16,
823
+ gin_channels=gin_channels,
824
+ )
825
+ self.flow = ResidualCouplingBlock(
826
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
827
+ )
828
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
829
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
830
+
831
+ def remove_weight_norm(self):
832
+ self.dec.remove_weight_norm()
833
+ self.flow.remove_weight_norm()
834
+ self.enc_q.remove_weight_norm()
835
+
836
+ def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
837
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
838
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
839
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
840
+ z_p = self.flow(z, y_mask, g=g)
841
+ z_slice, ids_slice = commons.rand_slice_segments(
842
+ z, y_lengths, self.segment_size
843
+ )
844
+ o = self.dec(z_slice, g=g)
845
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
846
+
847
+ def infer(self, phone, phone_lengths, sid, max_len=None):
848
+ g = self.emb_g(sid).unsqueeze(-1)
849
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
850
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
851
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
852
+ o = self.dec((z * x_mask)[:, :, :max_len], g=g)
853
+ return o, x_mask, (z, z_p, m_p, logs_p)
854
+
855
+
856
+ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
857
+ def __init__(
858
+ self,
859
+ spec_channels,
860
+ segment_size,
861
+ inter_channels,
862
+ hidden_channels,
863
+ filter_channels,
864
+ n_heads,
865
+ n_layers,
866
+ kernel_size,
867
+ p_dropout,
868
+ resblock,
869
+ resblock_kernel_sizes,
870
+ resblock_dilation_sizes,
871
+ upsample_rates,
872
+ upsample_initial_channel,
873
+ upsample_kernel_sizes,
874
+ spk_embed_dim,
875
+ gin_channels,
876
+ sr=None,
877
+ **kwargs
878
+ ):
879
+ super().__init__()
880
+ self.spec_channels = spec_channels
881
+ self.inter_channels = inter_channels
882
+ self.hidden_channels = hidden_channels
883
+ self.filter_channels = filter_channels
884
+ self.n_heads = n_heads
885
+ self.n_layers = n_layers
886
+ self.kernel_size = kernel_size
887
+ self.p_dropout = p_dropout
888
+ self.resblock = resblock
889
+ self.resblock_kernel_sizes = resblock_kernel_sizes
890
+ self.resblock_dilation_sizes = resblock_dilation_sizes
891
+ self.upsample_rates = upsample_rates
892
+ self.upsample_initial_channel = upsample_initial_channel
893
+ self.upsample_kernel_sizes = upsample_kernel_sizes
894
+ self.segment_size = segment_size
895
+ self.gin_channels = gin_channels
896
+ # self.hop_length = hop_length#
897
+ self.spk_embed_dim = spk_embed_dim
898
+ self.enc_p = TextEncoder768(
899
+ inter_channels,
900
+ hidden_channels,
901
+ filter_channels,
902
+ n_heads,
903
+ n_layers,
904
+ kernel_size,
905
+ p_dropout,
906
+ f0=False,
907
+ )
908
+ self.dec = Generator(
909
+ inter_channels,
910
+ resblock,
911
+ resblock_kernel_sizes,
912
+ resblock_dilation_sizes,
913
+ upsample_rates,
914
+ upsample_initial_channel,
915
+ upsample_kernel_sizes,
916
+ gin_channels=gin_channels,
917
+ )
918
+ self.enc_q = PosteriorEncoder(
919
+ spec_channels,
920
+ inter_channels,
921
+ hidden_channels,
922
+ 5,
923
+ 1,
924
+ 16,
925
+ gin_channels=gin_channels,
926
+ )
927
+ self.flow = ResidualCouplingBlock(
928
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
929
+ )
930
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
931
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
932
+
933
+ def remove_weight_norm(self):
934
+ self.dec.remove_weight_norm()
935
+ self.flow.remove_weight_norm()
936
+ self.enc_q.remove_weight_norm()
937
+
938
+ def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
939
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
940
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
941
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
942
+ z_p = self.flow(z, y_mask, g=g)
943
+ z_slice, ids_slice = commons.rand_slice_segments(
944
+ z, y_lengths, self.segment_size
945
+ )
946
+ o = self.dec(z_slice, g=g)
947
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
948
+
949
+ def infer(self, phone, phone_lengths, sid, max_len=None):
950
+ g = self.emb_g(sid).unsqueeze(-1)
951
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
952
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
953
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
954
+ o = self.dec((z * x_mask)[:, :, :max_len], g=g)
955
+ return o, x_mask, (z, z_p, m_p, logs_p)
956
+
957
+
958
+ class MultiPeriodDiscriminator(torch.nn.Module):
959
+ def __init__(self, use_spectral_norm=False):
960
+ super(MultiPeriodDiscriminator, self).__init__()
961
+ periods = [2, 3, 5, 7, 11, 17]
962
+ # periods = [3, 5, 7, 11, 17, 23, 37]
963
+
964
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
965
+ discs = discs + [
966
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
967
+ ]
968
+ self.discriminators = nn.ModuleList(discs)
969
+
970
+ def forward(self, y, y_hat):
971
+ y_d_rs = [] #
972
+ y_d_gs = []
973
+ fmap_rs = []
974
+ fmap_gs = []
975
+ for i, d in enumerate(self.discriminators):
976
+ y_d_r, fmap_r = d(y)
977
+ y_d_g, fmap_g = d(y_hat)
978
+ # for j in range(len(fmap_r)):
979
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
980
+ y_d_rs.append(y_d_r)
981
+ y_d_gs.append(y_d_g)
982
+ fmap_rs.append(fmap_r)
983
+ fmap_gs.append(fmap_g)
984
+
985
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
986
+
987
+
988
+ class MultiPeriodDiscriminatorV2(torch.nn.Module):
989
+ def __init__(self, use_spectral_norm=False):
990
+ super(MultiPeriodDiscriminatorV2, self).__init__()
991
+ # periods = [2, 3, 5, 7, 11, 17]
992
+ periods = [2, 3, 5, 7, 11, 17, 23, 37]
993
+
994
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
995
+ discs = discs + [
996
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
997
+ ]
998
+ self.discriminators = nn.ModuleList(discs)
999
+
1000
+ def forward(self, y, y_hat):
1001
+ y_d_rs = [] #
1002
+ y_d_gs = []
1003
+ fmap_rs = []
1004
+ fmap_gs = []
1005
+ for i, d in enumerate(self.discriminators):
1006
+ y_d_r, fmap_r = d(y)
1007
+ y_d_g, fmap_g = d(y_hat)
1008
+ # for j in range(len(fmap_r)):
1009
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
1010
+ y_d_rs.append(y_d_r)
1011
+ y_d_gs.append(y_d_g)
1012
+ fmap_rs.append(fmap_r)
1013
+ fmap_gs.append(fmap_g)
1014
+
1015
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
1016
+
1017
+
1018
+ class DiscriminatorS(torch.nn.Module):
1019
+ def __init__(self, use_spectral_norm=False):
1020
+ super(DiscriminatorS, self).__init__()
1021
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1022
+ self.convs = nn.ModuleList(
1023
+ [
1024
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
1025
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
1026
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
1027
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
1028
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
1029
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
1030
+ ]
1031
+ )
1032
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
1033
+
1034
+ def forward(self, x):
1035
+ fmap = []
1036
+
1037
+ for l in self.convs:
1038
+ x = l(x)
1039
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1040
+ fmap.append(x)
1041
+ x = self.conv_post(x)
1042
+ fmap.append(x)
1043
+ x = torch.flatten(x, 1, -1)
1044
+
1045
+ return x, fmap
1046
+
1047
+
1048
+ class DiscriminatorP(torch.nn.Module):
1049
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
1050
+ super(DiscriminatorP, self).__init__()
1051
+ self.period = period
1052
+ self.use_spectral_norm = use_spectral_norm
1053
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1054
+ self.convs = nn.ModuleList(
1055
+ [
1056
+ norm_f(
1057
+ Conv2d(
1058
+ 1,
1059
+ 32,
1060
+ (kernel_size, 1),
1061
+ (stride, 1),
1062
+ padding=(get_padding(kernel_size, 1), 0),
1063
+ )
1064
+ ),
1065
+ norm_f(
1066
+ Conv2d(
1067
+ 32,
1068
+ 128,
1069
+ (kernel_size, 1),
1070
+ (stride, 1),
1071
+ padding=(get_padding(kernel_size, 1), 0),
1072
+ )
1073
+ ),
1074
+ norm_f(
1075
+ Conv2d(
1076
+ 128,
1077
+ 512,
1078
+ (kernel_size, 1),
1079
+ (stride, 1),
1080
+ padding=(get_padding(kernel_size, 1), 0),
1081
+ )
1082
+ ),
1083
+ norm_f(
1084
+ Conv2d(
1085
+ 512,
1086
+ 1024,
1087
+ (kernel_size, 1),
1088
+ (stride, 1),
1089
+ padding=(get_padding(kernel_size, 1), 0),
1090
+ )
1091
+ ),
1092
+ norm_f(
1093
+ Conv2d(
1094
+ 1024,
1095
+ 1024,
1096
+ (kernel_size, 1),
1097
+ 1,
1098
+ padding=(get_padding(kernel_size, 1), 0),
1099
+ )
1100
+ ),
1101
+ ]
1102
+ )
1103
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
1104
+
1105
+ def forward(self, x):
1106
+ fmap = []
1107
+
1108
+ # 1d to 2d
1109
+ b, c, t = x.shape
1110
+ if t % self.period != 0: # pad first
1111
+ n_pad = self.period - (t % self.period)
1112
+ x = F.pad(x, (0, n_pad), "reflect")
1113
+ t = t + n_pad
1114
+ x = x.view(b, c, t // self.period, self.period)
1115
+
1116
+ for l in self.convs:
1117
+ x = l(x)
1118
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1119
+ fmap.append(x)
1120
+ x = self.conv_post(x)
1121
+ fmap.append(x)
1122
+ x = torch.flatten(x, 1, -1)
1123
+
1124
+ return x, fmap
src/infer_pack/models_onnx.py ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math, pdb, os
2
+ from time import time as ttime
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from infer_pack import modules
7
+ from infer_pack import attentions
8
+ from infer_pack import commons
9
+ from infer_pack.commons import init_weights, get_padding
10
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+ from infer_pack.commons import init_weights
13
+ import numpy as np
14
+ from infer_pack import commons
15
+
16
+
17
+ class TextEncoder256(nn.Module):
18
+ def __init__(
19
+ self,
20
+ out_channels,
21
+ hidden_channels,
22
+ filter_channels,
23
+ n_heads,
24
+ n_layers,
25
+ kernel_size,
26
+ p_dropout,
27
+ f0=True,
28
+ ):
29
+ super().__init__()
30
+ self.out_channels = out_channels
31
+ self.hidden_channels = hidden_channels
32
+ self.filter_channels = filter_channels
33
+ self.n_heads = n_heads
34
+ self.n_layers = n_layers
35
+ self.kernel_size = kernel_size
36
+ self.p_dropout = p_dropout
37
+ self.emb_phone = nn.Linear(256, hidden_channels)
38
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
39
+ if f0 == True:
40
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
41
+ self.encoder = attentions.Encoder(
42
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
43
+ )
44
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
45
+
46
+ def forward(self, phone, pitch, lengths):
47
+ if pitch == None:
48
+ x = self.emb_phone(phone)
49
+ else:
50
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
51
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
52
+ x = self.lrelu(x)
53
+ x = torch.transpose(x, 1, -1) # [b, h, t]
54
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
55
+ x.dtype
56
+ )
57
+ x = self.encoder(x * x_mask, x_mask)
58
+ stats = self.proj(x) * x_mask
59
+
60
+ m, logs = torch.split(stats, self.out_channels, dim=1)
61
+ return m, logs, x_mask
62
+
63
+
64
+ class TextEncoder768(nn.Module):
65
+ def __init__(
66
+ self,
67
+ out_channels,
68
+ hidden_channels,
69
+ filter_channels,
70
+ n_heads,
71
+ n_layers,
72
+ kernel_size,
73
+ p_dropout,
74
+ f0=True,
75
+ ):
76
+ super().__init__()
77
+ self.out_channels = out_channels
78
+ self.hidden_channels = hidden_channels
79
+ self.filter_channels = filter_channels
80
+ self.n_heads = n_heads
81
+ self.n_layers = n_layers
82
+ self.kernel_size = kernel_size
83
+ self.p_dropout = p_dropout
84
+ self.emb_phone = nn.Linear(768, hidden_channels)
85
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
86
+ if f0 == True:
87
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
88
+ self.encoder = attentions.Encoder(
89
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
90
+ )
91
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
92
+
93
+ def forward(self, phone, pitch, lengths):
94
+ if pitch == None:
95
+ x = self.emb_phone(phone)
96
+ else:
97
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
98
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
99
+ x = self.lrelu(x)
100
+ x = torch.transpose(x, 1, -1) # [b, h, t]
101
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
102
+ x.dtype
103
+ )
104
+ x = self.encoder(x * x_mask, x_mask)
105
+ stats = self.proj(x) * x_mask
106
+
107
+ m, logs = torch.split(stats, self.out_channels, dim=1)
108
+ return m, logs, x_mask
109
+
110
+
111
+ class ResidualCouplingBlock(nn.Module):
112
+ def __init__(
113
+ self,
114
+ channels,
115
+ hidden_channels,
116
+ kernel_size,
117
+ dilation_rate,
118
+ n_layers,
119
+ n_flows=4,
120
+ gin_channels=0,
121
+ ):
122
+ super().__init__()
123
+ self.channels = channels
124
+ self.hidden_channels = hidden_channels
125
+ self.kernel_size = kernel_size
126
+ self.dilation_rate = dilation_rate
127
+ self.n_layers = n_layers
128
+ self.n_flows = n_flows
129
+ self.gin_channels = gin_channels
130
+
131
+ self.flows = nn.ModuleList()
132
+ for i in range(n_flows):
133
+ self.flows.append(
134
+ modules.ResidualCouplingLayer(
135
+ channels,
136
+ hidden_channels,
137
+ kernel_size,
138
+ dilation_rate,
139
+ n_layers,
140
+ gin_channels=gin_channels,
141
+ mean_only=True,
142
+ )
143
+ )
144
+ self.flows.append(modules.Flip())
145
+
146
+ def forward(self, x, x_mask, g=None, reverse=False):
147
+ if not reverse:
148
+ for flow in self.flows:
149
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
150
+ else:
151
+ for flow in reversed(self.flows):
152
+ x = flow(x, x_mask, g=g, reverse=reverse)
153
+ return x
154
+
155
+ def remove_weight_norm(self):
156
+ for i in range(self.n_flows):
157
+ self.flows[i * 2].remove_weight_norm()
158
+
159
+
160
+ class PosteriorEncoder(nn.Module):
161
+ def __init__(
162
+ self,
163
+ in_channels,
164
+ out_channels,
165
+ hidden_channels,
166
+ kernel_size,
167
+ dilation_rate,
168
+ n_layers,
169
+ gin_channels=0,
170
+ ):
171
+ super().__init__()
172
+ self.in_channels = in_channels
173
+ self.out_channels = out_channels
174
+ self.hidden_channels = hidden_channels
175
+ self.kernel_size = kernel_size
176
+ self.dilation_rate = dilation_rate
177
+ self.n_layers = n_layers
178
+ self.gin_channels = gin_channels
179
+
180
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
181
+ self.enc = modules.WN(
182
+ hidden_channels,
183
+ kernel_size,
184
+ dilation_rate,
185
+ n_layers,
186
+ gin_channels=gin_channels,
187
+ )
188
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
189
+
190
+ def forward(self, x, x_lengths, g=None):
191
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
192
+ x.dtype
193
+ )
194
+ x = self.pre(x) * x_mask
195
+ x = self.enc(x, x_mask, g=g)
196
+ stats = self.proj(x) * x_mask
197
+ m, logs = torch.split(stats, self.out_channels, dim=1)
198
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
199
+ return z, m, logs, x_mask
200
+
201
+ def remove_weight_norm(self):
202
+ self.enc.remove_weight_norm()
203
+
204
+
205
+ class Generator(torch.nn.Module):
206
+ def __init__(
207
+ self,
208
+ initial_channel,
209
+ resblock,
210
+ resblock_kernel_sizes,
211
+ resblock_dilation_sizes,
212
+ upsample_rates,
213
+ upsample_initial_channel,
214
+ upsample_kernel_sizes,
215
+ gin_channels=0,
216
+ ):
217
+ super(Generator, self).__init__()
218
+ self.num_kernels = len(resblock_kernel_sizes)
219
+ self.num_upsamples = len(upsample_rates)
220
+ self.conv_pre = Conv1d(
221
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
222
+ )
223
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
224
+
225
+ self.ups = nn.ModuleList()
226
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
227
+ self.ups.append(
228
+ weight_norm(
229
+ ConvTranspose1d(
230
+ upsample_initial_channel // (2**i),
231
+ upsample_initial_channel // (2 ** (i + 1)),
232
+ k,
233
+ u,
234
+ padding=(k - u) // 2,
235
+ )
236
+ )
237
+ )
238
+
239
+ self.resblocks = nn.ModuleList()
240
+ for i in range(len(self.ups)):
241
+ ch = upsample_initial_channel // (2 ** (i + 1))
242
+ for j, (k, d) in enumerate(
243
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
244
+ ):
245
+ self.resblocks.append(resblock(ch, k, d))
246
+
247
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
248
+ self.ups.apply(init_weights)
249
+
250
+ if gin_channels != 0:
251
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
252
+
253
+ def forward(self, x, g=None):
254
+ x = self.conv_pre(x)
255
+ if g is not None:
256
+ x = x + self.cond(g)
257
+
258
+ for i in range(self.num_upsamples):
259
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
260
+ x = self.ups[i](x)
261
+ xs = None
262
+ for j in range(self.num_kernels):
263
+ if xs is None:
264
+ xs = self.resblocks[i * self.num_kernels + j](x)
265
+ else:
266
+ xs += self.resblocks[i * self.num_kernels + j](x)
267
+ x = xs / self.num_kernels
268
+ x = F.leaky_relu(x)
269
+ x = self.conv_post(x)
270
+ x = torch.tanh(x)
271
+
272
+ return x
273
+
274
+ def remove_weight_norm(self):
275
+ for l in self.ups:
276
+ remove_weight_norm(l)
277
+ for l in self.resblocks:
278
+ l.remove_weight_norm()
279
+
280
+
281
+ class SineGen(torch.nn.Module):
282
+ """Definition of sine generator
283
+ SineGen(samp_rate, harmonic_num = 0,
284
+ sine_amp = 0.1, noise_std = 0.003,
285
+ voiced_threshold = 0,
286
+ flag_for_pulse=False)
287
+ samp_rate: sampling rate in Hz
288
+ harmonic_num: number of harmonic overtones (default 0)
289
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
290
+ noise_std: std of Gaussian noise (default 0.003)
291
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
292
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
293
+ Note: when flag_for_pulse is True, the first time step of a voiced
294
+ segment is always sin(np.pi) or cos(0)
295
+ """
296
+
297
+ def __init__(
298
+ self,
299
+ samp_rate,
300
+ harmonic_num=0,
301
+ sine_amp=0.1,
302
+ noise_std=0.003,
303
+ voiced_threshold=0,
304
+ flag_for_pulse=False,
305
+ ):
306
+ super(SineGen, self).__init__()
307
+ self.sine_amp = sine_amp
308
+ self.noise_std = noise_std
309
+ self.harmonic_num = harmonic_num
310
+ self.dim = self.harmonic_num + 1
311
+ self.sampling_rate = samp_rate
312
+ self.voiced_threshold = voiced_threshold
313
+
314
+ def _f02uv(self, f0):
315
+ # generate uv signal
316
+ uv = torch.ones_like(f0)
317
+ uv = uv * (f0 > self.voiced_threshold)
318
+ return uv
319
+
320
+ def forward(self, f0, upp):
321
+ """sine_tensor, uv = forward(f0)
322
+ input F0: tensor(batchsize=1, length, dim=1)
323
+ f0 for unvoiced steps should be 0
324
+ output sine_tensor: tensor(batchsize=1, length, dim)
325
+ output uv: tensor(batchsize=1, length, 1)
326
+ """
327
+ with torch.no_grad():
328
+ f0 = f0[:, None].transpose(1, 2)
329
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
330
+ # fundamental component
331
+ f0_buf[:, :, 0] = f0[:, :, 0]
332
+ for idx in np.arange(self.harmonic_num):
333
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
334
+ idx + 2
335
+ ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
336
+ rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
337
+ rand_ini = torch.rand(
338
+ f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
339
+ )
340
+ rand_ini[:, 0] = 0
341
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
342
+ tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
343
+ tmp_over_one *= upp
344
+ tmp_over_one = F.interpolate(
345
+ tmp_over_one.transpose(2, 1),
346
+ scale_factor=upp,
347
+ mode="linear",
348
+ align_corners=True,
349
+ ).transpose(2, 1)
350
+ rad_values = F.interpolate(
351
+ rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
352
+ ).transpose(
353
+ 2, 1
354
+ ) #######
355
+ tmp_over_one %= 1
356
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
357
+ cumsum_shift = torch.zeros_like(rad_values)
358
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
359
+ sine_waves = torch.sin(
360
+ torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
361
+ )
362
+ sine_waves = sine_waves * self.sine_amp
363
+ uv = self._f02uv(f0)
364
+ uv = F.interpolate(
365
+ uv.transpose(2, 1), scale_factor=upp, mode="nearest"
366
+ ).transpose(2, 1)
367
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
368
+ noise = noise_amp * torch.randn_like(sine_waves)
369
+ sine_waves = sine_waves * uv + noise
370
+ return sine_waves, uv, noise
371
+
372
+
373
+ class SourceModuleHnNSF(torch.nn.Module):
374
+ """SourceModule for hn-nsf
375
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
376
+ add_noise_std=0.003, voiced_threshod=0)
377
+ sampling_rate: sampling_rate in Hz
378
+ harmonic_num: number of harmonic above F0 (default: 0)
379
+ sine_amp: amplitude of sine source signal (default: 0.1)
380
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
381
+ note that amplitude of noise in unvoiced is decided
382
+ by sine_amp
383
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
384
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
385
+ F0_sampled (batchsize, length, 1)
386
+ Sine_source (batchsize, length, 1)
387
+ noise_source (batchsize, length 1)
388
+ uv (batchsize, length, 1)
389
+ """
390
+
391
+ def __init__(
392
+ self,
393
+ sampling_rate,
394
+ harmonic_num=0,
395
+ sine_amp=0.1,
396
+ add_noise_std=0.003,
397
+ voiced_threshod=0,
398
+ is_half=True,
399
+ ):
400
+ super(SourceModuleHnNSF, self).__init__()
401
+
402
+ self.sine_amp = sine_amp
403
+ self.noise_std = add_noise_std
404
+ self.is_half = is_half
405
+ # to produce sine waveforms
406
+ self.l_sin_gen = SineGen(
407
+ sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
408
+ )
409
+
410
+ # to merge source harmonics into a single excitation
411
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
412
+ self.l_tanh = torch.nn.Tanh()
413
+
414
+ def forward(self, x, upp=None):
415
+ sine_wavs, uv, _ = self.l_sin_gen(x, upp)
416
+ if self.is_half:
417
+ sine_wavs = sine_wavs.half()
418
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
419
+ return sine_merge, None, None # noise, uv
420
+
421
+
422
+ class GeneratorNSF(torch.nn.Module):
423
+ def __init__(
424
+ self,
425
+ initial_channel,
426
+ resblock,
427
+ resblock_kernel_sizes,
428
+ resblock_dilation_sizes,
429
+ upsample_rates,
430
+ upsample_initial_channel,
431
+ upsample_kernel_sizes,
432
+ gin_channels,
433
+ sr,
434
+ is_half=False,
435
+ ):
436
+ super(GeneratorNSF, self).__init__()
437
+ self.num_kernels = len(resblock_kernel_sizes)
438
+ self.num_upsamples = len(upsample_rates)
439
+
440
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
441
+ self.m_source = SourceModuleHnNSF(
442
+ sampling_rate=sr, harmonic_num=0, is_half=is_half
443
+ )
444
+ self.noise_convs = nn.ModuleList()
445
+ self.conv_pre = Conv1d(
446
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
447
+ )
448
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
449
+
450
+ self.ups = nn.ModuleList()
451
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
452
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
453
+ self.ups.append(
454
+ weight_norm(
455
+ ConvTranspose1d(
456
+ upsample_initial_channel // (2**i),
457
+ upsample_initial_channel // (2 ** (i + 1)),
458
+ k,
459
+ u,
460
+ padding=(k - u) // 2,
461
+ )
462
+ )
463
+ )
464
+ if i + 1 < len(upsample_rates):
465
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
466
+ self.noise_convs.append(
467
+ Conv1d(
468
+ 1,
469
+ c_cur,
470
+ kernel_size=stride_f0 * 2,
471
+ stride=stride_f0,
472
+ padding=stride_f0 // 2,
473
+ )
474
+ )
475
+ else:
476
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
477
+
478
+ self.resblocks = nn.ModuleList()
479
+ for i in range(len(self.ups)):
480
+ ch = upsample_initial_channel // (2 ** (i + 1))
481
+ for j, (k, d) in enumerate(
482
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
483
+ ):
484
+ self.resblocks.append(resblock(ch, k, d))
485
+
486
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
487
+ self.ups.apply(init_weights)
488
+
489
+ if gin_channels != 0:
490
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
491
+
492
+ self.upp = np.prod(upsample_rates)
493
+
494
+ def forward(self, x, f0, g=None):
495
+ har_source, noi_source, uv = self.m_source(f0, self.upp)
496
+ har_source = har_source.transpose(1, 2)
497
+ x = self.conv_pre(x)
498
+ if g is not None:
499
+ x = x + self.cond(g)
500
+
501
+ for i in range(self.num_upsamples):
502
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
503
+ x = self.ups[i](x)
504
+ x_source = self.noise_convs[i](har_source)
505
+ x = x + x_source
506
+ xs = None
507
+ for j in range(self.num_kernels):
508
+ if xs is None:
509
+ xs = self.resblocks[i * self.num_kernels + j](x)
510
+ else:
511
+ xs += self.resblocks[i * self.num_kernels + j](x)
512
+ x = xs / self.num_kernels
513
+ x = F.leaky_relu(x)
514
+ x = self.conv_post(x)
515
+ x = torch.tanh(x)
516
+ return x
517
+
518
+ def remove_weight_norm(self):
519
+ for l in self.ups:
520
+ remove_weight_norm(l)
521
+ for l in self.resblocks:
522
+ l.remove_weight_norm()
523
+
524
+
525
+ sr2sr = {
526
+ "32k": 32000,
527
+ "40k": 40000,
528
+ "48k": 48000,
529
+ }
530
+
531
+
532
+ class SynthesizerTrnMsNSFsidM(nn.Module):
533
+ def __init__(
534
+ self,
535
+ spec_channels,
536
+ segment_size,
537
+ inter_channels,
538
+ hidden_channels,
539
+ filter_channels,
540
+ n_heads,
541
+ n_layers,
542
+ kernel_size,
543
+ p_dropout,
544
+ resblock,
545
+ resblock_kernel_sizes,
546
+ resblock_dilation_sizes,
547
+ upsample_rates,
548
+ upsample_initial_channel,
549
+ upsample_kernel_sizes,
550
+ spk_embed_dim,
551
+ gin_channels,
552
+ sr,
553
+ **kwargs
554
+ ):
555
+ super().__init__()
556
+ if type(sr) == type("strr"):
557
+ sr = sr2sr[sr]
558
+ self.spec_channels = spec_channels
559
+ self.inter_channels = inter_channels
560
+ self.hidden_channels = hidden_channels
561
+ self.filter_channels = filter_channels
562
+ self.n_heads = n_heads
563
+ self.n_layers = n_layers
564
+ self.kernel_size = kernel_size
565
+ self.p_dropout = p_dropout
566
+ self.resblock = resblock
567
+ self.resblock_kernel_sizes = resblock_kernel_sizes
568
+ self.resblock_dilation_sizes = resblock_dilation_sizes
569
+ self.upsample_rates = upsample_rates
570
+ self.upsample_initial_channel = upsample_initial_channel
571
+ self.upsample_kernel_sizes = upsample_kernel_sizes
572
+ self.segment_size = segment_size
573
+ self.gin_channels = gin_channels
574
+ # self.hop_length = hop_length#
575
+ self.spk_embed_dim = spk_embed_dim
576
+ if self.gin_channels == 256:
577
+ self.enc_p = TextEncoder256(
578
+ inter_channels,
579
+ hidden_channels,
580
+ filter_channels,
581
+ n_heads,
582
+ n_layers,
583
+ kernel_size,
584
+ p_dropout,
585
+ )
586
+ else:
587
+ self.enc_p = TextEncoder768(
588
+ inter_channels,
589
+ hidden_channels,
590
+ filter_channels,
591
+ n_heads,
592
+ n_layers,
593
+ kernel_size,
594
+ p_dropout,
595
+ )
596
+ self.dec = GeneratorNSF(
597
+ inter_channels,
598
+ resblock,
599
+ resblock_kernel_sizes,
600
+ resblock_dilation_sizes,
601
+ upsample_rates,
602
+ upsample_initial_channel,
603
+ upsample_kernel_sizes,
604
+ gin_channels=gin_channels,
605
+ sr=sr,
606
+ is_half=kwargs["is_half"],
607
+ )
608
+ self.enc_q = PosteriorEncoder(
609
+ spec_channels,
610
+ inter_channels,
611
+ hidden_channels,
612
+ 5,
613
+ 1,
614
+ 16,
615
+ gin_channels=gin_channels,
616
+ )
617
+ self.flow = ResidualCouplingBlock(
618
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
619
+ )
620
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
621
+ self.speaker_map = None
622
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
623
+
624
+ def remove_weight_norm(self):
625
+ self.dec.remove_weight_norm()
626
+ self.flow.remove_weight_norm()
627
+ self.enc_q.remove_weight_norm()
628
+
629
+ def construct_spkmixmap(self, n_speaker):
630
+ self.speaker_map = torch.zeros((n_speaker, 1, 1, self.gin_channels))
631
+ for i in range(n_speaker):
632
+ self.speaker_map[i] = self.emb_g(torch.LongTensor([[i]]))
633
+ self.speaker_map = self.speaker_map.unsqueeze(0)
634
+
635
+ def forward(self, phone, phone_lengths, pitch, nsff0, g, rnd, max_len=None):
636
+ if self.speaker_map is not None: # [N, S] * [S, B, 1, H]
637
+ g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1]
638
+ g = g * self.speaker_map # [N, S, B, 1, H]
639
+ g = torch.sum(g, dim=1) # [N, 1, B, 1, H]
640
+ g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N]
641
+ else:
642
+ g = g.unsqueeze(0)
643
+ g = self.emb_g(g).transpose(1, 2)
644
+
645
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
646
+ z_p = (m_p + torch.exp(logs_p) * rnd) * x_mask
647
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
648
+ o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
649
+ return o
650
+
651
+
652
+ class MultiPeriodDiscriminator(torch.nn.Module):
653
+ def __init__(self, use_spectral_norm=False):
654
+ super(MultiPeriodDiscriminator, self).__init__()
655
+ periods = [2, 3, 5, 7, 11, 17]
656
+ # periods = [3, 5, 7, 11, 17, 23, 37]
657
+
658
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
659
+ discs = discs + [
660
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
661
+ ]
662
+ self.discriminators = nn.ModuleList(discs)
663
+
664
+ def forward(self, y, y_hat):
665
+ y_d_rs = [] #
666
+ y_d_gs = []
667
+ fmap_rs = []
668
+ fmap_gs = []
669
+ for i, d in enumerate(self.discriminators):
670
+ y_d_r, fmap_r = d(y)
671
+ y_d_g, fmap_g = d(y_hat)
672
+ # for j in range(len(fmap_r)):
673
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
674
+ y_d_rs.append(y_d_r)
675
+ y_d_gs.append(y_d_g)
676
+ fmap_rs.append(fmap_r)
677
+ fmap_gs.append(fmap_g)
678
+
679
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
680
+
681
+
682
+ class MultiPeriodDiscriminatorV2(torch.nn.Module):
683
+ def __init__(self, use_spectral_norm=False):
684
+ super(MultiPeriodDiscriminatorV2, self).__init__()
685
+ # periods = [2, 3, 5, 7, 11, 17]
686
+ periods = [2, 3, 5, 7, 11, 17, 23, 37]
687
+
688
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
689
+ discs = discs + [
690
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
691
+ ]
692
+ self.discriminators = nn.ModuleList(discs)
693
+
694
+ def forward(self, y, y_hat):
695
+ y_d_rs = [] #
696
+ y_d_gs = []
697
+ fmap_rs = []
698
+ fmap_gs = []
699
+ for i, d in enumerate(self.discriminators):
700
+ y_d_r, fmap_r = d(y)
701
+ y_d_g, fmap_g = d(y_hat)
702
+ # for j in range(len(fmap_r)):
703
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
704
+ y_d_rs.append(y_d_r)
705
+ y_d_gs.append(y_d_g)
706
+ fmap_rs.append(fmap_r)
707
+ fmap_gs.append(fmap_g)
708
+
709
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
710
+
711
+
712
+ class DiscriminatorS(torch.nn.Module):
713
+ def __init__(self, use_spectral_norm=False):
714
+ super(DiscriminatorS, self).__init__()
715
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
716
+ self.convs = nn.ModuleList(
717
+ [
718
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
719
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
720
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
721
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
722
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
723
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
724
+ ]
725
+ )
726
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
727
+
728
+ def forward(self, x):
729
+ fmap = []
730
+
731
+ for l in self.convs:
732
+ x = l(x)
733
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
734
+ fmap.append(x)
735
+ x = self.conv_post(x)
736
+ fmap.append(x)
737
+ x = torch.flatten(x, 1, -1)
738
+
739
+ return x, fmap
740
+
741
+
742
+ class DiscriminatorP(torch.nn.Module):
743
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
744
+ super(DiscriminatorP, self).__init__()
745
+ self.period = period
746
+ self.use_spectral_norm = use_spectral_norm
747
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
748
+ self.convs = nn.ModuleList(
749
+ [
750
+ norm_f(
751
+ Conv2d(
752
+ 1,
753
+ 32,
754
+ (kernel_size, 1),
755
+ (stride, 1),
756
+ padding=(get_padding(kernel_size, 1), 0),
757
+ )
758
+ ),
759
+ norm_f(
760
+ Conv2d(
761
+ 32,
762
+ 128,
763
+ (kernel_size, 1),
764
+ (stride, 1),
765
+ padding=(get_padding(kernel_size, 1), 0),
766
+ )
767
+ ),
768
+ norm_f(
769
+ Conv2d(
770
+ 128,
771
+ 512,
772
+ (kernel_size, 1),
773
+ (stride, 1),
774
+ padding=(get_padding(kernel_size, 1), 0),
775
+ )
776
+ ),
777
+ norm_f(
778
+ Conv2d(
779
+ 512,
780
+ 1024,
781
+ (kernel_size, 1),
782
+ (stride, 1),
783
+ padding=(get_padding(kernel_size, 1), 0),
784
+ )
785
+ ),
786
+ norm_f(
787
+ Conv2d(
788
+ 1024,
789
+ 1024,
790
+ (kernel_size, 1),
791
+ 1,
792
+ padding=(get_padding(kernel_size, 1), 0),
793
+ )
794
+ ),
795
+ ]
796
+ )
797
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
798
+
799
+ def forward(self, x):
800
+ fmap = []
801
+
802
+ # 1d to 2d
803
+ b, c, t = x.shape
804
+ if t % self.period != 0: # pad first
805
+ n_pad = self.period - (t % self.period)
806
+ x = F.pad(x, (0, n_pad), "reflect")
807
+ t = t + n_pad
808
+ x = x.view(b, c, t // self.period, self.period)
809
+
810
+ for l in self.convs:
811
+ x = l(x)
812
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
813
+ fmap.append(x)
814
+ x = self.conv_post(x)
815
+ fmap.append(x)
816
+ x = torch.flatten(x, 1, -1)
817
+
818
+ return x, fmap
src/infer_pack/models_onnx_moess.py ADDED
@@ -0,0 +1,849 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math, pdb, os
2
+ from time import time as ttime
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from infer_pack import modules
7
+ from infer_pack import attentions
8
+ from infer_pack import commons
9
+ from infer_pack.commons import init_weights, get_padding
10
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+ from infer_pack.commons import init_weights
13
+ import numpy as np
14
+ from infer_pack import commons
15
+
16
+
17
+ class TextEncoder256(nn.Module):
18
+ def __init__(
19
+ self,
20
+ out_channels,
21
+ hidden_channels,
22
+ filter_channels,
23
+ n_heads,
24
+ n_layers,
25
+ kernel_size,
26
+ p_dropout,
27
+ f0=True,
28
+ ):
29
+ super().__init__()
30
+ self.out_channels = out_channels
31
+ self.hidden_channels = hidden_channels
32
+ self.filter_channels = filter_channels
33
+ self.n_heads = n_heads
34
+ self.n_layers = n_layers
35
+ self.kernel_size = kernel_size
36
+ self.p_dropout = p_dropout
37
+ self.emb_phone = nn.Linear(256, hidden_channels)
38
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
39
+ if f0 == True:
40
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
41
+ self.encoder = attentions.Encoder(
42
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
43
+ )
44
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
45
+
46
+ def forward(self, phone, pitch, lengths):
47
+ if pitch == None:
48
+ x = self.emb_phone(phone)
49
+ else:
50
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
51
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
52
+ x = self.lrelu(x)
53
+ x = torch.transpose(x, 1, -1) # [b, h, t]
54
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
55
+ x.dtype
56
+ )
57
+ x = self.encoder(x * x_mask, x_mask)
58
+ stats = self.proj(x) * x_mask
59
+
60
+ m, logs = torch.split(stats, self.out_channels, dim=1)
61
+ return m, logs, x_mask
62
+
63
+
64
+ class TextEncoder256Sim(nn.Module):
65
+ def __init__(
66
+ self,
67
+ out_channels,
68
+ hidden_channels,
69
+ filter_channels,
70
+ n_heads,
71
+ n_layers,
72
+ kernel_size,
73
+ p_dropout,
74
+ f0=True,
75
+ ):
76
+ super().__init__()
77
+ self.out_channels = out_channels
78
+ self.hidden_channels = hidden_channels
79
+ self.filter_channels = filter_channels
80
+ self.n_heads = n_heads
81
+ self.n_layers = n_layers
82
+ self.kernel_size = kernel_size
83
+ self.p_dropout = p_dropout
84
+ self.emb_phone = nn.Linear(256, hidden_channels)
85
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
86
+ if f0 == True:
87
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
88
+ self.encoder = attentions.Encoder(
89
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
90
+ )
91
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
92
+
93
+ def forward(self, phone, pitch, lengths):
94
+ if pitch == None:
95
+ x = self.emb_phone(phone)
96
+ else:
97
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
98
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
99
+ x = self.lrelu(x)
100
+ x = torch.transpose(x, 1, -1) # [b, h, t]
101
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
102
+ x.dtype
103
+ )
104
+ x = self.encoder(x * x_mask, x_mask)
105
+ x = self.proj(x) * x_mask
106
+ return x, x_mask
107
+
108
+
109
+ class ResidualCouplingBlock(nn.Module):
110
+ def __init__(
111
+ self,
112
+ channels,
113
+ hidden_channels,
114
+ kernel_size,
115
+ dilation_rate,
116
+ n_layers,
117
+ n_flows=4,
118
+ gin_channels=0,
119
+ ):
120
+ super().__init__()
121
+ self.channels = channels
122
+ self.hidden_channels = hidden_channels
123
+ self.kernel_size = kernel_size
124
+ self.dilation_rate = dilation_rate
125
+ self.n_layers = n_layers
126
+ self.n_flows = n_flows
127
+ self.gin_channels = gin_channels
128
+
129
+ self.flows = nn.ModuleList()
130
+ for i in range(n_flows):
131
+ self.flows.append(
132
+ modules.ResidualCouplingLayer(
133
+ channels,
134
+ hidden_channels,
135
+ kernel_size,
136
+ dilation_rate,
137
+ n_layers,
138
+ gin_channels=gin_channels,
139
+ mean_only=True,
140
+ )
141
+ )
142
+ self.flows.append(modules.Flip())
143
+
144
+ def forward(self, x, x_mask, g=None, reverse=False):
145
+ if not reverse:
146
+ for flow in self.flows:
147
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
148
+ else:
149
+ for flow in reversed(self.flows):
150
+ x = flow(x, x_mask, g=g, reverse=reverse)
151
+ return x
152
+
153
+ def remove_weight_norm(self):
154
+ for i in range(self.n_flows):
155
+ self.flows[i * 2].remove_weight_norm()
156
+
157
+
158
+ class PosteriorEncoder(nn.Module):
159
+ def __init__(
160
+ self,
161
+ in_channels,
162
+ out_channels,
163
+ hidden_channels,
164
+ kernel_size,
165
+ dilation_rate,
166
+ n_layers,
167
+ gin_channels=0,
168
+ ):
169
+ super().__init__()
170
+ self.in_channels = in_channels
171
+ self.out_channels = out_channels
172
+ self.hidden_channels = hidden_channels
173
+ self.kernel_size = kernel_size
174
+ self.dilation_rate = dilation_rate
175
+ self.n_layers = n_layers
176
+ self.gin_channels = gin_channels
177
+
178
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
179
+ self.enc = modules.WN(
180
+ hidden_channels,
181
+ kernel_size,
182
+ dilation_rate,
183
+ n_layers,
184
+ gin_channels=gin_channels,
185
+ )
186
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
187
+
188
+ def forward(self, x, x_lengths, g=None):
189
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
190
+ x.dtype
191
+ )
192
+ x = self.pre(x) * x_mask
193
+ x = self.enc(x, x_mask, g=g)
194
+ stats = self.proj(x) * x_mask
195
+ m, logs = torch.split(stats, self.out_channels, dim=1)
196
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
197
+ return z, m, logs, x_mask
198
+
199
+ def remove_weight_norm(self):
200
+ self.enc.remove_weight_norm()
201
+
202
+
203
+ class Generator(torch.nn.Module):
204
+ def __init__(
205
+ self,
206
+ initial_channel,
207
+ resblock,
208
+ resblock_kernel_sizes,
209
+ resblock_dilation_sizes,
210
+ upsample_rates,
211
+ upsample_initial_channel,
212
+ upsample_kernel_sizes,
213
+ gin_channels=0,
214
+ ):
215
+ super(Generator, self).__init__()
216
+ self.num_kernels = len(resblock_kernel_sizes)
217
+ self.num_upsamples = len(upsample_rates)
218
+ self.conv_pre = Conv1d(
219
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
220
+ )
221
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
222
+
223
+ self.ups = nn.ModuleList()
224
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
225
+ self.ups.append(
226
+ weight_norm(
227
+ ConvTranspose1d(
228
+ upsample_initial_channel // (2**i),
229
+ upsample_initial_channel // (2 ** (i + 1)),
230
+ k,
231
+ u,
232
+ padding=(k - u) // 2,
233
+ )
234
+ )
235
+ )
236
+
237
+ self.resblocks = nn.ModuleList()
238
+ for i in range(len(self.ups)):
239
+ ch = upsample_initial_channel // (2 ** (i + 1))
240
+ for j, (k, d) in enumerate(
241
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
242
+ ):
243
+ self.resblocks.append(resblock(ch, k, d))
244
+
245
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
246
+ self.ups.apply(init_weights)
247
+
248
+ if gin_channels != 0:
249
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
250
+
251
+ def forward(self, x, g=None):
252
+ x = self.conv_pre(x)
253
+ if g is not None:
254
+ x = x + self.cond(g)
255
+
256
+ for i in range(self.num_upsamples):
257
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
258
+ x = self.ups[i](x)
259
+ xs = None
260
+ for j in range(self.num_kernels):
261
+ if xs is None:
262
+ xs = self.resblocks[i * self.num_kernels + j](x)
263
+ else:
264
+ xs += self.resblocks[i * self.num_kernels + j](x)
265
+ x = xs / self.num_kernels
266
+ x = F.leaky_relu(x)
267
+ x = self.conv_post(x)
268
+ x = torch.tanh(x)
269
+
270
+ return x
271
+
272
+ def remove_weight_norm(self):
273
+ for l in self.ups:
274
+ remove_weight_norm(l)
275
+ for l in self.resblocks:
276
+ l.remove_weight_norm()
277
+
278
+
279
+ class SineGen(torch.nn.Module):
280
+ """Definition of sine generator
281
+ SineGen(samp_rate, harmonic_num = 0,
282
+ sine_amp = 0.1, noise_std = 0.003,
283
+ voiced_threshold = 0,
284
+ flag_for_pulse=False)
285
+ samp_rate: sampling rate in Hz
286
+ harmonic_num: number of harmonic overtones (default 0)
287
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
288
+ noise_std: std of Gaussian noise (default 0.003)
289
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
290
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
291
+ Note: when flag_for_pulse is True, the first time step of a voiced
292
+ segment is always sin(np.pi) or cos(0)
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ samp_rate,
298
+ harmonic_num=0,
299
+ sine_amp=0.1,
300
+ noise_std=0.003,
301
+ voiced_threshold=0,
302
+ flag_for_pulse=False,
303
+ ):
304
+ super(SineGen, self).__init__()
305
+ self.sine_amp = sine_amp
306
+ self.noise_std = noise_std
307
+ self.harmonic_num = harmonic_num
308
+ self.dim = self.harmonic_num + 1
309
+ self.sampling_rate = samp_rate
310
+ self.voiced_threshold = voiced_threshold
311
+
312
+ def _f02uv(self, f0):
313
+ # generate uv signal
314
+ uv = torch.ones_like(f0)
315
+ uv = uv * (f0 > self.voiced_threshold)
316
+ return uv
317
+
318
+ def forward(self, f0, upp):
319
+ """sine_tensor, uv = forward(f0)
320
+ input F0: tensor(batchsize=1, length, dim=1)
321
+ f0 for unvoiced steps should be 0
322
+ output sine_tensor: tensor(batchsize=1, length, dim)
323
+ output uv: tensor(batchsize=1, length, 1)
324
+ """
325
+ with torch.no_grad():
326
+ f0 = f0[:, None].transpose(1, 2)
327
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
328
+ # fundamental component
329
+ f0_buf[:, :, 0] = f0[:, :, 0]
330
+ for idx in np.arange(self.harmonic_num):
331
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
332
+ idx + 2
333
+ ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
334
+ rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
335
+ rand_ini = torch.rand(
336
+ f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
337
+ )
338
+ rand_ini[:, 0] = 0
339
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
340
+ tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
341
+ tmp_over_one *= upp
342
+ tmp_over_one = F.interpolate(
343
+ tmp_over_one.transpose(2, 1),
344
+ scale_factor=upp,
345
+ mode="linear",
346
+ align_corners=True,
347
+ ).transpose(2, 1)
348
+ rad_values = F.interpolate(
349
+ rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
350
+ ).transpose(
351
+ 2, 1
352
+ ) #######
353
+ tmp_over_one %= 1
354
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
355
+ cumsum_shift = torch.zeros_like(rad_values)
356
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
357
+ sine_waves = torch.sin(
358
+ torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
359
+ )
360
+ sine_waves = sine_waves * self.sine_amp
361
+ uv = self._f02uv(f0)
362
+ uv = F.interpolate(
363
+ uv.transpose(2, 1), scale_factor=upp, mode="nearest"
364
+ ).transpose(2, 1)
365
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
366
+ noise = noise_amp * torch.randn_like(sine_waves)
367
+ sine_waves = sine_waves * uv + noise
368
+ return sine_waves, uv, noise
369
+
370
+
371
+ class SourceModuleHnNSF(torch.nn.Module):
372
+ """SourceModule for hn-nsf
373
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
374
+ add_noise_std=0.003, voiced_threshod=0)
375
+ sampling_rate: sampling_rate in Hz
376
+ harmonic_num: number of harmonic above F0 (default: 0)
377
+ sine_amp: amplitude of sine source signal (default: 0.1)
378
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
379
+ note that amplitude of noise in unvoiced is decided
380
+ by sine_amp
381
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
382
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
383
+ F0_sampled (batchsize, length, 1)
384
+ Sine_source (batchsize, length, 1)
385
+ noise_source (batchsize, length 1)
386
+ uv (batchsize, length, 1)
387
+ """
388
+
389
+ def __init__(
390
+ self,
391
+ sampling_rate,
392
+ harmonic_num=0,
393
+ sine_amp=0.1,
394
+ add_noise_std=0.003,
395
+ voiced_threshod=0,
396
+ is_half=True,
397
+ ):
398
+ super(SourceModuleHnNSF, self).__init__()
399
+
400
+ self.sine_amp = sine_amp
401
+ self.noise_std = add_noise_std
402
+ self.is_half = is_half
403
+ # to produce sine waveforms
404
+ self.l_sin_gen = SineGen(
405
+ sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
406
+ )
407
+
408
+ # to merge source harmonics into a single excitation
409
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
410
+ self.l_tanh = torch.nn.Tanh()
411
+
412
+ def forward(self, x, upp=None):
413
+ sine_wavs, uv, _ = self.l_sin_gen(x, upp)
414
+ if self.is_half:
415
+ sine_wavs = sine_wavs.half()
416
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
417
+ return sine_merge, None, None # noise, uv
418
+
419
+
420
+ class GeneratorNSF(torch.nn.Module):
421
+ def __init__(
422
+ self,
423
+ initial_channel,
424
+ resblock,
425
+ resblock_kernel_sizes,
426
+ resblock_dilation_sizes,
427
+ upsample_rates,
428
+ upsample_initial_channel,
429
+ upsample_kernel_sizes,
430
+ gin_channels,
431
+ sr,
432
+ is_half=False,
433
+ ):
434
+ super(GeneratorNSF, self).__init__()
435
+ self.num_kernels = len(resblock_kernel_sizes)
436
+ self.num_upsamples = len(upsample_rates)
437
+
438
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
439
+ self.m_source = SourceModuleHnNSF(
440
+ sampling_rate=sr, harmonic_num=0, is_half=is_half
441
+ )
442
+ self.noise_convs = nn.ModuleList()
443
+ self.conv_pre = Conv1d(
444
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
445
+ )
446
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
447
+
448
+ self.ups = nn.ModuleList()
449
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
450
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
451
+ self.ups.append(
452
+ weight_norm(
453
+ ConvTranspose1d(
454
+ upsample_initial_channel // (2**i),
455
+ upsample_initial_channel // (2 ** (i + 1)),
456
+ k,
457
+ u,
458
+ padding=(k - u) // 2,
459
+ )
460
+ )
461
+ )
462
+ if i + 1 < len(upsample_rates):
463
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
464
+ self.noise_convs.append(
465
+ Conv1d(
466
+ 1,
467
+ c_cur,
468
+ kernel_size=stride_f0 * 2,
469
+ stride=stride_f0,
470
+ padding=stride_f0 // 2,
471
+ )
472
+ )
473
+ else:
474
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
475
+
476
+ self.resblocks = nn.ModuleList()
477
+ for i in range(len(self.ups)):
478
+ ch = upsample_initial_channel // (2 ** (i + 1))
479
+ for j, (k, d) in enumerate(
480
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
481
+ ):
482
+ self.resblocks.append(resblock(ch, k, d))
483
+
484
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
485
+ self.ups.apply(init_weights)
486
+
487
+ if gin_channels != 0:
488
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
489
+
490
+ self.upp = np.prod(upsample_rates)
491
+
492
+ def forward(self, x, f0, g=None):
493
+ har_source, noi_source, uv = self.m_source(f0, self.upp)
494
+ har_source = har_source.transpose(1, 2)
495
+ x = self.conv_pre(x)
496
+ if g is not None:
497
+ x = x + self.cond(g)
498
+
499
+ for i in range(self.num_upsamples):
500
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
501
+ x = self.ups[i](x)
502
+ x_source = self.noise_convs[i](har_source)
503
+ x = x + x_source
504
+ xs = None
505
+ for j in range(self.num_kernels):
506
+ if xs is None:
507
+ xs = self.resblocks[i * self.num_kernels + j](x)
508
+ else:
509
+ xs += self.resblocks[i * self.num_kernels + j](x)
510
+ x = xs / self.num_kernels
511
+ x = F.leaky_relu(x)
512
+ x = self.conv_post(x)
513
+ x = torch.tanh(x)
514
+ return x
515
+
516
+ def remove_weight_norm(self):
517
+ for l in self.ups:
518
+ remove_weight_norm(l)
519
+ for l in self.resblocks:
520
+ l.remove_weight_norm()
521
+
522
+
523
+ sr2sr = {
524
+ "32k": 32000,
525
+ "40k": 40000,
526
+ "48k": 48000,
527
+ }
528
+
529
+
530
+ class SynthesizerTrnMs256NSFsidM(nn.Module):
531
+ def __init__(
532
+ self,
533
+ spec_channels,
534
+ segment_size,
535
+ inter_channels,
536
+ hidden_channels,
537
+ filter_channels,
538
+ n_heads,
539
+ n_layers,
540
+ kernel_size,
541
+ p_dropout,
542
+ resblock,
543
+ resblock_kernel_sizes,
544
+ resblock_dilation_sizes,
545
+ upsample_rates,
546
+ upsample_initial_channel,
547
+ upsample_kernel_sizes,
548
+ spk_embed_dim,
549
+ gin_channels,
550
+ sr,
551
+ **kwargs
552
+ ):
553
+ super().__init__()
554
+ if type(sr) == type("strr"):
555
+ sr = sr2sr[sr]
556
+ self.spec_channels = spec_channels
557
+ self.inter_channels = inter_channels
558
+ self.hidden_channels = hidden_channels
559
+ self.filter_channels = filter_channels
560
+ self.n_heads = n_heads
561
+ self.n_layers = n_layers
562
+ self.kernel_size = kernel_size
563
+ self.p_dropout = p_dropout
564
+ self.resblock = resblock
565
+ self.resblock_kernel_sizes = resblock_kernel_sizes
566
+ self.resblock_dilation_sizes = resblock_dilation_sizes
567
+ self.upsample_rates = upsample_rates
568
+ self.upsample_initial_channel = upsample_initial_channel
569
+ self.upsample_kernel_sizes = upsample_kernel_sizes
570
+ self.segment_size = segment_size
571
+ self.gin_channels = gin_channels
572
+ # self.hop_length = hop_length#
573
+ self.spk_embed_dim = spk_embed_dim
574
+ self.enc_p = TextEncoder256(
575
+ inter_channels,
576
+ hidden_channels,
577
+ filter_channels,
578
+ n_heads,
579
+ n_layers,
580
+ kernel_size,
581
+ p_dropout,
582
+ )
583
+ self.dec = GeneratorNSF(
584
+ inter_channels,
585
+ resblock,
586
+ resblock_kernel_sizes,
587
+ resblock_dilation_sizes,
588
+ upsample_rates,
589
+ upsample_initial_channel,
590
+ upsample_kernel_sizes,
591
+ gin_channels=gin_channels,
592
+ sr=sr,
593
+ is_half=kwargs["is_half"],
594
+ )
595
+ self.enc_q = PosteriorEncoder(
596
+ spec_channels,
597
+ inter_channels,
598
+ hidden_channels,
599
+ 5,
600
+ 1,
601
+ 16,
602
+ gin_channels=gin_channels,
603
+ )
604
+ self.flow = ResidualCouplingBlock(
605
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
606
+ )
607
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
608
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
609
+
610
+ def remove_weight_norm(self):
611
+ self.dec.remove_weight_norm()
612
+ self.flow.remove_weight_norm()
613
+ self.enc_q.remove_weight_norm()
614
+
615
+ def forward(self, phone, phone_lengths, pitch, nsff0, sid, rnd, max_len=None):
616
+ g = self.emb_g(sid).unsqueeze(-1)
617
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
618
+ z_p = (m_p + torch.exp(logs_p) * rnd) * x_mask
619
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
620
+ o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
621
+ return o
622
+
623
+
624
+ class SynthesizerTrnMs256NSFsid_sim(nn.Module):
625
+ """
626
+ Synthesizer for Training
627
+ """
628
+
629
+ def __init__(
630
+ self,
631
+ spec_channels,
632
+ segment_size,
633
+ inter_channels,
634
+ hidden_channels,
635
+ filter_channels,
636
+ n_heads,
637
+ n_layers,
638
+ kernel_size,
639
+ p_dropout,
640
+ resblock,
641
+ resblock_kernel_sizes,
642
+ resblock_dilation_sizes,
643
+ upsample_rates,
644
+ upsample_initial_channel,
645
+ upsample_kernel_sizes,
646
+ spk_embed_dim,
647
+ # hop_length,
648
+ gin_channels=0,
649
+ use_sdp=True,
650
+ **kwargs
651
+ ):
652
+ super().__init__()
653
+ self.spec_channels = spec_channels
654
+ self.inter_channels = inter_channels
655
+ self.hidden_channels = hidden_channels
656
+ self.filter_channels = filter_channels
657
+ self.n_heads = n_heads
658
+ self.n_layers = n_layers
659
+ self.kernel_size = kernel_size
660
+ self.p_dropout = p_dropout
661
+ self.resblock = resblock
662
+ self.resblock_kernel_sizes = resblock_kernel_sizes
663
+ self.resblock_dilation_sizes = resblock_dilation_sizes
664
+ self.upsample_rates = upsample_rates
665
+ self.upsample_initial_channel = upsample_initial_channel
666
+ self.upsample_kernel_sizes = upsample_kernel_sizes
667
+ self.segment_size = segment_size
668
+ self.gin_channels = gin_channels
669
+ # self.hop_length = hop_length#
670
+ self.spk_embed_dim = spk_embed_dim
671
+ self.enc_p = TextEncoder256Sim(
672
+ inter_channels,
673
+ hidden_channels,
674
+ filter_channels,
675
+ n_heads,
676
+ n_layers,
677
+ kernel_size,
678
+ p_dropout,
679
+ )
680
+ self.dec = GeneratorNSF(
681
+ inter_channels,
682
+ resblock,
683
+ resblock_kernel_sizes,
684
+ resblock_dilation_sizes,
685
+ upsample_rates,
686
+ upsample_initial_channel,
687
+ upsample_kernel_sizes,
688
+ gin_channels=gin_channels,
689
+ is_half=kwargs["is_half"],
690
+ )
691
+
692
+ self.flow = ResidualCouplingBlock(
693
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
694
+ )
695
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
696
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
697
+
698
+ def remove_weight_norm(self):
699
+ self.dec.remove_weight_norm()
700
+ self.flow.remove_weight_norm()
701
+ self.enc_q.remove_weight_norm()
702
+
703
+ def forward(
704
+ self, phone, phone_lengths, pitch, pitchf, ds, max_len=None
705
+ ): # y是spec不需要了现在
706
+ g = self.emb_g(ds.unsqueeze(0)).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
707
+ x, x_mask = self.enc_p(phone, pitch, phone_lengths)
708
+ x = self.flow(x, x_mask, g=g, reverse=True)
709
+ o = self.dec((x * x_mask)[:, :, :max_len], pitchf, g=g)
710
+ return o
711
+
712
+
713
+ class MultiPeriodDiscriminator(torch.nn.Module):
714
+ def __init__(self, use_spectral_norm=False):
715
+ super(MultiPeriodDiscriminator, self).__init__()
716
+ periods = [2, 3, 5, 7, 11, 17]
717
+ # periods = [3, 5, 7, 11, 17, 23, 37]
718
+
719
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
720
+ discs = discs + [
721
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
722
+ ]
723
+ self.discriminators = nn.ModuleList(discs)
724
+
725
+ def forward(self, y, y_hat):
726
+ y_d_rs = [] #
727
+ y_d_gs = []
728
+ fmap_rs = []
729
+ fmap_gs = []
730
+ for i, d in enumerate(self.discriminators):
731
+ y_d_r, fmap_r = d(y)
732
+ y_d_g, fmap_g = d(y_hat)
733
+ # for j in range(len(fmap_r)):
734
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
735
+ y_d_rs.append(y_d_r)
736
+ y_d_gs.append(y_d_g)
737
+ fmap_rs.append(fmap_r)
738
+ fmap_gs.append(fmap_g)
739
+
740
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
741
+
742
+
743
+ class DiscriminatorS(torch.nn.Module):
744
+ def __init__(self, use_spectral_norm=False):
745
+ super(DiscriminatorS, self).__init__()
746
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
747
+ self.convs = nn.ModuleList(
748
+ [
749
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
750
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
751
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
752
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
753
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
754
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
755
+ ]
756
+ )
757
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
758
+
759
+ def forward(self, x):
760
+ fmap = []
761
+
762
+ for l in self.convs:
763
+ x = l(x)
764
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
765
+ fmap.append(x)
766
+ x = self.conv_post(x)
767
+ fmap.append(x)
768
+ x = torch.flatten(x, 1, -1)
769
+
770
+ return x, fmap
771
+
772
+
773
+ class DiscriminatorP(torch.nn.Module):
774
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
775
+ super(DiscriminatorP, self).__init__()
776
+ self.period = period
777
+ self.use_spectral_norm = use_spectral_norm
778
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
779
+ self.convs = nn.ModuleList(
780
+ [
781
+ norm_f(
782
+ Conv2d(
783
+ 1,
784
+ 32,
785
+ (kernel_size, 1),
786
+ (stride, 1),
787
+ padding=(get_padding(kernel_size, 1), 0),
788
+ )
789
+ ),
790
+ norm_f(
791
+ Conv2d(
792
+ 32,
793
+ 128,
794
+ (kernel_size, 1),
795
+ (stride, 1),
796
+ padding=(get_padding(kernel_size, 1), 0),
797
+ )
798
+ ),
799
+ norm_f(
800
+ Conv2d(
801
+ 128,
802
+ 512,
803
+ (kernel_size, 1),
804
+ (stride, 1),
805
+ padding=(get_padding(kernel_size, 1), 0),
806
+ )
807
+ ),
808
+ norm_f(
809
+ Conv2d(
810
+ 512,
811
+ 1024,
812
+ (kernel_size, 1),
813
+ (stride, 1),
814
+ padding=(get_padding(kernel_size, 1), 0),
815
+ )
816
+ ),
817
+ norm_f(
818
+ Conv2d(
819
+ 1024,
820
+ 1024,
821
+ (kernel_size, 1),
822
+ 1,
823
+ padding=(get_padding(kernel_size, 1), 0),
824
+ )
825
+ ),
826
+ ]
827
+ )
828
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
829
+
830
+ def forward(self, x):
831
+ fmap = []
832
+
833
+ # 1d to 2d
834
+ b, c, t = x.shape
835
+ if t % self.period != 0: # pad first
836
+ n_pad = self.period - (t % self.period)
837
+ x = F.pad(x, (0, n_pad), "reflect")
838
+ t = t + n_pad
839
+ x = x.view(b, c, t // self.period, self.period)
840
+
841
+ for l in self.convs:
842
+ x = l(x)
843
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
844
+ fmap.append(x)
845
+ x = self.conv_post(x)
846
+ fmap.append(x)
847
+ x = torch.flatten(x, 1, -1)
848
+
849
+ return x, fmap
src/infer_pack/modules.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import scipy
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
+ from torch.nn.utils import weight_norm, remove_weight_norm
11
+
12
+ from infer_pack import commons
13
+ from infer_pack.commons import init_weights, get_padding
14
+ from infer_pack.transforms import piecewise_rational_quadratic_transform
15
+
16
+
17
+ LRELU_SLOPE = 0.1
18
+
19
+
20
+ class LayerNorm(nn.Module):
21
+ def __init__(self, channels, eps=1e-5):
22
+ super().__init__()
23
+ self.channels = channels
24
+ self.eps = eps
25
+
26
+ self.gamma = nn.Parameter(torch.ones(channels))
27
+ self.beta = nn.Parameter(torch.zeros(channels))
28
+
29
+ def forward(self, x):
30
+ x = x.transpose(1, -1)
31
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32
+ return x.transpose(1, -1)
33
+
34
+
35
+ class ConvReluNorm(nn.Module):
36
+ def __init__(
37
+ self,
38
+ in_channels,
39
+ hidden_channels,
40
+ out_channels,
41
+ kernel_size,
42
+ n_layers,
43
+ p_dropout,
44
+ ):
45
+ super().__init__()
46
+ self.in_channels = in_channels
47
+ self.hidden_channels = hidden_channels
48
+ self.out_channels = out_channels
49
+ self.kernel_size = kernel_size
50
+ self.n_layers = n_layers
51
+ self.p_dropout = p_dropout
52
+ assert n_layers > 1, "Number of layers should be larger than 0."
53
+
54
+ self.conv_layers = nn.ModuleList()
55
+ self.norm_layers = nn.ModuleList()
56
+ self.conv_layers.append(
57
+ nn.Conv1d(
58
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
59
+ )
60
+ )
61
+ self.norm_layers.append(LayerNorm(hidden_channels))
62
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
63
+ for _ in range(n_layers - 1):
64
+ self.conv_layers.append(
65
+ nn.Conv1d(
66
+ hidden_channels,
67
+ hidden_channels,
68
+ kernel_size,
69
+ padding=kernel_size // 2,
70
+ )
71
+ )
72
+ self.norm_layers.append(LayerNorm(hidden_channels))
73
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
74
+ self.proj.weight.data.zero_()
75
+ self.proj.bias.data.zero_()
76
+
77
+ def forward(self, x, x_mask):
78
+ x_org = x
79
+ for i in range(self.n_layers):
80
+ x = self.conv_layers[i](x * x_mask)
81
+ x = self.norm_layers[i](x)
82
+ x = self.relu_drop(x)
83
+ x = x_org + self.proj(x)
84
+ return x * x_mask
85
+
86
+
87
+ class DDSConv(nn.Module):
88
+ """
89
+ Dialted and Depth-Separable Convolution
90
+ """
91
+
92
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
93
+ super().__init__()
94
+ self.channels = channels
95
+ self.kernel_size = kernel_size
96
+ self.n_layers = n_layers
97
+ self.p_dropout = p_dropout
98
+
99
+ self.drop = nn.Dropout(p_dropout)
100
+ self.convs_sep = nn.ModuleList()
101
+ self.convs_1x1 = nn.ModuleList()
102
+ self.norms_1 = nn.ModuleList()
103
+ self.norms_2 = nn.ModuleList()
104
+ for i in range(n_layers):
105
+ dilation = kernel_size**i
106
+ padding = (kernel_size * dilation - dilation) // 2
107
+ self.convs_sep.append(
108
+ nn.Conv1d(
109
+ channels,
110
+ channels,
111
+ kernel_size,
112
+ groups=channels,
113
+ dilation=dilation,
114
+ padding=padding,
115
+ )
116
+ )
117
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
118
+ self.norms_1.append(LayerNorm(channels))
119
+ self.norms_2.append(LayerNorm(channels))
120
+
121
+ def forward(self, x, x_mask, g=None):
122
+ if g is not None:
123
+ x = x + g
124
+ for i in range(self.n_layers):
125
+ y = self.convs_sep[i](x * x_mask)
126
+ y = self.norms_1[i](y)
127
+ y = F.gelu(y)
128
+ y = self.convs_1x1[i](y)
129
+ y = self.norms_2[i](y)
130
+ y = F.gelu(y)
131
+ y = self.drop(y)
132
+ x = x + y
133
+ return x * x_mask
134
+
135
+
136
+ class WN(torch.nn.Module):
137
+ def __init__(
138
+ self,
139
+ hidden_channels,
140
+ kernel_size,
141
+ dilation_rate,
142
+ n_layers,
143
+ gin_channels=0,
144
+ p_dropout=0,
145
+ ):
146
+ super(WN, self).__init__()
147
+ assert kernel_size % 2 == 1
148
+ self.hidden_channels = hidden_channels
149
+ self.kernel_size = (kernel_size,)
150
+ self.dilation_rate = dilation_rate
151
+ self.n_layers = n_layers
152
+ self.gin_channels = gin_channels
153
+ self.p_dropout = p_dropout
154
+
155
+ self.in_layers = torch.nn.ModuleList()
156
+ self.res_skip_layers = torch.nn.ModuleList()
157
+ self.drop = nn.Dropout(p_dropout)
158
+
159
+ if gin_channels != 0:
160
+ cond_layer = torch.nn.Conv1d(
161
+ gin_channels, 2 * hidden_channels * n_layers, 1
162
+ )
163
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
164
+
165
+ for i in range(n_layers):
166
+ dilation = dilation_rate**i
167
+ padding = int((kernel_size * dilation - dilation) / 2)
168
+ in_layer = torch.nn.Conv1d(
169
+ hidden_channels,
170
+ 2 * hidden_channels,
171
+ kernel_size,
172
+ dilation=dilation,
173
+ padding=padding,
174
+ )
175
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
176
+ self.in_layers.append(in_layer)
177
+
178
+ # last one is not necessary
179
+ if i < n_layers - 1:
180
+ res_skip_channels = 2 * hidden_channels
181
+ else:
182
+ res_skip_channels = hidden_channels
183
+
184
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
185
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
186
+ self.res_skip_layers.append(res_skip_layer)
187
+
188
+ def forward(self, x, x_mask, g=None, **kwargs):
189
+ output = torch.zeros_like(x)
190
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
191
+
192
+ if g is not None:
193
+ g = self.cond_layer(g)
194
+
195
+ for i in range(self.n_layers):
196
+ x_in = self.in_layers[i](x)
197
+ if g is not None:
198
+ cond_offset = i * 2 * self.hidden_channels
199
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
200
+ else:
201
+ g_l = torch.zeros_like(x_in)
202
+
203
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
204
+ acts = self.drop(acts)
205
+
206
+ res_skip_acts = self.res_skip_layers[i](acts)
207
+ if i < self.n_layers - 1:
208
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
209
+ x = (x + res_acts) * x_mask
210
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
211
+ else:
212
+ output = output + res_skip_acts
213
+ return output * x_mask
214
+
215
+ def remove_weight_norm(self):
216
+ if self.gin_channels != 0:
217
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
218
+ for l in self.in_layers:
219
+ torch.nn.utils.remove_weight_norm(l)
220
+ for l in self.res_skip_layers:
221
+ torch.nn.utils.remove_weight_norm(l)
222
+
223
+
224
+ class ResBlock1(torch.nn.Module):
225
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
226
+ super(ResBlock1, self).__init__()
227
+ self.convs1 = nn.ModuleList(
228
+ [
229
+ weight_norm(
230
+ Conv1d(
231
+ channels,
232
+ channels,
233
+ kernel_size,
234
+ 1,
235
+ dilation=dilation[0],
236
+ padding=get_padding(kernel_size, dilation[0]),
237
+ )
238
+ ),
239
+ weight_norm(
240
+ Conv1d(
241
+ channels,
242
+ channels,
243
+ kernel_size,
244
+ 1,
245
+ dilation=dilation[1],
246
+ padding=get_padding(kernel_size, dilation[1]),
247
+ )
248
+ ),
249
+ weight_norm(
250
+ Conv1d(
251
+ channels,
252
+ channels,
253
+ kernel_size,
254
+ 1,
255
+ dilation=dilation[2],
256
+ padding=get_padding(kernel_size, dilation[2]),
257
+ )
258
+ ),
259
+ ]
260
+ )
261
+ self.convs1.apply(init_weights)
262
+
263
+ self.convs2 = nn.ModuleList(
264
+ [
265
+ weight_norm(
266
+ Conv1d(
267
+ channels,
268
+ channels,
269
+ kernel_size,
270
+ 1,
271
+ dilation=1,
272
+ padding=get_padding(kernel_size, 1),
273
+ )
274
+ ),
275
+ weight_norm(
276
+ Conv1d(
277
+ channels,
278
+ channels,
279
+ kernel_size,
280
+ 1,
281
+ dilation=1,
282
+ padding=get_padding(kernel_size, 1),
283
+ )
284
+ ),
285
+ weight_norm(
286
+ Conv1d(
287
+ channels,
288
+ channels,
289
+ kernel_size,
290
+ 1,
291
+ dilation=1,
292
+ padding=get_padding(kernel_size, 1),
293
+ )
294
+ ),
295
+ ]
296
+ )
297
+ self.convs2.apply(init_weights)
298
+
299
+ def forward(self, x, x_mask=None):
300
+ for c1, c2 in zip(self.convs1, self.convs2):
301
+ xt = F.leaky_relu(x, LRELU_SLOPE)
302
+ if x_mask is not None:
303
+ xt = xt * x_mask
304
+ xt = c1(xt)
305
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
306
+ if x_mask is not None:
307
+ xt = xt * x_mask
308
+ xt = c2(xt)
309
+ x = xt + x
310
+ if x_mask is not None:
311
+ x = x * x_mask
312
+ return x
313
+
314
+ def remove_weight_norm(self):
315
+ for l in self.convs1:
316
+ remove_weight_norm(l)
317
+ for l in self.convs2:
318
+ remove_weight_norm(l)
319
+
320
+
321
+ class ResBlock2(torch.nn.Module):
322
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
323
+ super(ResBlock2, self).__init__()
324
+ self.convs = nn.ModuleList(
325
+ [
326
+ weight_norm(
327
+ Conv1d(
328
+ channels,
329
+ channels,
330
+ kernel_size,
331
+ 1,
332
+ dilation=dilation[0],
333
+ padding=get_padding(kernel_size, dilation[0]),
334
+ )
335
+ ),
336
+ weight_norm(
337
+ Conv1d(
338
+ channels,
339
+ channels,
340
+ kernel_size,
341
+ 1,
342
+ dilation=dilation[1],
343
+ padding=get_padding(kernel_size, dilation[1]),
344
+ )
345
+ ),
346
+ ]
347
+ )
348
+ self.convs.apply(init_weights)
349
+
350
+ def forward(self, x, x_mask=None):
351
+ for c in self.convs:
352
+ xt = F.leaky_relu(x, LRELU_SLOPE)
353
+ if x_mask is not None:
354
+ xt = xt * x_mask
355
+ xt = c(xt)
356
+ x = xt + x
357
+ if x_mask is not None:
358
+ x = x * x_mask
359
+ return x
360
+
361
+ def remove_weight_norm(self):
362
+ for l in self.convs:
363
+ remove_weight_norm(l)
364
+
365
+
366
+ class Log(nn.Module):
367
+ def forward(self, x, x_mask, reverse=False, **kwargs):
368
+ if not reverse:
369
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
370
+ logdet = torch.sum(-y, [1, 2])
371
+ return y, logdet
372
+ else:
373
+ x = torch.exp(x) * x_mask
374
+ return x
375
+
376
+
377
+ class Flip(nn.Module):
378
+ def forward(self, x, *args, reverse=False, **kwargs):
379
+ x = torch.flip(x, [1])
380
+ if not reverse:
381
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
382
+ return x, logdet
383
+ else:
384
+ return x
385
+
386
+
387
+ class ElementwiseAffine(nn.Module):
388
+ def __init__(self, channels):
389
+ super().__init__()
390
+ self.channels = channels
391
+ self.m = nn.Parameter(torch.zeros(channels, 1))
392
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
393
+
394
+ def forward(self, x, x_mask, reverse=False, **kwargs):
395
+ if not reverse:
396
+ y = self.m + torch.exp(self.logs) * x
397
+ y = y * x_mask
398
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
399
+ return y, logdet
400
+ else:
401
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
402
+ return x
403
+
404
+
405
+ class ResidualCouplingLayer(nn.Module):
406
+ def __init__(
407
+ self,
408
+ channels,
409
+ hidden_channels,
410
+ kernel_size,
411
+ dilation_rate,
412
+ n_layers,
413
+ p_dropout=0,
414
+ gin_channels=0,
415
+ mean_only=False,
416
+ ):
417
+ assert channels % 2 == 0, "channels should be divisible by 2"
418
+ super().__init__()
419
+ self.channels = channels
420
+ self.hidden_channels = hidden_channels
421
+ self.kernel_size = kernel_size
422
+ self.dilation_rate = dilation_rate
423
+ self.n_layers = n_layers
424
+ self.half_channels = channels // 2
425
+ self.mean_only = mean_only
426
+
427
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
428
+ self.enc = WN(
429
+ hidden_channels,
430
+ kernel_size,
431
+ dilation_rate,
432
+ n_layers,
433
+ p_dropout=p_dropout,
434
+ gin_channels=gin_channels,
435
+ )
436
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
437
+ self.post.weight.data.zero_()
438
+ self.post.bias.data.zero_()
439
+
440
+ def forward(self, x, x_mask, g=None, reverse=False):
441
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
442
+ h = self.pre(x0) * x_mask
443
+ h = self.enc(h, x_mask, g=g)
444
+ stats = self.post(h) * x_mask
445
+ if not self.mean_only:
446
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
447
+ else:
448
+ m = stats
449
+ logs = torch.zeros_like(m)
450
+
451
+ if not reverse:
452
+ x1 = m + x1 * torch.exp(logs) * x_mask
453
+ x = torch.cat([x0, x1], 1)
454
+ logdet = torch.sum(logs, [1, 2])
455
+ return x, logdet
456
+ else:
457
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
458
+ x = torch.cat([x0, x1], 1)
459
+ return x
460
+
461
+ def remove_weight_norm(self):
462
+ self.enc.remove_weight_norm()
463
+
464
+
465
+ class ConvFlow(nn.Module):
466
+ def __init__(
467
+ self,
468
+ in_channels,
469
+ filter_channels,
470
+ kernel_size,
471
+ n_layers,
472
+ num_bins=10,
473
+ tail_bound=5.0,
474
+ ):
475
+ super().__init__()
476
+ self.in_channels = in_channels
477
+ self.filter_channels = filter_channels
478
+ self.kernel_size = kernel_size
479
+ self.n_layers = n_layers
480
+ self.num_bins = num_bins
481
+ self.tail_bound = tail_bound
482
+ self.half_channels = in_channels // 2
483
+
484
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
485
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
486
+ self.proj = nn.Conv1d(
487
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
488
+ )
489
+ self.proj.weight.data.zero_()
490
+ self.proj.bias.data.zero_()
491
+
492
+ def forward(self, x, x_mask, g=None, reverse=False):
493
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
494
+ h = self.pre(x0)
495
+ h = self.convs(h, x_mask, g=g)
496
+ h = self.proj(h) * x_mask
497
+
498
+ b, c, t = x0.shape
499
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
500
+
501
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
502
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
503
+ self.filter_channels
504
+ )
505
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
506
+
507
+ x1, logabsdet = piecewise_rational_quadratic_transform(
508
+ x1,
509
+ unnormalized_widths,
510
+ unnormalized_heights,
511
+ unnormalized_derivatives,
512
+ inverse=reverse,
513
+ tails="linear",
514
+ tail_bound=self.tail_bound,
515
+ )
516
+
517
+ x = torch.cat([x0, x1], 1) * x_mask
518
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
519
+ if not reverse:
520
+ return x, logdet
521
+ else:
522
+ return x
src/infer_pack/predictor/FCPE.py ADDED
@@ -0,0 +1,1036 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn.utils.parametrizations import weight_norm
8
+ from torchaudio.transforms import Resample
9
+ import os
10
+ import librosa
11
+ import soundfile as sf
12
+ import torch.utils.data
13
+ from librosa.filters import mel as librosa_mel_fn
14
+ import math
15
+ from functools import partial
16
+
17
+ from einops import rearrange, repeat
18
+ from local_attention import LocalAttention
19
+ from torch import nn
20
+
21
+ os.environ["LRU_CACHE_CAPACITY"] = "3"
22
+
23
+
24
+ def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
25
+ sampling_rate = None
26
+ try:
27
+ data, sampling_rate = sf.read(full_path, always_2d=True) # than soundfile.
28
+ except Exception as error:
29
+ print(f"'{full_path}' failed to load with {error}")
30
+ if return_empty_on_exception:
31
+ return [], sampling_rate or target_sr or 48000
32
+ else:
33
+ raise Exception(error)
34
+
35
+ if len(data.shape) > 1:
36
+ data = data[:, 0]
37
+ assert (
38
+ len(data) > 2
39
+ ) # check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension)
40
+
41
+ if np.issubdtype(data.dtype, np.integer): # if audio data is type int
42
+ max_mag = -np.iinfo(
43
+ data.dtype
44
+ ).min # maximum magnitude = min possible value of intXX
45
+ else: # if audio data is type fp32
46
+ max_mag = max(np.amax(data), -np.amin(data))
47
+ max_mag = (
48
+ (2**31) + 1
49
+ if max_mag > (2**15)
50
+ else ((2**15) + 1 if max_mag > 1.01 else 1.0)
51
+ ) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32
52
+
53
+ data = torch.FloatTensor(data.astype(np.float32)) / max_mag
54
+
55
+ if (
56
+ torch.isinf(data) | torch.isnan(data)
57
+ ).any() and return_empty_on_exception: # resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except
58
+ return [], sampling_rate or target_sr or 48000
59
+ if target_sr is not None and sampling_rate != target_sr:
60
+ data = torch.from_numpy(
61
+ librosa.core.resample(
62
+ data.numpy(), orig_sr=sampling_rate, target_sr=target_sr
63
+ )
64
+ )
65
+ sampling_rate = target_sr
66
+
67
+ return data, sampling_rate
68
+
69
+
70
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
71
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
72
+
73
+
74
+ def dynamic_range_decompression(x, C=1):
75
+ return np.exp(x) / C
76
+
77
+
78
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
79
+ return torch.log(torch.clamp(x, min=clip_val) * C)
80
+
81
+
82
+ def dynamic_range_decompression_torch(x, C=1):
83
+ return torch.exp(x) / C
84
+
85
+
86
+ class STFT:
87
+ def __init__(
88
+ self,
89
+ sr=22050,
90
+ n_mels=80,
91
+ n_fft=1024,
92
+ win_size=1024,
93
+ hop_length=256,
94
+ fmin=20,
95
+ fmax=11025,
96
+ clip_val=1e-5,
97
+ ):
98
+ self.target_sr = sr
99
+
100
+ self.n_mels = n_mels
101
+ self.n_fft = n_fft
102
+ self.win_size = win_size
103
+ self.hop_length = hop_length
104
+ self.fmin = fmin
105
+ self.fmax = fmax
106
+ self.clip_val = clip_val
107
+ self.mel_basis = {}
108
+ self.hann_window = {}
109
+
110
+ def get_mel(self, y, keyshift=0, speed=1, center=False, train=False):
111
+ sampling_rate = self.target_sr
112
+ n_mels = self.n_mels
113
+ n_fft = self.n_fft
114
+ win_size = self.win_size
115
+ hop_length = self.hop_length
116
+ fmin = self.fmin
117
+ fmax = self.fmax
118
+ clip_val = self.clip_val
119
+
120
+ factor = 2 ** (keyshift / 12)
121
+ n_fft_new = int(np.round(n_fft * factor))
122
+ win_size_new = int(np.round(win_size * factor))
123
+ hop_length_new = int(np.round(hop_length * speed))
124
+ if not train:
125
+ mel_basis = self.mel_basis
126
+ hann_window = self.hann_window
127
+ else:
128
+ mel_basis = {}
129
+ hann_window = {}
130
+
131
+ mel_basis_key = str(fmax) + "_" + str(y.device)
132
+ if mel_basis_key not in mel_basis:
133
+ mel = librosa_mel_fn(
134
+ sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax
135
+ )
136
+ mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device)
137
+
138
+ keyshift_key = str(keyshift) + "_" + str(y.device)
139
+ if keyshift_key not in hann_window:
140
+ hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
141
+
142
+ pad_left = (win_size_new - hop_length_new) // 2
143
+ pad_right = max(
144
+ (win_size_new - hop_length_new + 1) // 2,
145
+ win_size_new - y.size(-1) - pad_left,
146
+ )
147
+ if pad_right < y.size(-1):
148
+ mode = "reflect"
149
+ else:
150
+ mode = "constant"
151
+ y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode)
152
+ y = y.squeeze(1)
153
+
154
+ spec = torch.stft(
155
+ y,
156
+ n_fft_new,
157
+ hop_length=hop_length_new,
158
+ win_length=win_size_new,
159
+ window=hann_window[keyshift_key],
160
+ center=center,
161
+ pad_mode="reflect",
162
+ normalized=False,
163
+ onesided=True,
164
+ return_complex=True,
165
+ )
166
+ spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))
167
+ if keyshift != 0:
168
+ size = n_fft // 2 + 1
169
+ resize = spec.size(1)
170
+ if resize < size:
171
+ spec = F.pad(spec, (0, 0, 0, size - resize))
172
+ spec = spec[:, :size, :] * win_size / win_size_new
173
+ spec = torch.matmul(mel_basis[mel_basis_key], spec)
174
+ spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
175
+ return spec
176
+
177
+ def __call__(self, audiopath):
178
+ audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
179
+ spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
180
+ return spect
181
+
182
+
183
+ stft = STFT()
184
+
185
+ # import fast_transformers.causal_product.causal_product_cuda
186
+
187
+
188
+ def softmax_kernel(
189
+ data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device=None
190
+ ):
191
+ b, h, *_ = data.shape
192
+ # (batch size, head, length, model_dim)
193
+
194
+ # normalize model dim
195
+ data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.0
196
+
197
+ # what is ration?, projection_matrix.shape[0] --> 266
198
+
199
+ ratio = projection_matrix.shape[0] ** -0.5
200
+
201
+ projection = repeat(projection_matrix, "j d -> b h j d", b=b, h=h)
202
+ projection = projection.type_as(data)
203
+
204
+ # data_dash = w^T x
205
+ data_dash = torch.einsum("...id,...jd->...ij", (data_normalizer * data), projection)
206
+
207
+ # diag_data = D**2
208
+ diag_data = data**2
209
+ diag_data = torch.sum(diag_data, dim=-1)
210
+ diag_data = (diag_data / 2.0) * (data_normalizer**2)
211
+ diag_data = diag_data.unsqueeze(dim=-1)
212
+
213
+ if is_query:
214
+ data_dash = ratio * (
215
+ torch.exp(
216
+ data_dash
217
+ - diag_data
218
+ - torch.max(data_dash, dim=-1, keepdim=True).values
219
+ )
220
+ + eps
221
+ )
222
+ else:
223
+ data_dash = ratio * (
224
+ torch.exp(data_dash - diag_data + eps)
225
+ ) # - torch.max(data_dash)) + eps)
226
+
227
+ return data_dash.type_as(data)
228
+
229
+
230
+ def orthogonal_matrix_chunk(cols, qr_uniform_q=False, device=None):
231
+ unstructured_block = torch.randn((cols, cols), device=device)
232
+ q, r = torch.linalg.qr(unstructured_block.cpu(), mode="reduced")
233
+ q, r = map(lambda t: t.to(device), (q, r))
234
+
235
+ # proposed by @Parskatt
236
+ # to make sure Q is uniform https://arxiv.org/pdf/math-ph/0609050.pdf
237
+ if qr_uniform_q:
238
+ d = torch.diag(r, 0)
239
+ q *= d.sign()
240
+ return q.t()
241
+
242
+
243
+ def exists(val):
244
+ return val is not None
245
+
246
+
247
+ def empty(tensor):
248
+ return tensor.numel() == 0
249
+
250
+
251
+ def default(val, d):
252
+ return val if exists(val) else d
253
+
254
+
255
+ def cast_tuple(val):
256
+ return (val,) if not isinstance(val, tuple) else val
257
+
258
+
259
+ class PCmer(nn.Module):
260
+ """The encoder that is used in the Transformer model."""
261
+
262
+ def __init__(
263
+ self,
264
+ num_layers,
265
+ num_heads,
266
+ dim_model,
267
+ dim_keys,
268
+ dim_values,
269
+ residual_dropout,
270
+ attention_dropout,
271
+ ):
272
+ super().__init__()
273
+ self.num_layers = num_layers
274
+ self.num_heads = num_heads
275
+ self.dim_model = dim_model
276
+ self.dim_values = dim_values
277
+ self.dim_keys = dim_keys
278
+ self.residual_dropout = residual_dropout
279
+ self.attention_dropout = attention_dropout
280
+
281
+ self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)])
282
+
283
+ # METHODS ########################################################################################################
284
+
285
+ def forward(self, phone, mask=None):
286
+
287
+ # apply all layers to the input
288
+ for i, layer in enumerate(self._layers):
289
+ phone = layer(phone, mask)
290
+ # provide the final sequence
291
+ return phone
292
+
293
+
294
+ # ==================================================================================================================== #
295
+ # CLASS _ E N C O D E R L A Y E R #
296
+ # ==================================================================================================================== #
297
+
298
+
299
+ class _EncoderLayer(nn.Module):
300
+ """One layer of the encoder.
301
+
302
+ Attributes:
303
+ attn: (:class:`mha.MultiHeadAttention`): The attention mechanism that is used to read the input sequence.
304
+ feed_forward (:class:`ffl.FeedForwardLayer`): The feed-forward layer on top of the attention mechanism.
305
+ """
306
+
307
+ def __init__(self, parent: PCmer):
308
+ """Creates a new instance of ``_EncoderLayer``.
309
+
310
+ Args:
311
+ parent (Encoder): The encoder that the layers is created for.
312
+ """
313
+ super().__init__()
314
+
315
+ self.conformer = ConformerConvModule(parent.dim_model)
316
+ self.norm = nn.LayerNorm(parent.dim_model)
317
+ self.dropout = nn.Dropout(parent.residual_dropout)
318
+
319
+ # selfatt -> fastatt: performer!
320
+ self.attn = SelfAttention(
321
+ dim=parent.dim_model, heads=parent.num_heads, causal=False
322
+ )
323
+
324
+ # METHODS ########################################################################################################
325
+
326
+ def forward(self, phone, mask=None):
327
+
328
+ # compute attention sub-layer
329
+ phone = phone + (self.attn(self.norm(phone), mask=mask))
330
+
331
+ phone = phone + (self.conformer(phone))
332
+
333
+ return phone
334
+
335
+
336
+ def calc_same_padding(kernel_size):
337
+ pad = kernel_size // 2
338
+ return (pad, pad - (kernel_size + 1) % 2)
339
+
340
+
341
+ # helper classes
342
+
343
+
344
+ class Swish(nn.Module):
345
+ def forward(self, x):
346
+ return x * x.sigmoid()
347
+
348
+
349
+ class Transpose(nn.Module):
350
+ def __init__(self, dims):
351
+ super().__init__()
352
+ assert len(dims) == 2, "dims must be a tuple of two dimensions"
353
+ self.dims = dims
354
+
355
+ def forward(self, x):
356
+ return x.transpose(*self.dims)
357
+
358
+
359
+ class GLU(nn.Module):
360
+ def __init__(self, dim):
361
+ super().__init__()
362
+ self.dim = dim
363
+
364
+ def forward(self, x):
365
+ out, gate = x.chunk(2, dim=self.dim)
366
+ return out * gate.sigmoid()
367
+
368
+
369
+ class DepthWiseConv1d(nn.Module):
370
+ def __init__(self, chan_in, chan_out, kernel_size, padding):
371
+ super().__init__()
372
+ self.padding = padding
373
+ self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)
374
+
375
+ def forward(self, x):
376
+ x = F.pad(x, self.padding)
377
+ return self.conv(x)
378
+
379
+
380
+ class ConformerConvModule(nn.Module):
381
+ def __init__(
382
+ self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0
383
+ ):
384
+ super().__init__()
385
+
386
+ inner_dim = dim * expansion_factor
387
+ padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
388
+
389
+ self.net = nn.Sequential(
390
+ nn.LayerNorm(dim),
391
+ Transpose((1, 2)),
392
+ nn.Conv1d(dim, inner_dim * 2, 1),
393
+ GLU(dim=1),
394
+ DepthWiseConv1d(
395
+ inner_dim, inner_dim, kernel_size=kernel_size, padding=padding
396
+ ),
397
+ # nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
398
+ Swish(),
399
+ nn.Conv1d(inner_dim, dim, 1),
400
+ Transpose((1, 2)),
401
+ nn.Dropout(dropout),
402
+ )
403
+
404
+ def forward(self, x):
405
+ return self.net(x)
406
+
407
+
408
+ def linear_attention(q, k, v):
409
+ if v is None:
410
+ out = torch.einsum("...ed,...nd->...ne", k, q)
411
+ return out
412
+
413
+ else:
414
+ k_cumsum = k.sum(dim=-2)
415
+ # k_cumsum = k.sum(dim = -2)
416
+ D_inv = 1.0 / (torch.einsum("...nd,...d->...n", q, k_cumsum.type_as(q)) + 1e-8)
417
+
418
+ context = torch.einsum("...nd,...ne->...de", k, v)
419
+ out = torch.einsum("...de,...nd,...n->...ne", context, q, D_inv)
420
+ return out
421
+
422
+
423
+ def gaussian_orthogonal_random_matrix(
424
+ nb_rows, nb_columns, scaling=0, qr_uniform_q=False, device=None
425
+ ):
426
+ nb_full_blocks = int(nb_rows / nb_columns)
427
+ block_list = []
428
+
429
+ for _ in range(nb_full_blocks):
430
+ q = orthogonal_matrix_chunk(
431
+ nb_columns, qr_uniform_q=qr_uniform_q, device=device
432
+ )
433
+ block_list.append(q)
434
+
435
+ remaining_rows = nb_rows - nb_full_blocks * nb_columns
436
+ if remaining_rows > 0:
437
+ q = orthogonal_matrix_chunk(
438
+ nb_columns, qr_uniform_q=qr_uniform_q, device=device
439
+ )
440
+
441
+ block_list.append(q[:remaining_rows])
442
+
443
+ final_matrix = torch.cat(block_list)
444
+
445
+ if scaling == 0:
446
+ multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(dim=1)
447
+ elif scaling == 1:
448
+ multiplier = math.sqrt((float(nb_columns))) * torch.ones(
449
+ (nb_rows,), device=device
450
+ )
451
+ else:
452
+ raise ValueError(f"Invalid scaling {scaling}")
453
+
454
+ return torch.diag(multiplier) @ final_matrix
455
+
456
+
457
+ class FastAttention(nn.Module):
458
+ def __init__(
459
+ self,
460
+ dim_heads,
461
+ nb_features=None,
462
+ ortho_scaling=0,
463
+ causal=False,
464
+ generalized_attention=False,
465
+ kernel_fn=nn.ReLU(),
466
+ qr_uniform_q=False,
467
+ no_projection=False,
468
+ ):
469
+ super().__init__()
470
+ nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
471
+
472
+ self.dim_heads = dim_heads
473
+ self.nb_features = nb_features
474
+ self.ortho_scaling = ortho_scaling
475
+
476
+ self.create_projection = partial(
477
+ gaussian_orthogonal_random_matrix,
478
+ nb_rows=self.nb_features,
479
+ nb_columns=dim_heads,
480
+ scaling=ortho_scaling,
481
+ qr_uniform_q=qr_uniform_q,
482
+ )
483
+ projection_matrix = self.create_projection()
484
+ self.register_buffer("projection_matrix", projection_matrix)
485
+
486
+ self.generalized_attention = generalized_attention
487
+ self.kernel_fn = kernel_fn
488
+
489
+ # if this is turned on, no projection will be used
490
+ # queries and keys will be softmax-ed as in the original efficient attention paper
491
+ self.no_projection = no_projection
492
+
493
+ self.causal = causal
494
+
495
+ @torch.no_grad()
496
+ def redraw_projection_matrix(self):
497
+ projections = self.create_projection()
498
+ self.projection_matrix.copy_(projections)
499
+ del projections
500
+
501
+ def forward(self, q, k, v):
502
+ device = q.device
503
+
504
+ if self.no_projection:
505
+ q = q.softmax(dim=-1)
506
+ k = torch.exp(k) if self.causal else k.softmax(dim=-2)
507
+ else:
508
+ create_kernel = partial(
509
+ softmax_kernel, projection_matrix=self.projection_matrix, device=device
510
+ )
511
+
512
+ q = create_kernel(q, is_query=True)
513
+ k = create_kernel(k, is_query=False)
514
+
515
+ attn_fn = linear_attention if not self.causal else self.causal_linear_fn
516
+ if v is None:
517
+ out = attn_fn(q, k, None)
518
+ return out
519
+ else:
520
+ out = attn_fn(q, k, v)
521
+ return out
522
+
523
+
524
+ class SelfAttention(nn.Module):
525
+ def __init__(
526
+ self,
527
+ dim,
528
+ causal=False,
529
+ heads=8,
530
+ dim_head=64,
531
+ local_heads=0,
532
+ local_window_size=256,
533
+ nb_features=None,
534
+ feature_redraw_interval=1000,
535
+ generalized_attention=False,
536
+ kernel_fn=nn.ReLU(),
537
+ qr_uniform_q=False,
538
+ dropout=0.0,
539
+ no_projection=False,
540
+ ):
541
+ super().__init__()
542
+ assert dim % heads == 0, "dimension must be divisible by number of heads"
543
+ dim_head = default(dim_head, dim // heads)
544
+ inner_dim = dim_head * heads
545
+ self.fast_attention = FastAttention(
546
+ dim_head,
547
+ nb_features,
548
+ causal=causal,
549
+ generalized_attention=generalized_attention,
550
+ kernel_fn=kernel_fn,
551
+ qr_uniform_q=qr_uniform_q,
552
+ no_projection=no_projection,
553
+ )
554
+
555
+ self.heads = heads
556
+ self.global_heads = heads - local_heads
557
+ self.local_attn = (
558
+ LocalAttention(
559
+ window_size=local_window_size,
560
+ causal=causal,
561
+ autopad=True,
562
+ dropout=dropout,
563
+ look_forward=int(not causal),
564
+ rel_pos_emb_config=(dim_head, local_heads),
565
+ )
566
+ if local_heads > 0
567
+ else None
568
+ )
569
+
570
+ self.to_q = nn.Linear(dim, inner_dim)
571
+ self.to_k = nn.Linear(dim, inner_dim)
572
+ self.to_v = nn.Linear(dim, inner_dim)
573
+ self.to_out = nn.Linear(inner_dim, dim)
574
+ self.dropout = nn.Dropout(dropout)
575
+
576
+ @torch.no_grad()
577
+ def redraw_projection_matrix(self):
578
+ self.fast_attention.redraw_projection_matrix()
579
+
580
+ def forward(
581
+ self,
582
+ x,
583
+ context=None,
584
+ mask=None,
585
+ context_mask=None,
586
+ name=None,
587
+ inference=False,
588
+ **kwargs,
589
+ ):
590
+ _, _, _, h, gh = *x.shape, self.heads, self.global_heads
591
+
592
+ cross_attend = exists(context)
593
+
594
+ context = default(context, x)
595
+ context_mask = default(context_mask, mask) if not cross_attend else context_mask
596
+ q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
597
+
598
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
599
+ (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
600
+
601
+ attn_outs = []
602
+ if not empty(q):
603
+ if exists(context_mask):
604
+ global_mask = context_mask[:, None, :, None]
605
+ v.masked_fill_(~global_mask, 0.0)
606
+ if cross_attend:
607
+ pass
608
+ else:
609
+ out = self.fast_attention(q, k, v)
610
+ attn_outs.append(out)
611
+
612
+ if not empty(lq):
613
+ assert (
614
+ not cross_attend
615
+ ), "local attention is not compatible with cross attention"
616
+ out = self.local_attn(lq, lk, lv, input_mask=mask)
617
+ attn_outs.append(out)
618
+
619
+ out = torch.cat(attn_outs, dim=1)
620
+ out = rearrange(out, "b h n d -> b n (h d)")
621
+ out = self.to_out(out)
622
+ return self.dropout(out)
623
+
624
+
625
+ def l2_regularization(model, l2_alpha):
626
+ l2_loss = []
627
+ for module in model.modules():
628
+ if type(module) is nn.Conv2d:
629
+ l2_loss.append((module.weight**2).sum() / 2.0)
630
+ return l2_alpha * sum(l2_loss)
631
+
632
+
633
+ class FCPE(nn.Module):
634
+ def __init__(
635
+ self,
636
+ input_channel=128,
637
+ out_dims=360,
638
+ n_layers=12,
639
+ n_chans=512,
640
+ use_siren=False,
641
+ use_full=False,
642
+ loss_mse_scale=10,
643
+ loss_l2_regularization=False,
644
+ loss_l2_regularization_scale=1,
645
+ loss_grad1_mse=False,
646
+ loss_grad1_mse_scale=1,
647
+ f0_max=1975.5,
648
+ f0_min=32.70,
649
+ confidence=False,
650
+ threshold=0.05,
651
+ use_input_conv=True,
652
+ ):
653
+ super().__init__()
654
+ if use_siren is True:
655
+ raise ValueError("Siren is not supported yet.")
656
+ if use_full is True:
657
+ raise ValueError("Full model is not supported yet.")
658
+
659
+ self.loss_mse_scale = loss_mse_scale if (loss_mse_scale is not None) else 10
660
+ self.loss_l2_regularization = (
661
+ loss_l2_regularization if (loss_l2_regularization is not None) else False
662
+ )
663
+ self.loss_l2_regularization_scale = (
664
+ loss_l2_regularization_scale
665
+ if (loss_l2_regularization_scale is not None)
666
+ else 1
667
+ )
668
+ self.loss_grad1_mse = loss_grad1_mse if (loss_grad1_mse is not None) else False
669
+ self.loss_grad1_mse_scale = (
670
+ loss_grad1_mse_scale if (loss_grad1_mse_scale is not None) else 1
671
+ )
672
+ self.f0_max = f0_max if (f0_max is not None) else 1975.5
673
+ self.f0_min = f0_min if (f0_min is not None) else 32.70
674
+ self.confidence = confidence if (confidence is not None) else False
675
+ self.threshold = threshold if (threshold is not None) else 0.05
676
+ self.use_input_conv = use_input_conv if (use_input_conv is not None) else True
677
+
678
+ self.cent_table_b = torch.Tensor(
679
+ np.linspace(
680
+ self.f0_to_cent(torch.Tensor([f0_min]))[0],
681
+ self.f0_to_cent(torch.Tensor([f0_max]))[0],
682
+ out_dims,
683
+ )
684
+ )
685
+ self.register_buffer("cent_table", self.cent_table_b)
686
+
687
+ # conv in stack
688
+ _leaky = nn.LeakyReLU()
689
+ self.stack = nn.Sequential(
690
+ nn.Conv1d(input_channel, n_chans, 3, 1, 1),
691
+ nn.GroupNorm(4, n_chans),
692
+ _leaky,
693
+ nn.Conv1d(n_chans, n_chans, 3, 1, 1),
694
+ )
695
+
696
+ # transformer
697
+ self.decoder = PCmer(
698
+ num_layers=n_layers,
699
+ num_heads=8,
700
+ dim_model=n_chans,
701
+ dim_keys=n_chans,
702
+ dim_values=n_chans,
703
+ residual_dropout=0.1,
704
+ attention_dropout=0.1,
705
+ )
706
+ self.norm = nn.LayerNorm(n_chans)
707
+
708
+ # out
709
+ self.n_out = out_dims
710
+ self.dense_out = weight_norm(nn.Linear(n_chans, self.n_out))
711
+
712
+ def forward(
713
+ self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder="local_argmax"
714
+ ):
715
+ """
716
+ input:
717
+ B x n_frames x n_unit
718
+ return:
719
+ dict of B x n_frames x feat
720
+ """
721
+ if cdecoder == "argmax":
722
+ self.cdecoder = self.cents_decoder
723
+ elif cdecoder == "local_argmax":
724
+ self.cdecoder = self.cents_local_decoder
725
+ if self.use_input_conv:
726
+ x = self.stack(mel.transpose(1, 2)).transpose(1, 2)
727
+ else:
728
+ x = mel
729
+ x = self.decoder(x)
730
+ x = self.norm(x)
731
+ x = self.dense_out(x) # [B,N,D]
732
+ x = torch.sigmoid(x)
733
+ if not infer:
734
+ gt_cent_f0 = self.f0_to_cent(gt_f0) # mel f0 #[B,N,1]
735
+ gt_cent_f0 = self.gaussian_blurred_cent(gt_cent_f0) # #[B,N,out_dim]
736
+ loss_all = self.loss_mse_scale * F.binary_cross_entropy(
737
+ x, gt_cent_f0
738
+ ) # bce loss
739
+ # l2 regularization
740
+ if self.loss_l2_regularization:
741
+ loss_all = loss_all + l2_regularization(
742
+ model=self, l2_alpha=self.loss_l2_regularization_scale
743
+ )
744
+ x = loss_all
745
+ if infer:
746
+ x = self.cdecoder(x)
747
+ x = self.cent_to_f0(x)
748
+ if not return_hz_f0:
749
+ x = (1 + x / 700).log()
750
+ return x
751
+
752
+ def cents_decoder(self, y, mask=True):
753
+ B, N, _ = y.size()
754
+ ci = self.cent_table[None, None, :].expand(B, N, -1)
755
+ rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(
756
+ y, dim=-1, keepdim=True
757
+ ) # cents: [B,N,1]
758
+ if mask:
759
+ confident = torch.max(y, dim=-1, keepdim=True)[0]
760
+ confident_mask = torch.ones_like(confident)
761
+ confident_mask[confident <= self.threshold] = float("-INF")
762
+ rtn = rtn * confident_mask
763
+ if self.confidence:
764
+ return rtn, confident
765
+ else:
766
+ return rtn
767
+
768
+ def cents_local_decoder(self, y, mask=True):
769
+ B, N, _ = y.size()
770
+ ci = self.cent_table[None, None, :].expand(B, N, -1)
771
+ confident, max_index = torch.max(y, dim=-1, keepdim=True)
772
+ local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4)
773
+ local_argmax_index[local_argmax_index < 0] = 0
774
+ local_argmax_index[local_argmax_index >= self.n_out] = self.n_out - 1
775
+ ci_l = torch.gather(ci, -1, local_argmax_index)
776
+ y_l = torch.gather(y, -1, local_argmax_index)
777
+ rtn = torch.sum(ci_l * y_l, dim=-1, keepdim=True) / torch.sum(
778
+ y_l, dim=-1, keepdim=True
779
+ ) # cents: [B,N,1]
780
+ if mask:
781
+ confident_mask = torch.ones_like(confident)
782
+ confident_mask[confident <= self.threshold] = float("-INF")
783
+ rtn = rtn * confident_mask
784
+ if self.confidence:
785
+ return rtn, confident
786
+ else:
787
+ return rtn
788
+
789
+ def cent_to_f0(self, cent):
790
+ return 10.0 * 2 ** (cent / 1200.0)
791
+
792
+ def f0_to_cent(self, f0):
793
+ return 1200.0 * torch.log2(f0 / 10.0)
794
+
795
+ def gaussian_blurred_cent(self, cents): # cents: [B,N,1]
796
+ mask = (cents > 0.1) & (cents < (1200.0 * np.log2(self.f0_max / 10.0)))
797
+ B, N, _ = cents.size()
798
+ ci = self.cent_table[None, None, :].expand(B, N, -1)
799
+ return torch.exp(-torch.square(ci - cents) / 1250) * mask.float()
800
+
801
+
802
+ class FCPEInfer:
803
+ def __init__(self, model_path, device=None, dtype=torch.float32):
804
+ if device is None:
805
+ device = "cuda" if torch.cuda.is_available() else "cpu"
806
+ self.device = device
807
+ ckpt = torch.load(model_path, map_location=torch.device(self.device))
808
+ self.args = DotDict(ckpt["config"])
809
+ self.dtype = dtype
810
+ model = FCPE(
811
+ input_channel=self.args.model.input_channel,
812
+ out_dims=self.args.model.out_dims,
813
+ n_layers=self.args.model.n_layers,
814
+ n_chans=self.args.model.n_chans,
815
+ use_siren=self.args.model.use_siren,
816
+ use_full=self.args.model.use_full,
817
+ loss_mse_scale=self.args.loss.loss_mse_scale,
818
+ loss_l2_regularization=self.args.loss.loss_l2_regularization,
819
+ loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale,
820
+ loss_grad1_mse=self.args.loss.loss_grad1_mse,
821
+ loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale,
822
+ f0_max=self.args.model.f0_max,
823
+ f0_min=self.args.model.f0_min,
824
+ confidence=self.args.model.confidence,
825
+ )
826
+ model.to(self.device).to(self.dtype)
827
+ model.load_state_dict(ckpt["model"])
828
+ model.eval()
829
+ self.model = model
830
+ self.wav2mel = Wav2Mel(self.args, dtype=self.dtype, device=self.device)
831
+
832
+ @torch.no_grad()
833
+ def __call__(self, audio, sr, threshold=0.05):
834
+ self.model.threshold = threshold
835
+ audio = audio[None, :]
836
+ mel = self.wav2mel(audio=audio, sample_rate=sr).to(self.dtype)
837
+ f0 = self.model(mel=mel, infer=True, return_hz_f0=True)
838
+ return f0
839
+
840
+
841
+ class Wav2Mel:
842
+
843
+ def __init__(self, args, device=None, dtype=torch.float32):
844
+ # self.args = args
845
+ self.sampling_rate = args.mel.sampling_rate
846
+ self.hop_size = args.mel.hop_size
847
+ if device is None:
848
+ device = "cuda" if torch.cuda.is_available() else "cpu"
849
+ self.device = device
850
+ self.dtype = dtype
851
+ self.stft = STFT(
852
+ args.mel.sampling_rate,
853
+ args.mel.num_mels,
854
+ args.mel.n_fft,
855
+ args.mel.win_size,
856
+ args.mel.hop_size,
857
+ args.mel.fmin,
858
+ args.mel.fmax,
859
+ )
860
+ self.resample_kernel = {}
861
+
862
+ def extract_nvstft(self, audio, keyshift=0, train=False):
863
+ mel = self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(
864
+ 1, 2
865
+ ) # B, n_frames, bins
866
+ return mel
867
+
868
+ def extract_mel(self, audio, sample_rate, keyshift=0, train=False):
869
+ audio = audio.to(self.dtype).to(self.device)
870
+ # resample
871
+ if sample_rate == self.sampling_rate:
872
+ audio_res = audio
873
+ else:
874
+ key_str = str(sample_rate)
875
+ if key_str not in self.resample_kernel:
876
+ self.resample_kernel[key_str] = Resample(
877
+ sample_rate, self.sampling_rate, lowpass_filter_width=128
878
+ )
879
+ self.resample_kernel[key_str] = (
880
+ self.resample_kernel[key_str].to(self.dtype).to(self.device)
881
+ )
882
+ audio_res = self.resample_kernel[key_str](audio)
883
+
884
+ # extract
885
+ mel = self.extract_nvstft(
886
+ audio_res, keyshift=keyshift, train=train
887
+ ) # B, n_frames, bins
888
+ n_frames = int(audio.shape[1] // self.hop_size) + 1
889
+ if n_frames > int(mel.shape[1]):
890
+ mel = torch.cat((mel, mel[:, -1:, :]), 1)
891
+ if n_frames < int(mel.shape[1]):
892
+ mel = mel[:, :n_frames, :]
893
+ return mel
894
+
895
+ def __call__(self, audio, sample_rate, keyshift=0, train=False):
896
+ return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train)
897
+
898
+
899
+ class DotDict(dict):
900
+ def __getattr__(*args):
901
+ val = dict.get(*args)
902
+ return DotDict(val) if type(val) is dict else val
903
+
904
+ __setattr__ = dict.__setitem__
905
+ __delattr__ = dict.__delitem__
906
+
907
+
908
+ class F0Predictor(object):
909
+ def compute_f0(self, wav, p_len):
910
+ """
911
+ input: wav:[signal_length]
912
+ p_len:int
913
+ output: f0:[signal_length//hop_length]
914
+ """
915
+ pass
916
+
917
+ def compute_f0_uv(self, wav, p_len):
918
+ """
919
+ input: wav:[signal_length]
920
+ p_len:int
921
+ output: f0:[signal_length//hop_length],uv:[signal_length//hop_length]
922
+ """
923
+ pass
924
+
925
+
926
+ class FCPEF0Predictor(F0Predictor):
927
+ def __init__(
928
+ self,
929
+ model_path,
930
+ hop_length=512,
931
+ f0_min=50,
932
+ f0_max=1100,
933
+ dtype=torch.float32,
934
+ device=None,
935
+ sampling_rate=44100,
936
+ threshold=0.05,
937
+ ):
938
+ self.fcpe = FCPEInfer(model_path, device=device, dtype=dtype)
939
+ self.hop_length = hop_length
940
+ self.f0_min = f0_min
941
+ self.f0_max = f0_max
942
+ if device is None:
943
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
944
+ else:
945
+ self.device = device
946
+ self.threshold = threshold
947
+ self.sampling_rate = sampling_rate
948
+ self.dtype = dtype
949
+ self.name = "fcpe"
950
+
951
+ def repeat_expand(
952
+ self,
953
+ content: Union[torch.Tensor, np.ndarray],
954
+ target_len: int,
955
+ mode: str = "nearest",
956
+ ):
957
+ ndim = content.ndim
958
+
959
+ if content.ndim == 1:
960
+ content = content[None, None]
961
+ elif content.ndim == 2:
962
+ content = content[None]
963
+
964
+ assert content.ndim == 3
965
+
966
+ is_np = isinstance(content, np.ndarray)
967
+ if is_np:
968
+ content = torch.from_numpy(content)
969
+
970
+ results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
971
+
972
+ if is_np:
973
+ results = results.numpy()
974
+
975
+ if ndim == 1:
976
+ return results[0, 0]
977
+ elif ndim == 2:
978
+ return results[0]
979
+
980
+ def post_process(self, x, sampling_rate, f0, pad_to):
981
+ if isinstance(f0, np.ndarray):
982
+ f0 = torch.from_numpy(f0).float().to(x.device)
983
+
984
+ if pad_to is None:
985
+ return f0
986
+
987
+ f0 = self.repeat_expand(f0, pad_to)
988
+
989
+ vuv_vector = torch.zeros_like(f0)
990
+ vuv_vector[f0 > 0.0] = 1.0
991
+ vuv_vector[f0 <= 0.0] = 0.0
992
+
993
+ # 去掉0频率, 并线性插值
994
+ nzindex = torch.nonzero(f0).squeeze()
995
+ f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
996
+ time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy()
997
+ time_frame = np.arange(pad_to) * self.hop_length / sampling_rate
998
+
999
+ vuv_vector = F.interpolate(vuv_vector[None, None, :], size=pad_to)[0][0]
1000
+
1001
+ if f0.shape[0] <= 0:
1002
+ return (
1003
+ torch.zeros(pad_to, dtype=torch.float, device=x.device).cpu().numpy(),
1004
+ vuv_vector.cpu().numpy(),
1005
+ )
1006
+ if f0.shape[0] == 1:
1007
+ return (
1008
+ torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0]
1009
+ ).cpu().numpy(), vuv_vector.cpu().numpy()
1010
+
1011
+ # 大概可以用 torch 重写?
1012
+ f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
1013
+ # vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0))
1014
+
1015
+ return f0, vuv_vector.cpu().numpy()
1016
+
1017
+ def compute_f0(self, wav, p_len=None):
1018
+ x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
1019
+ if p_len is None:
1020
+ print("fcpe p_len is None")
1021
+ p_len = x.shape[0] // self.hop_length
1022
+ f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold)[0, :, 0]
1023
+ if torch.all(f0 == 0):
1024
+ rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
1025
+ return rtn, rtn
1026
+ return self.post_process(x, self.sampling_rate, f0, p_len)[0]
1027
+
1028
+ def compute_f0_uv(self, wav, p_len=None):
1029
+ x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
1030
+ if p_len is None:
1031
+ p_len = x.shape[0] // self.hop_length
1032
+ f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold)[0, :, 0]
1033
+ if torch.all(f0 == 0):
1034
+ rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
1035
+ return rtn, rtn
1036
+ return self.post_process(x, self.sampling_rate, f0, p_len)
src/infer_pack/predictor/RMVPE.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch, numpy as np
3
+ import torch.nn.functional as F
4
+ from librosa.filters import mel
5
+
6
+
7
+ class BiGRU(nn.Module):
8
+ def __init__(self, input_features, hidden_features, num_layers):
9
+ super(BiGRU, self).__init__()
10
+ self.gru = nn.GRU(
11
+ input_features,
12
+ hidden_features,
13
+ num_layers=num_layers,
14
+ batch_first=True,
15
+ bidirectional=True,
16
+ )
17
+
18
+ def forward(self, x):
19
+ return self.gru(x)[0]
20
+
21
+
22
+ class ConvBlockRes(nn.Module):
23
+ def __init__(self, in_channels, out_channels, momentum=0.01):
24
+ super(ConvBlockRes, self).__init__()
25
+ self.conv = nn.Sequential(
26
+ nn.Conv2d(
27
+ in_channels=in_channels,
28
+ out_channels=out_channels,
29
+ kernel_size=(3, 3),
30
+ stride=(1, 1),
31
+ padding=(1, 1),
32
+ bias=False,
33
+ ),
34
+ nn.BatchNorm2d(out_channels, momentum=momentum),
35
+ nn.ReLU(),
36
+ nn.Conv2d(
37
+ in_channels=out_channels,
38
+ out_channels=out_channels,
39
+ kernel_size=(3, 3),
40
+ stride=(1, 1),
41
+ padding=(1, 1),
42
+ bias=False,
43
+ ),
44
+ nn.BatchNorm2d(out_channels, momentum=momentum),
45
+ nn.ReLU(),
46
+ )
47
+ if in_channels != out_channels:
48
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
49
+ self.is_shortcut = True
50
+ else:
51
+ self.is_shortcut = False
52
+
53
+ def forward(self, x):
54
+ if self.is_shortcut:
55
+ return self.conv(x) + self.shortcut(x)
56
+ else:
57
+ return self.conv(x) + x
58
+
59
+
60
+ class Encoder(nn.Module):
61
+ def __init__(
62
+ self,
63
+ in_channels,
64
+ in_size,
65
+ n_encoders,
66
+ kernel_size,
67
+ n_blocks,
68
+ out_channels=16,
69
+ momentum=0.01,
70
+ ):
71
+ super(Encoder, self).__init__()
72
+ self.n_encoders = n_encoders
73
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
74
+ self.layers = nn.ModuleList()
75
+ self.latent_channels = []
76
+ for i in range(self.n_encoders):
77
+ self.layers.append(
78
+ ResEncoderBlock(
79
+ in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
80
+ )
81
+ )
82
+ self.latent_channels.append([out_channels, in_size])
83
+ in_channels = out_channels
84
+ out_channels *= 2
85
+ in_size //= 2
86
+ self.out_size = in_size
87
+ self.out_channel = out_channels
88
+
89
+ def forward(self, x):
90
+ concat_tensors = []
91
+ x = self.bn(x)
92
+ for i in range(self.n_encoders):
93
+ _, x = self.layers[i](x)
94
+ concat_tensors.append(_)
95
+ return x, concat_tensors
96
+
97
+
98
+ class ResEncoderBlock(nn.Module):
99
+ def __init__(
100
+ self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
101
+ ):
102
+ super(ResEncoderBlock, self).__init__()
103
+ self.n_blocks = n_blocks
104
+ self.conv = nn.ModuleList()
105
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
106
+ for i in range(n_blocks - 1):
107
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
108
+ self.kernel_size = kernel_size
109
+ if self.kernel_size is not None:
110
+ self.pool = nn.AvgPool2d(kernel_size=kernel_size)
111
+
112
+ def forward(self, x):
113
+ for i in range(self.n_blocks):
114
+ x = self.conv[i](x)
115
+ if self.kernel_size is not None:
116
+ return x, self.pool(x)
117
+ else:
118
+ return x
119
+
120
+
121
+ class Intermediate(nn.Module): #
122
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
123
+ super(Intermediate, self).__init__()
124
+ self.n_inters = n_inters
125
+ self.layers = nn.ModuleList()
126
+ self.layers.append(
127
+ ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
128
+ )
129
+ for i in range(self.n_inters - 1):
130
+ self.layers.append(
131
+ ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
132
+ )
133
+
134
+ def forward(self, x):
135
+ for i in range(self.n_inters):
136
+ x = self.layers[i](x)
137
+ return x
138
+
139
+
140
+ class ResDecoderBlock(nn.Module):
141
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
142
+ super(ResDecoderBlock, self).__init__()
143
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
144
+ self.n_blocks = n_blocks
145
+ self.conv1 = nn.Sequential(
146
+ nn.ConvTranspose2d(
147
+ in_channels=in_channels,
148
+ out_channels=out_channels,
149
+ kernel_size=(3, 3),
150
+ stride=stride,
151
+ padding=(1, 1),
152
+ output_padding=out_padding,
153
+ bias=False,
154
+ ),
155
+ nn.BatchNorm2d(out_channels, momentum=momentum),
156
+ nn.ReLU(),
157
+ )
158
+ self.conv2 = nn.ModuleList()
159
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
160
+ for i in range(n_blocks - 1):
161
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
162
+
163
+ def forward(self, x, concat_tensor):
164
+ x = self.conv1(x)
165
+ x = torch.cat((x, concat_tensor), dim=1)
166
+ for i in range(self.n_blocks):
167
+ x = self.conv2[i](x)
168
+ return x
169
+
170
+
171
+ class Decoder(nn.Module):
172
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
173
+ super(Decoder, self).__init__()
174
+ self.layers = nn.ModuleList()
175
+ self.n_decoders = n_decoders
176
+ for i in range(self.n_decoders):
177
+ out_channels = in_channels // 2
178
+ self.layers.append(
179
+ ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
180
+ )
181
+ in_channels = out_channels
182
+
183
+ def forward(self, x, concat_tensors):
184
+ for i in range(self.n_decoders):
185
+ x = self.layers[i](x, concat_tensors[-1 - i])
186
+ return x
187
+
188
+
189
+ class DeepUnet(nn.Module):
190
+ def __init__(
191
+ self,
192
+ kernel_size,
193
+ n_blocks,
194
+ en_de_layers=5,
195
+ inter_layers=4,
196
+ in_channels=1,
197
+ en_out_channels=16,
198
+ ):
199
+ super(DeepUnet, self).__init__()
200
+ self.encoder = Encoder(
201
+ in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
202
+ )
203
+ self.intermediate = Intermediate(
204
+ self.encoder.out_channel // 2,
205
+ self.encoder.out_channel,
206
+ inter_layers,
207
+ n_blocks,
208
+ )
209
+ self.decoder = Decoder(
210
+ self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
211
+ )
212
+
213
+ def forward(self, x):
214
+ x, concat_tensors = self.encoder(x)
215
+ x = self.intermediate(x)
216
+ x = self.decoder(x, concat_tensors)
217
+ return x
218
+
219
+
220
+ class E2E(nn.Module):
221
+ def __init__(
222
+ self,
223
+ n_blocks,
224
+ n_gru,
225
+ kernel_size,
226
+ en_de_layers=5,
227
+ inter_layers=4,
228
+ in_channels=1,
229
+ en_out_channels=16,
230
+ ):
231
+ super(E2E, self).__init__()
232
+ self.unet = DeepUnet(
233
+ kernel_size,
234
+ n_blocks,
235
+ en_de_layers,
236
+ inter_layers,
237
+ in_channels,
238
+ en_out_channels,
239
+ )
240
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
241
+ if n_gru:
242
+ self.fc = nn.Sequential(
243
+ BiGRU(3 * 128, 256, n_gru),
244
+ nn.Linear(512, 360),
245
+ nn.Dropout(0.25),
246
+ nn.Sigmoid(),
247
+ )
248
+
249
+ def forward(self, mel):
250
+ mel = mel.transpose(-1, -2).unsqueeze(1)
251
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
252
+ x = self.fc(x)
253
+ return x
254
+
255
+
256
+ class MelSpectrogram(torch.nn.Module):
257
+ def __init__(
258
+ self,
259
+ is_half,
260
+ n_mel_channels,
261
+ sampling_rate,
262
+ win_length,
263
+ hop_length,
264
+ n_fft=None,
265
+ mel_fmin=0,
266
+ mel_fmax=None,
267
+ clamp=1e-5,
268
+ ):
269
+ super().__init__()
270
+ n_fft = win_length if n_fft is None else n_fft
271
+ self.hann_window = {}
272
+ mel_basis = mel(
273
+ sr=sampling_rate,
274
+ n_fft=n_fft,
275
+ n_mels=n_mel_channels,
276
+ fmin=mel_fmin,
277
+ fmax=mel_fmax,
278
+ htk=True,
279
+ )
280
+ mel_basis = torch.from_numpy(mel_basis).float()
281
+ self.register_buffer("mel_basis", mel_basis)
282
+ self.n_fft = win_length if n_fft is None else n_fft
283
+ self.hop_length = hop_length
284
+ self.win_length = win_length
285
+ self.sampling_rate = sampling_rate
286
+ self.n_mel_channels = n_mel_channels
287
+ self.clamp = clamp
288
+ self.is_half = is_half
289
+
290
+ def forward(self, audio, keyshift=0, speed=1, center=True):
291
+ factor = 2 ** (keyshift / 12)
292
+ n_fft_new = int(np.round(self.n_fft * factor))
293
+ win_length_new = int(np.round(self.win_length * factor))
294
+ hop_length_new = int(np.round(self.hop_length * speed))
295
+ keyshift_key = str(keyshift) + "_" + str(audio.device)
296
+ if keyshift_key not in self.hann_window:
297
+ self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
298
+ audio.device
299
+ )
300
+ fft = torch.stft(
301
+ audio,
302
+ n_fft=n_fft_new,
303
+ hop_length=hop_length_new,
304
+ win_length=win_length_new,
305
+ window=self.hann_window[keyshift_key],
306
+ center=center,
307
+ return_complex=True,
308
+ )
309
+ magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
310
+ if keyshift != 0:
311
+ size = self.n_fft // 2 + 1
312
+ resize = magnitude.size(1)
313
+ if resize < size:
314
+ magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
315
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
316
+ mel_output = torch.matmul(self.mel_basis, magnitude)
317
+ if self.is_half == True:
318
+ mel_output = mel_output.half()
319
+ log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
320
+ return log_mel_spec
321
+
322
+
323
+ class RMVPE:
324
+ def __init__(self, model_path, is_half, device=None):
325
+ self.resample_kernel = {}
326
+ model = E2E(4, 1, (2, 2))
327
+ ckpt = torch.load(model_path, map_location="cpu")
328
+ model.load_state_dict(ckpt)
329
+ model.eval()
330
+ if is_half == True:
331
+ model = model.half()
332
+ self.model = model
333
+ self.resample_kernel = {}
334
+ self.is_half = is_half
335
+ if device is None:
336
+ device = "cuda" if torch.cuda.is_available() else "cpu"
337
+ self.device = device
338
+ self.mel_extractor = MelSpectrogram(
339
+ is_half, 128, 16000, 1024, 160, None, 30, 8000
340
+ ).to(device)
341
+ self.model = self.model.to(device)
342
+ cents_mapping = 20 * np.arange(360) + 1997.3794084376191
343
+ self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
344
+
345
+ def mel2hidden(self, mel):
346
+ with torch.no_grad():
347
+ n_frames = mel.shape[-1]
348
+ mel = F.pad(
349
+ mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect"
350
+ )
351
+ hidden = self.model(mel)
352
+ return hidden[:, :n_frames]
353
+
354
+ def decode(self, hidden, thred=0.03):
355
+ cents_pred = self.to_local_average_cents(hidden, thred=thred)
356
+ f0 = 10 * (2 ** (cents_pred / 1200))
357
+ f0[f0 == 10] = 0
358
+ return f0
359
+
360
+ def infer_from_audio(self, audio, thred=0.03):
361
+ audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0)
362
+ mel = self.mel_extractor(audio, center=True)
363
+ hidden = self.mel2hidden(mel)
364
+ hidden = hidden.squeeze(0).cpu().numpy()
365
+ if self.is_half == True:
366
+ hidden = hidden.astype("float32")
367
+ f0 = self.decode(hidden, thred=thred)
368
+ return f0
369
+
370
+ def to_local_average_cents(self, salience, thred=0.05):
371
+ center = np.argmax(salience, axis=1)
372
+ salience = np.pad(salience, ((0, 0), (4, 4)))
373
+ center += 4
374
+ todo_salience = []
375
+ todo_cents_mapping = []
376
+ starts = center - 4
377
+ ends = center + 5
378
+ for idx in range(salience.shape[0]):
379
+ todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
380
+ todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
381
+ todo_salience = np.array(todo_salience)
382
+ todo_cents_mapping = np.array(todo_cents_mapping)
383
+ product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
384
+ weight_sum = np.sum(todo_salience, 1)
385
+ devided = product_sum / weight_sum
386
+ maxx = np.max(salience, axis=1)
387
+ devided[maxx <= thred] = 0
388
+ return devided
389
+
390
+ def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
391
+ audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0)
392
+ mel = self.mel_extractor(audio, center=True)
393
+ hidden = self.mel2hidden(mel)
394
+ hidden = hidden.squeeze(0).cpu().numpy()
395
+ if self.is_half == True:
396
+ hidden = hidden.astype("float32")
397
+ f0 = self.decode(hidden, thred=thred)
398
+ f0[(f0 < f0_min) | (f0 > f0_max)] = 0
399
+ return f0
src/infer_pack/transforms.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(
13
+ inputs,
14
+ unnormalized_widths,
15
+ unnormalized_heights,
16
+ unnormalized_derivatives,
17
+ inverse=False,
18
+ tails=None,
19
+ tail_bound=1.0,
20
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
23
+ ):
24
+ if tails is None:
25
+ spline_fn = rational_quadratic_spline
26
+ spline_kwargs = {}
27
+ else:
28
+ spline_fn = unconstrained_rational_quadratic_spline
29
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
+
31
+ outputs, logabsdet = spline_fn(
32
+ inputs=inputs,
33
+ unnormalized_widths=unnormalized_widths,
34
+ unnormalized_heights=unnormalized_heights,
35
+ unnormalized_derivatives=unnormalized_derivatives,
36
+ inverse=inverse,
37
+ min_bin_width=min_bin_width,
38
+ min_bin_height=min_bin_height,
39
+ min_derivative=min_derivative,
40
+ **spline_kwargs
41
+ )
42
+ return outputs, logabsdet
43
+
44
+
45
+ def searchsorted(bin_locations, inputs, eps=1e-6):
46
+ bin_locations[..., -1] += eps
47
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
+
49
+
50
+ def unconstrained_rational_quadratic_spline(
51
+ inputs,
52
+ unnormalized_widths,
53
+ unnormalized_heights,
54
+ unnormalized_derivatives,
55
+ inverse=False,
56
+ tails="linear",
57
+ tail_bound=1.0,
58
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
61
+ ):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == "linear":
69
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ unnormalized_derivatives[..., -1] = constant
73
+
74
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
+ logabsdet[outside_interval_mask] = 0
76
+ else:
77
+ raise RuntimeError("{} tails are not implemented.".format(tails))
78
+
79
+ (
80
+ outputs[inside_interval_mask],
81
+ logabsdet[inside_interval_mask],
82
+ ) = rational_quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound,
89
+ right=tail_bound,
90
+ bottom=-tail_bound,
91
+ top=tail_bound,
92
+ min_bin_width=min_bin_width,
93
+ min_bin_height=min_bin_height,
94
+ min_derivative=min_derivative,
95
+ )
96
+
97
+ return outputs, logabsdet
98
+
99
+
100
+ def rational_quadratic_spline(
101
+ inputs,
102
+ unnormalized_widths,
103
+ unnormalized_heights,
104
+ unnormalized_derivatives,
105
+ inverse=False,
106
+ left=0.0,
107
+ right=1.0,
108
+ bottom=0.0,
109
+ top=1.0,
110
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
113
+ ):
114
+ if torch.min(inputs) < left or torch.max(inputs) > right:
115
+ raise ValueError("Input to a transform is not within its domain")
116
+
117
+ num_bins = unnormalized_widths.shape[-1]
118
+
119
+ if min_bin_width * num_bins > 1.0:
120
+ raise ValueError("Minimal bin width too large for the number of bins")
121
+ if min_bin_height * num_bins > 1.0:
122
+ raise ValueError("Minimal bin height too large for the number of bins")
123
+
124
+ widths = F.softmax(unnormalized_widths, dim=-1)
125
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
+ cumwidths = torch.cumsum(widths, dim=-1)
127
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
+ cumwidths = (right - left) * cumwidths + left
129
+ cumwidths[..., 0] = left
130
+ cumwidths[..., -1] = right
131
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
+
133
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
+
135
+ heights = F.softmax(unnormalized_heights, dim=-1)
136
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
+ cumheights = torch.cumsum(heights, dim=-1)
138
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
+ cumheights = (top - bottom) * cumheights + bottom
140
+ cumheights[..., 0] = bottom
141
+ cumheights[..., -1] = top
142
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
143
+
144
+ if inverse:
145
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
146
+ else:
147
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
+
149
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
+
152
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
+ delta = heights / widths
154
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
155
+
156
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
+
159
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
160
+
161
+ if inverse:
162
+ a = (inputs - input_cumheights) * (
163
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
+ ) + input_heights * (input_delta - input_derivatives)
165
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
+ )
168
+ c = -input_delta * (inputs - input_cumheights)
169
+
170
+ discriminant = b.pow(2) - 4 * a * c
171
+ assert (discriminant >= 0).all()
172
+
173
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
174
+ outputs = root * input_bin_widths + input_cumwidths
175
+
176
+ theta_one_minus_theta = root * (1 - root)
177
+ denominator = input_delta + (
178
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
+ * theta_one_minus_theta
180
+ )
181
+ derivative_numerator = input_delta.pow(2) * (
182
+ input_derivatives_plus_one * root.pow(2)
183
+ + 2 * input_delta * theta_one_minus_theta
184
+ + input_derivatives * (1 - root).pow(2)
185
+ )
186
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
+
188
+ return outputs, -logabsdet
189
+ else:
190
+ theta = (inputs - input_cumwidths) / input_bin_widths
191
+ theta_one_minus_theta = theta * (1 - theta)
192
+
193
+ numerator = input_heights * (
194
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
+ )
196
+ denominator = input_delta + (
197
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
+ * theta_one_minus_theta
199
+ )
200
+ outputs = input_cumheights + numerator / denominator
201
+
202
+ derivative_numerator = input_delta.pow(2) * (
203
+ input_derivatives_plus_one * theta.pow(2)
204
+ + 2 * input_delta * theta_one_minus_theta
205
+ + input_derivatives * (1 - theta).pow(2)
206
+ )
207
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
+
209
+ return outputs, logabsdet
src/main.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import hashlib
3
+ import os
4
+ import shlex
5
+ import subprocess
6
+ import librosa
7
+ import numpy as np
8
+ import soundfile as sf
9
+ import gradio as gr
10
+ from rvc import Config, load_hubert, get_vc, rvc_infer
11
+
12
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13
+ rvc_models_dir = os.path.join(BASE_DIR, 'rvc_models')
14
+ output_dir = os.path.join(BASE_DIR, 'song_output')
15
+
16
+ def get_rvc_model(voice_model):
17
+ model_dir = os.path.join(rvc_models_dir, voice_model)
18
+ rvc_model_path = next((os.path.join(model_dir, f) for f in os.listdir(model_dir) if f.endswith('.pth')), None)
19
+ rvc_index_path = next((os.path.join(model_dir, f) for f in os.listdir(model_dir) if f.endswith('.index')), None)
20
+
21
+ if rvc_model_path is None:
22
+ error_msg = f'В каталоге {model_dir} отсутствует файл модели.'
23
+ raise Exception(error_msg)
24
+
25
+ return rvc_model_path, rvc_index_path
26
+
27
+ def convert_to_stereo(audio_path):
28
+ wave, sr = librosa.load(audio_path, mono=False, sr=44100)
29
+ if type(wave[0]) != np.ndarray:
30
+ stereo_path = f'Voice_stereo.wav'
31
+ command = shlex.split(f'ffmpeg -y -loglevel error -i "{audio_path}" -ac 2 -f wav "{stereo_path}"')
32
+ subprocess.run(command)
33
+ return stereo_path
34
+ else:
35
+ return audio_path
36
+
37
+ def get_hash(filepath):
38
+ with open(filepath, 'rb') as f:
39
+ file_hash = hashlib.blake2b()
40
+ while chunk := f.read(8192):
41
+ file_hash.update(chunk)
42
+
43
+ return file_hash.hexdigest()[:11]
44
+
45
+ def display_progress(percent, message, progress=gr.Progress()):
46
+ progress(percent, desc=message)
47
+
48
+ def voice_change(voice_model, vocals_path, output_path, pitch_change, f0_method, index_rate, filter_radius, rms_mix_rate, protect, crepe_hop_length):
49
+ rvc_model_path, rvc_index_path = get_rvc_model(voice_model)
50
+ device = 'cuda:0'
51
+ config = Config(device, True)
52
+ hubert_model = load_hubert(device, config.is_half, os.path.join(rvc_models_dir, 'hubert_base.pt'))
53
+ cpt, version, net_g, tgt_sr, vc = get_vc(device, config.is_half, config, rvc_model_path)
54
+
55
+ rvc_infer(rvc_index_path, index_rate, vocals_path, output_path, pitch_change, f0_method, cpt, version, net_g,
56
+ filter_radius, tgt_sr, rms_mix_rate, protect, crepe_hop_length, vc, hubert_model)
57
+ del hubert_model, cpt
58
+ gc.collect()
59
+
60
+ def song_cover_pipeline(uploaded_file, voice_model, pitch_change, index_rate=0.5, filter_radius=3, rms_mix_rate=0.25, f0_method='rmvpe',
61
+ crepe_hop_length=128, protect=0.33, output_format='mp3', progress=gr.Progress()):
62
+
63
+ if not uploaded_file or not voice_model:
64
+ raise Exception('Убедитесь, что поле ввода песни и поле модели голоса заполнены.')
65
+
66
+ display_progress(0, '[~] Запуск конвейера генерации AI-кавера...', progress)
67
+
68
+ if not os.path.exists(uploaded_file):
69
+ error_msg = f'{uploaded_file} не существует.'
70
+ raise Exception(error_msg)
71
+
72
+ song_id = get_hash(uploaded_file)
73
+ song_dir = os.path.join(output_dir, song_id)
74
+ os.makedirs(song_dir, exist_ok=True)
75
+
76
+ orig_song_path = convert_to_stereo(uploaded_file)
77
+ ai_cover_path = os.path.join(song_dir, f'Converted_Voice.{output_format}')
78
+
79
+ if os.path.exists(ai_cover_path):
80
+ os.remove(ai_cover_path)
81
+
82
+ display_progress(0.5, '[~] Преобразование вокала...', progress)
83
+ voice_change(voice_model, orig_song_path, ai_cover_path, pitch_change, f0_method, index_rate,
84
+ filter_radius, rms_mix_rate, protect, crepe_hop_length)
85
+
86
+ return ai_cover_path
src/modules/file_processing.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def process_file_upload(file):
4
+ return gr.update(value=file)
src/modules/model_management.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import urllib.request
4
+ import zipfile
5
+ import gdown
6
+ import gradio as gr
7
+
8
+
9
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10
+ rvc_models_dir = os.path.join(BASE_DIR, 'rvc_models')
11
+
12
+
13
+ def ignore_files(models_dir):
14
+ models_list = os.listdir(models_dir)
15
+ items_to_remove = ['hubert_base.pt', 'MODELS.txt', 'rmvpe.pt', 'fcpe.pt']
16
+ return [item for item in models_list if item not in items_to_remove]
17
+
18
+
19
+ def update_models_list():
20
+ models_l = ignore_files(rvc_models_dir)
21
+ return gr.update(choices=models_l)
22
+
23
+
24
+ def extract_zip(extraction_folder, zip_name):
25
+ os.makedirs(extraction_folder)
26
+ with zipfile.ZipFile(zip_name, 'r') as zip_ref:
27
+ zip_ref.extractall(extraction_folder)
28
+ os.remove(zip_name)
29
+
30
+ index_filepath, model_filepath = None, None
31
+ for root, dirs, files in os.walk(extraction_folder):
32
+ for name in files:
33
+ if name.endswith('.index') and os.stat(os.path.join(root, name)).st_size > 1024 * 100:
34
+ index_filepath = os.path.join(root, name)
35
+ if name.endswith('.pth') and os.stat(os.path.join(root, name)).st_size > 1024 * 1024 * 40:
36
+ model_filepath = os.path.join(root, name)
37
+
38
+ if not model_filepath:
39
+ raise gr.Error(f'Не найден файл модели .pth в распакованном zip-файле. Пожалуйста, проверьте {extraction_folder}.')
40
+
41
+ os.rename(model_filepath, os.path.join(extraction_folder, os.path.basename(model_filepath)))
42
+ if index_filepath:
43
+ os.rename(index_filepath, os.path.join(extraction_folder, os.path.basename(index_filepath)))
44
+
45
+ for filepath in os.listdir(extraction_folder):
46
+ if os.path.isdir(os.path.join(extraction_folder, filepath)):
47
+ shutil.rmtree(os.path.join(extraction_folder, filepath))
48
+
49
+
50
+ def download_from_url(url, dir_name, progress=gr.Progress()):
51
+ try:
52
+ progress(0, desc=f'[~] Загрузка голосовой модели с именем {dir_name}...')
53
+ zip_name = url.split('/')[-1]
54
+ extraction_folder = os.path.join(rvc_models_dir, dir_name)
55
+ if os.path.exists(extraction_folder):
56
+ raise gr.Error(f'Директория голосовой модели {dir_name} уже существует! Выберите другое имя для вашей голосовой модели.')
57
+
58
+ if 'huggingface.co' in url:
59
+ urllib.request.urlretrieve(url, zip_name)
60
+ elif 'pixeldrain.com' in url:
61
+ zip_name = dir_name + '.zip'
62
+ url = f'https://pixeldrain.com/api/file/{zip_name}'
63
+ urllib.request.urlretrieve(url, zip_name)
64
+ elif 'drive.google.com' in url:
65
+ zip_name = dir_name + '.zip'
66
+ file_id = url.split('/')[-2]
67
+ output = os.path.join('.', f'{dir_name}.zip')
68
+ gdown.download(id=file_id, output=output, quiet=False)
69
+
70
+ progress(0.5, desc='[~] Распаковка zip-файла...')
71
+ extract_zip(extraction_folder, zip_name)
72
+ return f'[+] Модель {dir_name} успешно загружена!'
73
+ except Exception as e:
74
+ raise gr.Error(str(e))
75
+
76
+
77
+ def upload_zip_model(zip_path, dir_name, progress=gr.Progress()):
78
+ try:
79
+ extraction_folder = os.path.join(rvc_models_dir, dir_name)
80
+ if os.path.exists(extraction_folder):
81
+ raise gr.Error(f'Директория голосовой модели {dir_name} уже существует! Выберите другое имя для вашей голосовой модели.')
82
+
83
+ zip_name = zip_path.name
84
+ progress(0.5, desc='[~] Распаковка zip-файла...')
85
+ extract_zip(extraction_folder, zip_name)
86
+ return f'[+] Модель {dir_name} успешно загружена!'
87
+
88
+ except Exception as e:
89
+ raise gr.Error(str(e))
src/modules/ui_updates.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ def show_hop_slider(pitch_detection_algo):
5
+ if pitch_detection_algo in ['mangio-crepe']:
6
+ return gr.update(visible=True)
7
+ else:
8
+ return gr.update(visible=False)
9
+
10
+
11
+ def update_f0_method(use_hybrid_methods):
12
+ if use_hybrid_methods:
13
+ return gr.update(choices=['hybrid[rmvpe+fcpe]', 'hybrid[rmvpe+crepe]', 'hybrid[crepe+rmvpe]', 'hybrid[crepe+fcpe]', 'hybrid[crepe+rmvpe+fcpe]'], value='hybrid[rmvpe+fcpe]')
14
+ else:
15
+ return gr.update(choices=['rmvpe+', 'fcpe', 'rmvpe', 'mangio-crepe', 'crepe'], value='rmvpe+')
16
+
17
+
18
+ def update_button_text():
19
+ return gr.update(label="Загрузить другой аудио-файл")
20
+
21
+ def update_button_text_voc():
22
+ return gr.update(label="Загрузить другой вокал")
23
+
24
+ def update_button_text_inst():
25
+ return gr.update(label="Загрузить другой инструментал")
src/my_utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ffmpeg
2
+ import numpy as np
3
+
4
+
5
+ def load_audio(file, sr):
6
+ try:
7
+ file = (
8
+ file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
9
+ )
10
+ out, _ = (
11
+ ffmpeg.input(file, threads=0)
12
+ .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
13
+ .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
14
+ )
15
+ except Exception as e:
16
+ raise RuntimeError(f"Failed to load audio: {e}")
17
+
18
+ return np.frombuffer(out, np.float32).flatten()
src/rvc.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing import cpu_count
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from fairseq import checkpoint_utils
6
+ from scipy.io import wavfile
7
+
8
+ from infer_pack.models import (
9
+ SynthesizerTrnMs256NSFsid,
10
+ SynthesizerTrnMs256NSFsid_nono,
11
+ SynthesizerTrnMs768NSFsid,
12
+ SynthesizerTrnMs768NSFsid_nono,
13
+ )
14
+ from my_utils import load_audio
15
+ from vc_infer_pipeline import VC
16
+
17
+ BASE_DIR = Path(__file__).resolve().parent.parent
18
+
19
+
20
+ class Config:
21
+ def __init__(self, device, is_half):
22
+ self.device = device
23
+ self.is_half = is_half
24
+ self.n_cpu = 0
25
+ self.gpu_name = None
26
+ self.gpu_mem = None
27
+ self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
28
+
29
+ def device_config(self) -> tuple:
30
+ if torch.cuda.is_available():
31
+ i_device = int(self.device.split(":")[-1])
32
+ self.gpu_name = torch.cuda.get_device_name(i_device)
33
+ if (
34
+ ("16" in self.gpu_name and "V100" not in self.gpu_name.upper())
35
+ or "P40" in self.gpu_name.upper()
36
+ or "1060" in self.gpu_name
37
+ or "1070" in self.gpu_name
38
+ or "1080" in self.gpu_name
39
+ ):
40
+ print("16 series/10 series P40 forced single precision")
41
+ self.is_half = False
42
+ for config_file in ["32k.json", "40k.json", "48k.json"]:
43
+ with open(BASE_DIR / "src" / "configs" / config_file, "r") as f:
44
+ strr = f.read().replace("true", "false")
45
+ with open(BASE_DIR / "src" / "configs" / config_file, "w") as f:
46
+ f.write(strr)
47
+ with open(BASE_DIR / "src" / "trainset_preprocess_pipeline_print.py", "r") as f:
48
+ strr = f.read().replace("3.7", "3.0")
49
+ with open(BASE_DIR / "src" / "trainset_preprocess_pipeline_print.py", "w") as f:
50
+ f.write(strr)
51
+ else:
52
+ self.gpu_name = None
53
+ self.gpu_mem = int(
54
+ torch.cuda.get_device_properties(i_device).total_memory
55
+ / 1024
56
+ / 1024
57
+ / 1024
58
+ + 0.4
59
+ )
60
+ if self.gpu_mem <= 4:
61
+ with open(BASE_DIR / "src" / "trainset_preprocess_pipeline_print.py", "r") as f:
62
+ strr = f.read().replace("3.7", "3.0")
63
+ with open(BASE_DIR / "src" / "trainset_preprocess_pipeline_print.py", "w") as f:
64
+ f.write(strr)
65
+ elif torch.backends.mps.is_available():
66
+ print("No supported N-card found, use MPS for inference")
67
+ self.device = "mps"
68
+ else:
69
+ print("No supported N-card found, use CPU for inference")
70
+ self.device = "cpu"
71
+ self.is_half = True
72
+
73
+ if self.n_cpu == 0:
74
+ self.n_cpu = cpu_count()
75
+
76
+ if self.is_half:
77
+ # 6G memory config
78
+ x_pad = 3
79
+ x_query = 10
80
+ x_center = 60
81
+ x_max = 65
82
+ else:
83
+ # 5G memory config
84
+ x_pad = 1
85
+ x_query = 6
86
+ x_center = 38
87
+ x_max = 41
88
+
89
+ if self.gpu_mem != None and self.gpu_mem <= 4:
90
+ x_pad = 1
91
+ x_query = 5
92
+ x_center = 30
93
+ x_max = 32
94
+
95
+ return x_pad, x_query, x_center, x_max
96
+
97
+
98
+ def load_hubert(device, is_half, model_path):
99
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([model_path], suffix='', )
100
+ hubert = models[0]
101
+ hubert = hubert.to(device)
102
+
103
+ if is_half:
104
+ hubert = hubert.half()
105
+ else:
106
+ hubert = hubert.float()
107
+
108
+ hubert.eval()
109
+ return hubert
110
+
111
+
112
+ def get_vc(device, is_half, config, model_path):
113
+ cpt = torch.load(model_path, map_location='cpu')
114
+ if "config" not in cpt or "weight" not in cpt:
115
+ raise ValueError(f'Incorrect format for {model_path}. Use a voice model trained using RVC v2 instead.')
116
+
117
+ tgt_sr = cpt["config"][-1]
118
+ cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
119
+ if_f0 = cpt.get("f0", 1)
120
+ version = cpt.get("version", "v1")
121
+
122
+ if version == "v1":
123
+ if if_f0 == 1:
124
+ net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=is_half)
125
+ else:
126
+ net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
127
+ elif version == "v2":
128
+ if if_f0 == 1:
129
+ net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=is_half)
130
+ else:
131
+ net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
132
+
133
+ del net_g.enc_q
134
+ print(net_g.load_state_dict(cpt["weight"], strict=False))
135
+ net_g.eval().to(device)
136
+
137
+ if is_half:
138
+ net_g = net_g.half()
139
+ else:
140
+ net_g = net_g.float()
141
+
142
+ vc = VC(tgt_sr, config)
143
+ return cpt, version, net_g, tgt_sr, vc
144
+
145
+
146
+ def rvc_infer(
147
+ index_path,
148
+ index_rate,
149
+ input_path,
150
+ output_path,
151
+ pitch_change,
152
+ f0_method,
153
+ cpt,
154
+ version,
155
+ net_g,
156
+ filter_radius,
157
+ tgt_sr,
158
+ rms_mix_rate,
159
+ protect,
160
+ crepe_hop_length,
161
+ vc,
162
+ hubert_model
163
+ ):
164
+ audio = load_audio(input_path, 16000)
165
+ times = [0, 0, 0]
166
+ if_f0 = cpt.get('f0', 1)
167
+ audio_opt = vc.pipeline(
168
+ hubert_model,
169
+ net_g,
170
+ 0,
171
+ audio,
172
+ input_path,
173
+ times,
174
+ pitch_change,
175
+ f0_method,
176
+ index_path,
177
+ index_rate,
178
+ if_f0,
179
+ filter_radius,
180
+ tgt_sr,
181
+ 0,
182
+ rms_mix_rate,
183
+ version,
184
+ protect,
185
+ crepe_hop_length
186
+ )
187
+ wavfile.write(output_path, tgt_sr, audio_opt)
src/trainset_preprocess_pipeline_print.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os, multiprocessing
2
+ from scipy import signal
3
+
4
+ now_dir = os.getcwd()
5
+ sys.path.append(now_dir)
6
+
7
+ inp_root = sys.argv[1]
8
+ sr = int(sys.argv[2])
9
+ n_p = int(sys.argv[3])
10
+ exp_dir = sys.argv[4]
11
+ noparallel = sys.argv[5] == "True"
12
+ import numpy as np, os, traceback
13
+ from slicer2 import Slicer
14
+ import librosa, traceback
15
+ from scipy.io import wavfile
16
+ import multiprocessing
17
+ from my_utils import load_audio
18
+ import tqdm
19
+
20
+ DoFormant = False
21
+ Quefrency = 1.0
22
+ Timbre = 1.0
23
+
24
+ mutex = multiprocessing.Lock()
25
+ f = open("%s/preprocess.log" % exp_dir, "a+")
26
+
27
+
28
+ def println(strr):
29
+ mutex.acquire()
30
+ print(strr)
31
+ f.write("%s\n" % strr)
32
+ f.flush()
33
+ mutex.release()
34
+
35
+
36
+ class PreProcess:
37
+ def __init__(self, sr, exp_dir):
38
+ self.slicer = Slicer(
39
+ sr=sr,
40
+ threshold=-42,
41
+ min_length=1500,
42
+ min_interval=400,
43
+ hop_size=15,
44
+ max_sil_kept=500,
45
+ )
46
+ self.sr = sr
47
+ self.bh, self.ah = signal.butter(N=5, Wn=48, btype="high", fs=self.sr)
48
+ self.per = 3.0
49
+ self.overlap = 0.3
50
+ self.tail = self.per + self.overlap
51
+ self.max = 0.9
52
+ self.alpha = 0.75
53
+ self.exp_dir = exp_dir
54
+ self.gt_wavs_dir = "%s/0_gt_wavs" % exp_dir
55
+ self.wavs16k_dir = "%s/1_16k_wavs" % exp_dir
56
+ os.makedirs(self.exp_dir, exist_ok=True)
57
+ os.makedirs(self.gt_wavs_dir, exist_ok=True)
58
+ os.makedirs(self.wavs16k_dir, exist_ok=True)
59
+
60
+ def norm_write(self, tmp_audio, idx0, idx1):
61
+ tmp_max = np.abs(tmp_audio).max()
62
+ if tmp_max > 2.5:
63
+ print("%s-%s-%s-filtered" % (idx0, idx1, tmp_max))
64
+ return
65
+ tmp_audio = (tmp_audio / tmp_max * (self.max * self.alpha)) + (
66
+ 1 - self.alpha
67
+ ) * tmp_audio
68
+ wavfile.write(
69
+ "%s/%s_%s.wav" % (self.gt_wavs_dir, idx0, idx1),
70
+ self.sr,
71
+ tmp_audio.astype(np.float32),
72
+ )
73
+ tmp_audio = librosa.resample(
74
+ tmp_audio, orig_sr=self.sr, target_sr=16000
75
+ ) # , res_type="soxr_vhq"
76
+ wavfile.write(
77
+ "%s/%s_%s.wav" % (self.wavs16k_dir, idx0, idx1),
78
+ 16000,
79
+ tmp_audio.astype(np.float32),
80
+ )
81
+
82
+ def pipeline(self, path, idx0):
83
+ try:
84
+ audio = load_audio(path, self.sr, DoFormant, Quefrency, Timbre)
85
+ # zero phased digital filter cause pre-ringing noise...
86
+ # audio = signal.filtfilt(self.bh, self.ah, audio)
87
+ audio = signal.lfilter(self.bh, self.ah, audio)
88
+
89
+ idx1 = 0
90
+ for audio in self.slicer.slice(audio):
91
+ i = 0
92
+ while 1:
93
+ start = int(self.sr * (self.per - self.overlap) * i)
94
+ i += 1
95
+ if len(audio[start:]) > self.tail * self.sr:
96
+ tmp_audio = audio[start : start + int(self.per * self.sr)]
97
+ self.norm_write(tmp_audio, idx0, idx1)
98
+ idx1 += 1
99
+ else:
100
+ tmp_audio = audio[start:]
101
+ idx1 += 1
102
+ break
103
+ self.norm_write(tmp_audio, idx0, idx1)
104
+ # println("%s->Suc." % path)
105
+ except:
106
+ println("%s->%s" % (path, traceback.format_exc()))
107
+
108
+ def pipeline_mp(self, infos, thread_n):
109
+ for path, idx0 in tqdm.tqdm(
110
+ infos, position=thread_n, leave=True, desc="thread:%s" % thread_n
111
+ ):
112
+ self.pipeline(path, idx0)
113
+
114
+ def pipeline_mp_inp_dir(self, inp_root, n_p):
115
+ try:
116
+ infos = [
117
+ ("%s/%s" % (inp_root, name), idx)
118
+ for idx, name in enumerate(sorted(list(os.listdir(inp_root))))
119
+ ]
120
+ if noparallel:
121
+ for i in range(n_p):
122
+ self.pipeline_mp(infos[i::n_p])
123
+ else:
124
+ ps = []
125
+ for i in range(n_p):
126
+ p = multiprocessing.Process(
127
+ target=self.pipeline_mp, args=(infos[i::n_p], i)
128
+ )
129
+ ps.append(p)
130
+ p.start()
131
+ for i in range(n_p):
132
+ ps[i].join()
133
+ except:
134
+ println("Fail. %s" % traceback.format_exc())
135
+
136
+
137
+ def preprocess_trainset(inp_root, sr, n_p, exp_dir):
138
+ pp = PreProcess(sr, exp_dir)
139
+ println("start preprocess")
140
+ println(sys.argv)
141
+ pp.pipeline_mp_inp_dir(inp_root, n_p)
142
+ println("end preprocess")
143
+
144
+
145
+ if __name__ == "__main__":
146
+ preprocess_trainset(inp_root, sr, n_p, exp_dir)
src/vc_infer_pipeline.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ import numpy as np, parselmouth, torch, pdb, sys, os
3
+ from time import time as ttime
4
+ import torch.nn.functional as F
5
+ import torchcrepe
6
+ from scipy import signal
7
+ from torch import Tensor
8
+ import pyworld, os, faiss, librosa, torchcrepe
9
+ import random
10
+ import gc
11
+ import re
12
+
13
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14
+ now_dir = os.path.join(BASE_DIR, 'src')
15
+ sys.path.append(now_dir)
16
+
17
+ from infer_pack.predictor.FCPE import FCPEF0Predictor
18
+ from infer_pack.predictor.RMVPE import RMVPE
19
+
20
+ bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
21
+
22
+ input_audio_path2wav = {}
23
+
24
+
25
+ @lru_cache
26
+ def cache_harvest_f0(input_audio_path, fs, f0max, f0min, frame_period):
27
+ audio = input_audio_path2wav[input_audio_path]
28
+ f0, t = pyworld.harvest(
29
+ audio,
30
+ fs=fs,
31
+ f0_ceil=f0max,
32
+ f0_floor=f0min,
33
+ frame_period=frame_period,
34
+ )
35
+ f0 = pyworld.stonemask(audio, f0, t, fs)
36
+ return f0
37
+
38
+
39
+ def change_rms(data1, sr1, data2, sr2, rate):
40
+ rms1 = librosa.feature.rms(y=data1, frame_length=sr1 // 2 * 2, hop_length=sr1 // 2)
41
+ rms2 = librosa.feature.rms(y=data2, frame_length=sr2 // 2 * 2, hop_length=sr2 // 2)
42
+ rms1 = torch.from_numpy(rms1)
43
+ rms1 = F.interpolate(rms1.unsqueeze(0), size=data2.shape[0], mode="linear").squeeze()
44
+ rms2 = torch.from_numpy(rms2)
45
+ rms2 = F.interpolate(rms2.unsqueeze(0), size=data2.shape[0], mode="linear").squeeze()
46
+ rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-6)
47
+ data2 *= (torch.pow(rms1, torch.tensor(1 - rate))* torch.pow(rms2, torch.tensor(rate - 1))).numpy()
48
+ return data2
49
+
50
+
51
+ class VC(object):
52
+ def __init__(self, tgt_sr, config):
53
+ self.x_pad, self.x_query, self.x_center, self.x_max, self.is_half = (
54
+ config.x_pad,
55
+ config.x_query,
56
+ config.x_center,
57
+ config.x_max,
58
+ config.is_half,
59
+ )
60
+ self.sr = 16000
61
+ self.window = 160
62
+ self.t_pad = self.sr * self.x_pad
63
+ self.t_pad_tgt = tgt_sr * self.x_pad
64
+ self.t_pad2 = self.t_pad * 2
65
+ self.t_query = self.sr * self.x_query
66
+ self.t_center = self.sr * self.x_center
67
+ self.t_max = self.sr * self.x_max
68
+ self.device = config.device
69
+
70
+
71
+ def get_optimal_torch_device(self, index: int = 0) -> torch.device:
72
+ if torch.cuda.is_available():
73
+ return torch.device(f"cuda:{index % torch.cuda.device_count()}")
74
+ elif torch.backends.mps.is_available():
75
+ return torch.device("mps")
76
+ return torch.device("cpu")
77
+
78
+ def get_f0_crepe_computation(
79
+ self,
80
+ x,
81
+ f0_min,
82
+ f0_max,
83
+ p_len,
84
+ hop_length=160,
85
+ model="full",
86
+ ):
87
+ x = x.astype(np.float32)
88
+ x /= np.quantile(np.abs(x), 0.999)
89
+ torch_device = self.get_optimal_torch_device()
90
+ audio = torch.from_numpy(x).to(torch_device, copy=True)
91
+ audio = torch.unsqueeze(audio, dim=0)
92
+ if audio.ndim == 2 and audio.shape[0] > 1:
93
+ audio = torch.mean(audio, dim=0, keepdim=True).detach()
94
+ audio = audio.detach()
95
+ pitch: Tensor = torchcrepe.predict(
96
+ audio,
97
+ self.sr,
98
+ hop_length,
99
+ f0_min,
100
+ f0_max,
101
+ model,
102
+ batch_size=hop_length * 2,
103
+ device=torch_device,
104
+ pad=True,
105
+ )
106
+ p_len = p_len or x.shape[0] // hop_length
107
+ source = np.array(pitch.squeeze(0).cpu().float().numpy())
108
+ source[source < 0.001] = np.nan
109
+ target = np.interp(
110
+ np.arange(0, len(source) * p_len, len(source)) / p_len,
111
+ np.arange(0, len(source)),
112
+ source,
113
+ )
114
+ f0 = np.nan_to_num(target)
115
+ return f0
116
+
117
+ def get_f0_official_crepe_computation(
118
+ self,
119
+ x,
120
+ f0_min,
121
+ f0_max,
122
+ model="full",
123
+ ):
124
+ batch_size = 512
125
+ audio = torch.tensor(np.copy(x))[None].float()
126
+ f0, pd = torchcrepe.predict(
127
+ audio,
128
+ self.sr,
129
+ self.window,
130
+ f0_min,
131
+ f0_max,
132
+ model,
133
+ batch_size=batch_size,
134
+ device=self.device,
135
+ return_periodicity=True,
136
+ )
137
+ pd = torchcrepe.filter.median(pd, 3)
138
+ f0 = torchcrepe.filter.mean(f0, 3)
139
+ f0[pd < 0.1] = 0
140
+ f0 = f0[0].cpu().numpy()
141
+ return f0
142
+
143
+ def get_f0_hybrid_computation(
144
+ self,
145
+ methods_str,
146
+ input_audio_path,
147
+ x,
148
+ f0_min,
149
+ f0_max,
150
+ p_len,
151
+ filter_radius,
152
+ crepe_hop_length,
153
+ time_step,
154
+ ):
155
+ methods_str = re.search("hybrid\[(.+)\]", methods_str)
156
+ if methods_str:
157
+ methods = [method.strip() for method in methods_str.group(1).split("+")]
158
+ f0_computation_stack = []
159
+ print(f"Calculating f0 pitch estimations for methods {str(methods)}")
160
+ x = x.astype(np.float32)
161
+ x /= np.quantile(np.abs(x), 0.999)
162
+ for method in methods:
163
+ f0 = None
164
+ if method == "mangio-crepe":
165
+ f0 = self.get_f0_crepe_computation(
166
+ x, f0_min, f0_max, p_len, crepe_hop_length
167
+ )
168
+ elif method == "rmvpe":
169
+ if hasattr(self, "model_rmvpe") == False:
170
+
171
+ self.model_rmvpe = RMVPE(
172
+ os.path.join(BASE_DIR, 'rvc_models', 'rmvpe.pt'), is_half=self.is_half, device=self.device
173
+ )
174
+ f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03)
175
+ f0 = f0[1:]
176
+ elif method == "fcpe":
177
+ self.model_fcpe = FCPEF0Predictor(
178
+ os.path.join(BASE_DIR, 'rvc_models', 'fcpe.pt'),
179
+ f0_min=int(f0_min),
180
+ f0_max=int(f0_max),
181
+ dtype=torch.float32,
182
+ device=self.device,
183
+ sampling_rate=self.sr,
184
+ threshold=0.03,
185
+ )
186
+ f0 = self.model_fcpe.compute_f0(x, p_len=p_len)
187
+ del self.model_fcpe
188
+ gc.collect()
189
+ f0_computation_stack.append(f0)
190
+
191
+ print(f"Calculating hybrid median f0 from the stack of {str(methods)}")
192
+ f0_computation_stack = [fc for fc in f0_computation_stack if fc is not None]
193
+ f0_median_hybrid = None
194
+ if len(f0_computation_stack) == 1:
195
+ f0_median_hybrid = f0_computation_stack[0]
196
+ else:
197
+ f0_median_hybrid = np.nanmedian(f0_computation_stack, axis=0)
198
+ return f0_median_hybrid
199
+
200
+ def get_f0(
201
+ self,
202
+ input_audio_path,
203
+ x,
204
+ p_len,
205
+ f0_up_key,
206
+ f0_method,
207
+ filter_radius,
208
+ crepe_hop_length,
209
+ inp_f0=None,
210
+ ):
211
+ global input_audio_path2wav
212
+ time_step = self.window / self.sr * 1000
213
+ f0_min = 50
214
+ f0_max = 1100
215
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
216
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
217
+ if f0_method == "pm":
218
+ f0 = (
219
+ parselmouth.Sound(x, self.sr)
220
+ .to_pitch_ac(
221
+ time_step=time_step / 1000,
222
+ voicing_threshold=0.6,
223
+ pitch_floor=f0_min,
224
+ pitch_ceiling=f0_max,
225
+ )
226
+ .selected_array["frequency"]
227
+ )
228
+ pad_size = (p_len - len(f0) + 1) // 2
229
+ if pad_size > 0 or p_len - len(f0) - pad_size > 0:
230
+ f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
231
+
232
+ elif f0_method == "harvest":
233
+ input_audio_path2wav[input_audio_path] = x.astype(np.double)
234
+ f0 = cache_harvest_f0(input_audio_path, self.sr, f0_max, f0_min, 10)
235
+ if int(filter_radius) > 2:
236
+ f0 = signal.medfilt(f0, 3)
237
+
238
+ elif f0_method == "dio":
239
+ f0, t = pyworld.dio(
240
+ x.astype(np.double),
241
+ fs=self.sr,
242
+ f0_ceil=f0_max,
243
+ f0_floor=f0_min,
244
+ frame_period=10,
245
+ )
246
+ f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.sr)
247
+ f0 = signal.medfilt(f0, 3)
248
+
249
+ elif f0_method == "crepe":
250
+ f0 = self.get_f0_official_crepe_computation(x, f0_min, f0_max)
251
+
252
+ elif f0_method == "mangio-crepe":
253
+ f0 = self.get_f0_crepe_computation(x, f0_min, f0_max, p_len, crepe_hop_length)
254
+
255
+ elif f0_method == "rmvpe":
256
+ if hasattr(self, "model_rmvpe") == False:
257
+
258
+ self.model_rmvpe = RMVPE(
259
+ os.path.join(BASE_DIR, 'rvc_models', 'rmvpe.pt'), is_half=self.is_half, device=self.device
260
+ )
261
+ f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03)
262
+
263
+ elif f0_method == "rmvpe+":
264
+ params = {'x': x, 'p_len': p_len, 'f0_up_key': f0_up_key, 'f0_min': f0_min,
265
+ 'f0_max': f0_max, 'time_step': time_step, 'filter_radius': filter_radius,
266
+ 'crepe_hop_length': crepe_hop_length, 'model': "full"
267
+ }
268
+ f0 = self.get_pitch_dependant_rmvpe(**params)
269
+
270
+ elif f0_method == "fcpe":
271
+ self.model_fcpe = FCPEF0Predictor(
272
+ os.path.join(BASE_DIR, 'rvc_models', 'fcpe.pt'),
273
+ f0_min=int(f0_min),
274
+ f0_max=int(f0_max),
275
+ dtype=torch.float32,
276
+ device=self.device,
277
+ sampling_rate=self.sr,
278
+ threshold=0.03,
279
+ )
280
+ f0 = self.model_fcpe.compute_f0(x, p_len=p_len)
281
+ del self.model_fcpe
282
+ gc.collect()
283
+
284
+ elif "hybrid" in f0_method:
285
+ input_audio_path2wav[input_audio_path] = x.astype(np.double)
286
+ f0 = self.get_f0_hybrid_computation(
287
+ f0_method,
288
+ input_audio_path,
289
+ x,
290
+ f0_min,
291
+ f0_max,
292
+ p_len,
293
+ filter_radius,
294
+ crepe_hop_length,
295
+ time_step,
296
+ )
297
+
298
+ f0 *= pow(2, f0_up_key / 12)
299
+ tf0 = self.sr // self.window
300
+ if inp_f0 is not None:
301
+ delta_t = np.round(
302
+ (inp_f0[:, 0].max() - inp_f0[:, 0].min()) * tf0 + 1
303
+ ).astype("int16")
304
+ replace_f0 = np.interp(
305
+ list(range(delta_t)), inp_f0[:, 0] * 100, inp_f0[:, 1]
306
+ )
307
+ shape = f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)].shape[0]
308
+ f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)] = replace_f0[
309
+ :shape
310
+ ]
311
+ f0bak = f0.copy()
312
+ f0_mel = 1127 * np.log(1 + f0 / 700)
313
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (
314
+ f0_mel_max - f0_mel_min
315
+ ) + 1
316
+ f0_mel[f0_mel <= 1] = 1
317
+ f0_mel[f0_mel > 255] = 255
318
+ f0_coarse = np.rint(f0_mel).astype(int)
319
+
320
+ return f0_coarse, f0bak
321
+
322
+ def get_pitch_dependant_rmvpe(self, x, f0_min=1, f0_max=40000, *args, **kwargs):
323
+ if not hasattr(self, "model_rmvpe"):
324
+
325
+ self.model_rmvpe = RMVPE(
326
+ os.path.join(BASE_DIR, 'rvc_models', 'rmvpe.pt'),
327
+ is_half=self.is_half,
328
+ device=self.device,
329
+ )
330
+
331
+ f0 = self.model_rmvpe.infer_from_audio_with_pitch(x, thred=0.03, f0_min=f0_min, f0_max=f0_max)
332
+
333
+ return f0
334
+
335
+
336
+ def vc(
337
+ self,
338
+ model,
339
+ net_g,
340
+ sid,
341
+ audio0,
342
+ pitch,
343
+ pitchf,
344
+ times,
345
+ index,
346
+ big_npy,
347
+ index_rate,
348
+ version,
349
+ protect,
350
+ ):
351
+ feats = torch.from_numpy(audio0)
352
+ if self.is_half:
353
+ feats = feats.half()
354
+ else:
355
+ feats = feats.float()
356
+ if feats.dim() == 2:
357
+ feats = feats.mean(-1)
358
+ assert feats.dim() == 1, feats.dim()
359
+ feats = feats.view(1, -1)
360
+ padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
361
+
362
+ inputs = {
363
+ "source": feats.to(self.device),
364
+ "padding_mask": padding_mask,
365
+ "output_layer": 9 if version == "v1" else 12,
366
+ }
367
+ t0 = ttime()
368
+ with torch.no_grad():
369
+ logits = model.extract_features(**inputs)
370
+ feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
371
+ if protect < 0.5 and pitch != None and pitchf != None:
372
+ feats0 = feats.clone()
373
+ if (
374
+ isinstance(index, type(None)) == False
375
+ and isinstance(big_npy, type(None)) == False
376
+ and index_rate != 0
377
+ ):
378
+ npy = feats[0].cpu().numpy()
379
+ if self.is_half:
380
+ npy = npy.astype("float32")
381
+
382
+ score, ix = index.search(npy, k=8)
383
+ weight = np.square(1 / score)
384
+ weight /= weight.sum(axis=1, keepdims=True)
385
+ npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
386
+
387
+ if self.is_half:
388
+ npy = npy.astype("float16")
389
+ feats = (
390
+ torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate
391
+ + (1 - index_rate) * feats
392
+ )
393
+
394
+ feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
395
+ if protect < 0.5 and pitch != None and pitchf != None:
396
+ feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(
397
+ 0, 2, 1
398
+ )
399
+ t1 = ttime()
400
+ p_len = audio0.shape[0] // self.window
401
+ if feats.shape[1] < p_len:
402
+ p_len = feats.shape[1]
403
+ if pitch != None and pitchf != None:
404
+ pitch = pitch[:, :p_len]
405
+ pitchf = pitchf[:, :p_len]
406
+
407
+ if protect < 0.5 and pitch != None and pitchf != None:
408
+ pitchff = pitchf.clone()
409
+ pitchff[pitchf > 0] = 1
410
+ pitchff[pitchf < 1] = protect
411
+ pitchff = pitchff.unsqueeze(-1)
412
+ feats = feats * pitchff + feats0 * (1 - pitchff)
413
+ feats = feats.to(feats0.dtype)
414
+ p_len = torch.tensor([p_len], device=self.device).long()
415
+ with torch.no_grad():
416
+ if pitch != None and pitchf != None:
417
+ audio1 = (
418
+ (net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0])
419
+ .data.cpu()
420
+ .float()
421
+ .numpy()
422
+ )
423
+ else:
424
+ audio1 = (
425
+ (net_g.infer(feats, p_len, sid)[0][0, 0]).data.cpu().float().numpy()
426
+ )
427
+ del feats, p_len, padding_mask
428
+ if torch.cuda.is_available():
429
+ torch.cuda.empty_cache()
430
+ t2 = ttime()
431
+ times[0] += t1 - t0
432
+ times[2] += t2 - t1
433
+ return audio1
434
+
435
+ def pipeline(
436
+ self,
437
+ model,
438
+ net_g,
439
+ sid,
440
+ audio,
441
+ input_audio_path,
442
+ times,
443
+ f0_up_key,
444
+ f0_method,
445
+ file_index,
446
+ index_rate,
447
+ if_f0,
448
+ filter_radius,
449
+ tgt_sr,
450
+ resample_sr,
451
+ rms_mix_rate,
452
+ version,
453
+ protect,
454
+ crepe_hop_length,
455
+ f0_file=None,
456
+ ):
457
+ if file_index != "" and os.path.exists(file_index) == True and index_rate != 0:
458
+ try:
459
+ index = faiss.read_index(file_index)
460
+ big_npy = index.reconstruct_n(0, index.ntotal)
461
+ except Exception as error:
462
+ print(error)
463
+ index = big_npy = None
464
+ else:
465
+ index = big_npy = None
466
+ audio = signal.filtfilt(bh, ah, audio)
467
+ audio_pad = np.pad(audio, (self.window // 2, self.window // 2), mode="reflect")
468
+ opt_ts = []
469
+ if audio_pad.shape[0] > self.t_max:
470
+ audio_sum = np.zeros_like(audio)
471
+ for i in range(self.window):
472
+ audio_sum += audio_pad[i : i - self.window]
473
+ for t in range(self.t_center, audio.shape[0], self.t_center):
474
+ opt_ts.append(
475
+ t
476
+ - self.t_query
477
+ + np.where(
478
+ np.abs(audio_sum[t - self.t_query : t + self.t_query])
479
+ == np.abs(audio_sum[t - self.t_query : t + self.t_query]).min()
480
+ )[0][0]
481
+ )
482
+ s = 0
483
+ audio_opt = []
484
+ t = None
485
+ t1 = ttime()
486
+ audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
487
+ p_len = audio_pad.shape[0] // self.window
488
+ inp_f0 = None
489
+ if hasattr(f0_file, "name") == True:
490
+ try:
491
+ with open(f0_file.name, "r") as f:
492
+ lines = f.read().strip("\n").split("\n")
493
+ inp_f0 = []
494
+ for line in lines:
495
+ inp_f0.append([float(i) for i in line.split(",")])
496
+ inp_f0 = np.array(inp_f0, dtype="float32")
497
+ except Exception as error:
498
+ print(error)
499
+ sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
500
+ pitch, pitchf = None, None
501
+ if if_f0 == 1:
502
+ pitch, pitchf = self.get_f0(
503
+ input_audio_path,
504
+ audio_pad,
505
+ p_len,
506
+ f0_up_key,
507
+ f0_method,
508
+ filter_radius,
509
+ crepe_hop_length,
510
+ inp_f0,
511
+ )
512
+ pitch = pitch[:p_len]
513
+ pitchf = pitchf[:p_len]
514
+ if self.device == "mps":
515
+ pitchf = pitchf.astype(np.float32)
516
+ pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
517
+ pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
518
+ t2 = ttime()
519
+ times[1] += t2 - t1
520
+ for t in opt_ts:
521
+ t = t // self.window * self.window
522
+ if if_f0 == 1:
523
+ audio_opt.append(
524
+ self.vc(
525
+ model,
526
+ net_g,
527
+ sid,
528
+ audio_pad[s : t + self.t_pad2 + self.window],
529
+ pitch[:, s // self.window : (t + self.t_pad2) // self.window],
530
+ pitchf[:, s // self.window : (t + self.t_pad2) // self.window],
531
+ times,
532
+ index,
533
+ big_npy,
534
+ index_rate,
535
+ version,
536
+ protect,
537
+ )[self.t_pad_tgt : -self.t_pad_tgt]
538
+ )
539
+ else:
540
+ audio_opt.append(
541
+ self.vc(
542
+ model,
543
+ net_g,
544
+ sid,
545
+ audio_pad[s : t + self.t_pad2 + self.window],
546
+ None,
547
+ None,
548
+ times,
549
+ index,
550
+ big_npy,
551
+ index_rate,
552
+ version,
553
+ protect,
554
+ )[self.t_pad_tgt : -self.t_pad_tgt]
555
+ )
556
+ s = t
557
+ if if_f0 == 1:
558
+ audio_opt.append(
559
+ self.vc(
560
+ model,
561
+ net_g,
562
+ sid,
563
+ audio_pad[t:],
564
+ pitch[:, t // self.window :] if t is not None else pitch,
565
+ pitchf[:, t // self.window :] if t is not None else pitchf,
566
+ times,
567
+ index,
568
+ big_npy,
569
+ index_rate,
570
+ version,
571
+ protect,
572
+ )[self.t_pad_tgt : -self.t_pad_tgt]
573
+ )
574
+ else:
575
+ audio_opt.append(
576
+ self.vc(
577
+ model,
578
+ net_g,
579
+ sid,
580
+ audio_pad[t:],
581
+ None,
582
+ None,
583
+ times,
584
+ index,
585
+ big_npy,
586
+ index_rate,
587
+ version,
588
+ protect,
589
+ )[self.t_pad_tgt : -self.t_pad_tgt]
590
+ )
591
+ audio_opt = np.concatenate(audio_opt)
592
+ if rms_mix_rate != 1:
593
+ audio_opt = change_rms(audio, 16000, audio_opt, tgt_sr, rms_mix_rate)
594
+ if resample_sr >= 16000 and tgt_sr != resample_sr:
595
+ audio_opt = librosa.resample(
596
+ audio_opt, orig_sr=tgt_sr, target_sr=resample_sr
597
+ )
598
+ audio_max = np.abs(audio_opt).max() / 0.99
599
+ max_int16 = 32768
600
+ if audio_max > 1:
601
+ max_int16 /= audio_max
602
+ audio_opt = (audio_opt * max_int16).astype(np.int16)
603
+ del pitch, pitchf, sid
604
+ if torch.cuda.is_available():
605
+ torch.cuda.empty_cache()
606
+ return audio_opt