add ddsp-svc
Browse files- DDSP-SVC/.gitignore +10 -0
- DDSP-SVC/LICENSE +21 -0
- DDSP-SVC/data/train/.gitignore +3 -0
- DDSP-SVC/data/train/audio/.gitignore +2 -0
- DDSP-SVC/data/val/.gitignore +3 -0
- DDSP-SVC/data/val/audio/.gitignore +2 -0
- DDSP-SVC/data_loaders.py +244 -0
- DDSP-SVC/ddsp/__init__.py +0 -0
- DDSP-SVC/ddsp/core.py +281 -0
- DDSP-SVC/ddsp/loss.py +57 -0
- DDSP-SVC/ddsp/pcmer.py +380 -0
- DDSP-SVC/ddsp/unit2control.py +86 -0
- DDSP-SVC/ddsp/vocoder.py +652 -0
- DDSP-SVC/diffusion/data_loaders.py +271 -0
- DDSP-SVC/diffusion/diffusion.py +317 -0
- DDSP-SVC/diffusion/dpm_solver_pytorch.py +1201 -0
- DDSP-SVC/diffusion/infer_gt_mel.py +78 -0
- DDSP-SVC/diffusion/solver.py +171 -0
- DDSP-SVC/diffusion/unit2mel.py +96 -0
- DDSP-SVC/diffusion/vocoder.py +87 -0
- DDSP-SVC/diffusion/wavenet.py +108 -0
- DDSP-SVC/draw.py +102 -0
- DDSP-SVC/encoder/hubert/model.py +293 -0
- DDSP-SVC/enhancer.py +115 -0
- DDSP-SVC/exp/.gitignore +2 -0
- DDSP-SVC/flask_api.py +178 -0
- DDSP-SVC/gui.py +483 -0
- DDSP-SVC/gui_diff.py +576 -0
- DDSP-SVC/gui_diff_locale.py +154 -0
- DDSP-SVC/gui_locale.py +130 -0
- DDSP-SVC/logger/__init__.py +0 -0
- DDSP-SVC/logger/saver.py +145 -0
- DDSP-SVC/logger/utils.py +122 -0
- DDSP-SVC/main.py +282 -0
- DDSP-SVC/main_diff.py +372 -0
- DDSP-SVC/nsf_hifigan/env.py +15 -0
- DDSP-SVC/nsf_hifigan/models.py +430 -0
- DDSP-SVC/nsf_hifigan/nvSTFT.py +129 -0
- DDSP-SVC/nsf_hifigan/utils.py +68 -0
- DDSP-SVC/preprocess.py +197 -0
- DDSP-SVC/pretrain/hubert/.gitignore +2 -0
- DDSP-SVC/pretrain/nsf_hifigan/.gitignore +2 -0
- DDSP-SVC/requirements.txt +25 -0
- DDSP-SVC/slicer.py +146 -0
- DDSP-SVC/solver.py +151 -0
- DDSP-SVC/train.py +93 -0
- DDSP-SVC/train_diff.py +70 -0
- DDSP-SVC/webui.py +267 -0
DDSP-SVC/.gitignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
|
3 |
+
venv/
|
4 |
+
results/
|
5 |
+
configs/
|
6 |
+
!configs/combsub.yaml
|
7 |
+
!configs/combsub-old.yaml
|
8 |
+
!configs/sins.yaml
|
9 |
+
!configs/diffusion.yaml
|
10 |
+
cache/
|
DDSP-SVC/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 yxlllc
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
DDSP-SVC/data/train/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
3 |
+
!audio
|
DDSP-SVC/data/train/audio/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
DDSP-SVC/data/val/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
3 |
+
!audio
|
DDSP-SVC/data/val/audio/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
DDSP-SVC/data_loaders.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
import librosa
|
5 |
+
import torch
|
6 |
+
import random
|
7 |
+
from tqdm import tqdm
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
def traverse_dir(
|
11 |
+
root_dir,
|
12 |
+
extension,
|
13 |
+
amount=None,
|
14 |
+
str_include=None,
|
15 |
+
str_exclude=None,
|
16 |
+
is_pure=False,
|
17 |
+
is_sort=False,
|
18 |
+
is_ext=True):
|
19 |
+
|
20 |
+
file_list = []
|
21 |
+
cnt = 0
|
22 |
+
for root, _, files in os.walk(root_dir):
|
23 |
+
for file in files:
|
24 |
+
if file.endswith(extension):
|
25 |
+
# path
|
26 |
+
mix_path = os.path.join(root, file)
|
27 |
+
pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
|
28 |
+
|
29 |
+
# amount
|
30 |
+
if (amount is not None) and (cnt == amount):
|
31 |
+
if is_sort:
|
32 |
+
file_list.sort()
|
33 |
+
return file_list
|
34 |
+
|
35 |
+
# check string
|
36 |
+
if (str_include is not None) and (str_include not in pure_path):
|
37 |
+
continue
|
38 |
+
if (str_exclude is not None) and (str_exclude in pure_path):
|
39 |
+
continue
|
40 |
+
|
41 |
+
if not is_ext:
|
42 |
+
ext = pure_path.split('.')[-1]
|
43 |
+
pure_path = pure_path[:-(len(ext)+1)]
|
44 |
+
file_list.append(pure_path)
|
45 |
+
cnt += 1
|
46 |
+
if is_sort:
|
47 |
+
file_list.sort()
|
48 |
+
return file_list
|
49 |
+
|
50 |
+
|
51 |
+
def get_data_loaders(args, whole_audio=False):
|
52 |
+
data_train = AudioDataset(
|
53 |
+
args.data.train_path,
|
54 |
+
waveform_sec=args.data.duration,
|
55 |
+
hop_size=args.data.block_size,
|
56 |
+
sample_rate=args.data.sampling_rate,
|
57 |
+
load_all_data=args.train.cache_all_data,
|
58 |
+
whole_audio=whole_audio,
|
59 |
+
n_spk=args.model.n_spk,
|
60 |
+
device=args.train.cache_device,
|
61 |
+
fp16=args.train.cache_fp16,
|
62 |
+
use_aug=True)
|
63 |
+
loader_train = torch.utils.data.DataLoader(
|
64 |
+
data_train ,
|
65 |
+
batch_size=args.train.batch_size if not whole_audio else 1,
|
66 |
+
shuffle=True,
|
67 |
+
num_workers=args.train.num_workers if args.train.cache_device=='cpu' else 0,
|
68 |
+
persistent_workers=(args.train.num_workers > 0) if args.train.cache_device=='cpu' else False,
|
69 |
+
pin_memory=True if args.train.cache_device=='cpu' else False
|
70 |
+
)
|
71 |
+
data_valid = AudioDataset(
|
72 |
+
args.data.valid_path,
|
73 |
+
waveform_sec=args.data.duration,
|
74 |
+
hop_size=args.data.block_size,
|
75 |
+
sample_rate=args.data.sampling_rate,
|
76 |
+
load_all_data=args.train.cache_all_data,
|
77 |
+
whole_audio=True,
|
78 |
+
n_spk=args.model.n_spk)
|
79 |
+
loader_valid = torch.utils.data.DataLoader(
|
80 |
+
data_valid,
|
81 |
+
batch_size=1,
|
82 |
+
shuffle=False,
|
83 |
+
num_workers=0,
|
84 |
+
pin_memory=True
|
85 |
+
)
|
86 |
+
return loader_train, loader_valid
|
87 |
+
|
88 |
+
|
89 |
+
class AudioDataset(Dataset):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
path_root,
|
93 |
+
waveform_sec,
|
94 |
+
hop_size,
|
95 |
+
sample_rate,
|
96 |
+
load_all_data=True,
|
97 |
+
whole_audio=False,
|
98 |
+
n_spk=1,
|
99 |
+
device = 'cpu',
|
100 |
+
fp16 = False,
|
101 |
+
use_aug = False
|
102 |
+
):
|
103 |
+
super().__init__()
|
104 |
+
|
105 |
+
self.waveform_sec = waveform_sec
|
106 |
+
self.sample_rate = sample_rate
|
107 |
+
self.hop_size = hop_size
|
108 |
+
self.path_root = path_root
|
109 |
+
self.paths = traverse_dir(
|
110 |
+
os.path.join(path_root, 'audio'),
|
111 |
+
extension='wav',
|
112 |
+
is_pure=True,
|
113 |
+
is_sort=True,
|
114 |
+
is_ext=False
|
115 |
+
)
|
116 |
+
self.whole_audio = whole_audio
|
117 |
+
self.use_aug = use_aug
|
118 |
+
self.data_buffer={}
|
119 |
+
if load_all_data:
|
120 |
+
print('Load all the data from :', path_root)
|
121 |
+
else:
|
122 |
+
print('Load the f0, volume data from :', path_root)
|
123 |
+
for name in tqdm(self.paths, total=len(self.paths)):
|
124 |
+
path_audio = os.path.join(self.path_root, 'audio', name) + '.wav'
|
125 |
+
duration = librosa.get_duration(filename = path_audio, sr = self.sample_rate)
|
126 |
+
|
127 |
+
path_f0 = os.path.join(self.path_root, 'f0', name) + '.npy'
|
128 |
+
f0 = np.load(path_f0)
|
129 |
+
f0 = torch.from_numpy(f0).float().unsqueeze(-1).to(device)
|
130 |
+
|
131 |
+
path_volume = os.path.join(self.path_root, 'volume', name) + '.npy'
|
132 |
+
volume = np.load(path_volume)
|
133 |
+
volume = torch.from_numpy(volume).float().unsqueeze(-1).to(device)
|
134 |
+
|
135 |
+
if n_spk is not None and n_spk > 1:
|
136 |
+
spk_id = int(os.path.dirname(name)) if str.isdigit(os.path.dirname(name)) else 0
|
137 |
+
if spk_id < 1 or spk_id > n_spk:
|
138 |
+
raise ValueError(' [x] Muiti-speaker traing error : spk_id must be a positive integer from 1 to n_spk ')
|
139 |
+
else:
|
140 |
+
spk_id = 1
|
141 |
+
spk_id = torch.LongTensor(np.array([spk_id])).to(device)
|
142 |
+
|
143 |
+
if load_all_data:
|
144 |
+
audio, sr = librosa.load(path_audio, sr=self.sample_rate)
|
145 |
+
if len(audio.shape) > 1:
|
146 |
+
audio = librosa.to_mono(audio)
|
147 |
+
audio = torch.from_numpy(audio).to(device)
|
148 |
+
|
149 |
+
path_units = os.path.join(self.path_root, 'units', name) + '.npy'
|
150 |
+
units = np.load(path_units)
|
151 |
+
units = torch.from_numpy(units).to(device)
|
152 |
+
|
153 |
+
if fp16:
|
154 |
+
audio = audio.half()
|
155 |
+
units = units.half()
|
156 |
+
|
157 |
+
self.data_buffer[name] = {
|
158 |
+
'duration': duration,
|
159 |
+
'audio': audio,
|
160 |
+
'units': units,
|
161 |
+
'f0': f0,
|
162 |
+
'volume': volume,
|
163 |
+
'spk_id': spk_id
|
164 |
+
}
|
165 |
+
else:
|
166 |
+
self.data_buffer[name] = {
|
167 |
+
'duration': duration,
|
168 |
+
'f0': f0,
|
169 |
+
'volume': volume,
|
170 |
+
'spk_id': spk_id
|
171 |
+
}
|
172 |
+
|
173 |
+
|
174 |
+
def __getitem__(self, file_idx):
|
175 |
+
name = self.paths[file_idx]
|
176 |
+
data_buffer = self.data_buffer[name]
|
177 |
+
# check duration. if too short, then skip
|
178 |
+
if data_buffer['duration'] < (self.waveform_sec + 0.1):
|
179 |
+
return self.__getitem__( (file_idx + 1) % len(self.paths))
|
180 |
+
|
181 |
+
# get item
|
182 |
+
return self.get_data(name, data_buffer)
|
183 |
+
|
184 |
+
def get_data(self, name, data_buffer):
|
185 |
+
frame_resolution = self.hop_size / self.sample_rate
|
186 |
+
duration = data_buffer['duration']
|
187 |
+
waveform_sec = duration if self.whole_audio else self.waveform_sec
|
188 |
+
|
189 |
+
# load audio
|
190 |
+
idx_from = 0 if self.whole_audio else random.uniform(0, duration - waveform_sec - 0.1)
|
191 |
+
start_frame = int(idx_from / frame_resolution)
|
192 |
+
units_frame_len = int(waveform_sec / frame_resolution)
|
193 |
+
audio = data_buffer.get('audio')
|
194 |
+
if audio is None:
|
195 |
+
path_audio = os.path.join(self.path_root, 'audio', name) + '.wav'
|
196 |
+
audio, sr = librosa.load(
|
197 |
+
path_audio,
|
198 |
+
sr = self.sample_rate,
|
199 |
+
offset = start_frame * frame_resolution,
|
200 |
+
duration = waveform_sec)
|
201 |
+
if len(audio.shape) > 1:
|
202 |
+
audio = librosa.to_mono(audio)
|
203 |
+
# clip audio into N seconds
|
204 |
+
audio = audio[ : audio.shape[-1] // self.hop_size * self.hop_size]
|
205 |
+
audio = torch.from_numpy(audio).float()
|
206 |
+
else:
|
207 |
+
audio = audio[start_frame * self.hop_size : (start_frame + units_frame_len) * self.hop_size]
|
208 |
+
|
209 |
+
# load units
|
210 |
+
units = data_buffer.get('units')
|
211 |
+
if units is None:
|
212 |
+
units = os.path.join(self.path_root, 'units', name) + '.npy'
|
213 |
+
units = np.load(units)
|
214 |
+
units = units[start_frame : start_frame + units_frame_len]
|
215 |
+
units = torch.from_numpy(units).float()
|
216 |
+
else:
|
217 |
+
units = units[start_frame : start_frame + units_frame_len]
|
218 |
+
|
219 |
+
# load f0
|
220 |
+
f0 = data_buffer.get('f0')
|
221 |
+
f0_frames = f0[start_frame : start_frame + units_frame_len]
|
222 |
+
|
223 |
+
# load volume
|
224 |
+
volume = data_buffer.get('volume')
|
225 |
+
volume_frames = volume[start_frame : start_frame + units_frame_len]
|
226 |
+
|
227 |
+
# load spk_id
|
228 |
+
spk_id = data_buffer.get('spk_id')
|
229 |
+
|
230 |
+
# volume augmentation
|
231 |
+
if self.use_aug:
|
232 |
+
max_amp = float(torch.max(torch.abs(audio))) + 1e-5
|
233 |
+
max_shift = min(1, np.log10(1/max_amp))
|
234 |
+
log10_vol_shift = random.uniform(-1, max_shift)
|
235 |
+
audio_aug = audio * (10 ** log10_vol_shift)
|
236 |
+
volume_frames_aug = volume_frames * (10 ** log10_vol_shift)
|
237 |
+
else:
|
238 |
+
audio_aug = audio
|
239 |
+
volume_frames_aug = volume_frames
|
240 |
+
|
241 |
+
return dict(audio=audio_aug, f0=f0_frames, volume=volume_frames_aug, units=units, spk_id=spk_id, name=name)
|
242 |
+
|
243 |
+
def __len__(self):
|
244 |
+
return len(self.paths)
|
DDSP-SVC/ddsp/__init__.py
ADDED
File without changes
|
DDSP-SVC/ddsp/core.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
import math
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def MaskedAvgPool1d(x, kernel_size):
|
9 |
+
x = x.unsqueeze(1)
|
10 |
+
x = F.pad(x, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect")
|
11 |
+
mask = ~torch.isnan(x)
|
12 |
+
masked_x = torch.where(mask, x, torch.zeros_like(x))
|
13 |
+
ones_kernel = torch.ones(x.size(1), 1, kernel_size, device=x.device)
|
14 |
+
|
15 |
+
# Perform sum pooling
|
16 |
+
sum_pooled = F.conv1d(
|
17 |
+
masked_x,
|
18 |
+
ones_kernel,
|
19 |
+
stride=1,
|
20 |
+
padding=0,
|
21 |
+
groups=x.size(1),
|
22 |
+
)
|
23 |
+
|
24 |
+
# Count the non-masked (valid) elements in each pooling window
|
25 |
+
valid_count = F.conv1d(
|
26 |
+
mask.float(),
|
27 |
+
ones_kernel,
|
28 |
+
stride=1,
|
29 |
+
padding=0,
|
30 |
+
groups=x.size(1),
|
31 |
+
)
|
32 |
+
valid_count = valid_count.clamp(min=1) # Avoid division by zero
|
33 |
+
|
34 |
+
# Perform masked average pooling
|
35 |
+
avg_pooled = sum_pooled / valid_count
|
36 |
+
|
37 |
+
return avg_pooled.squeeze(1)
|
38 |
+
|
39 |
+
def MedianPool1d(x, kernel_size):
|
40 |
+
x = x.unsqueeze(1)
|
41 |
+
x = F.pad(x, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect")
|
42 |
+
x = x.squeeze(1)
|
43 |
+
x = x.unfold(1, kernel_size, 1)
|
44 |
+
x, _ = torch.sort(x, dim=-1)
|
45 |
+
return x[:, :, (kernel_size - 1) // 2]
|
46 |
+
|
47 |
+
def get_fft_size(frame_size: int, ir_size: int, power_of_2: bool = True):
|
48 |
+
"""Calculate final size for efficient FFT.
|
49 |
+
Args:
|
50 |
+
frame_size: Size of the audio frame.
|
51 |
+
ir_size: Size of the convolving impulse response.
|
52 |
+
power_of_2: Constrain to be a power of 2. If False, allow other 5-smooth
|
53 |
+
numbers. TPU requires power of 2, while GPU is more flexible.
|
54 |
+
Returns:
|
55 |
+
fft_size: Size for efficient FFT.
|
56 |
+
"""
|
57 |
+
convolved_frame_size = ir_size + frame_size - 1
|
58 |
+
if power_of_2:
|
59 |
+
# Next power of 2.
|
60 |
+
fft_size = int(2**np.ceil(np.log2(convolved_frame_size)))
|
61 |
+
else:
|
62 |
+
fft_size = convolved_frame_size
|
63 |
+
return fft_size
|
64 |
+
|
65 |
+
|
66 |
+
def upsample(signal, factor):
|
67 |
+
signal = signal.permute(0, 2, 1)
|
68 |
+
signal = nn.functional.interpolate(torch.cat((signal,signal[:,:,-1:]),2), size=signal.shape[-1] * factor + 1, mode='linear', align_corners=True)
|
69 |
+
signal = signal[:,:,:-1]
|
70 |
+
return signal.permute(0, 2, 1)
|
71 |
+
|
72 |
+
|
73 |
+
def remove_above_fmax(amplitudes, pitch, fmax, level_start=1):
|
74 |
+
n_harm = amplitudes.shape[-1]
|
75 |
+
pitches = pitch * torch.arange(level_start, n_harm + level_start).to(pitch)
|
76 |
+
aa = (pitches < fmax).float() + 1e-7
|
77 |
+
return amplitudes * aa
|
78 |
+
|
79 |
+
|
80 |
+
def crop_and_compensate_delay(audio, audio_size, ir_size,
|
81 |
+
padding = 'same',
|
82 |
+
delay_compensation = -1):
|
83 |
+
"""Crop audio output from convolution to compensate for group delay.
|
84 |
+
Args:
|
85 |
+
audio: Audio after convolution. Tensor of shape [batch, time_steps].
|
86 |
+
audio_size: Initial size of the audio before convolution.
|
87 |
+
ir_size: Size of the convolving impulse response.
|
88 |
+
padding: Either 'valid' or 'same'. For 'same' the final output to be the
|
89 |
+
same size as the input audio (audio_timesteps). For 'valid' the audio is
|
90 |
+
extended to include the tail of the impulse response (audio_timesteps +
|
91 |
+
ir_timesteps - 1).
|
92 |
+
delay_compensation: Samples to crop from start of output audio to compensate
|
93 |
+
for group delay of the impulse response. If delay_compensation < 0 it
|
94 |
+
defaults to automatically calculating a constant group delay of the
|
95 |
+
windowed linear phase filter from frequency_impulse_response().
|
96 |
+
Returns:
|
97 |
+
Tensor of cropped and shifted audio.
|
98 |
+
Raises:
|
99 |
+
ValueError: If padding is not either 'valid' or 'same'.
|
100 |
+
"""
|
101 |
+
# Crop the output.
|
102 |
+
if padding == 'valid':
|
103 |
+
crop_size = ir_size + audio_size - 1
|
104 |
+
elif padding == 'same':
|
105 |
+
crop_size = audio_size
|
106 |
+
else:
|
107 |
+
raise ValueError('Padding must be \'valid\' or \'same\', instead '
|
108 |
+
'of {}.'.format(padding))
|
109 |
+
|
110 |
+
# Compensate for the group delay of the filter by trimming the front.
|
111 |
+
# For an impulse response produced by frequency_impulse_response(),
|
112 |
+
# the group delay is constant because the filter is linear phase.
|
113 |
+
total_size = int(audio.shape[-1])
|
114 |
+
crop = total_size - crop_size
|
115 |
+
start = (ir_size // 2 if delay_compensation < 0 else delay_compensation)
|
116 |
+
end = crop - start
|
117 |
+
return audio[:, start:-end]
|
118 |
+
|
119 |
+
|
120 |
+
def fft_convolve(audio,
|
121 |
+
impulse_response): # B, n_frames, 2*(n_mags-1)
|
122 |
+
"""Filter audio with frames of time-varying impulse responses.
|
123 |
+
Time-varying filter. Given audio [batch, n_samples], and a series of impulse
|
124 |
+
responses [batch, n_frames, n_impulse_response], splits the audio into frames,
|
125 |
+
applies filters, and then overlap-and-adds audio back together.
|
126 |
+
Applies non-windowed non-overlapping STFT/ISTFT to efficiently compute
|
127 |
+
convolution for large impulse response sizes.
|
128 |
+
Args:
|
129 |
+
audio: Input audio. Tensor of shape [batch, audio_timesteps].
|
130 |
+
impulse_response: Finite impulse response to convolve. Can either be a 2-D
|
131 |
+
Tensor of shape [batch, ir_size], or a 3-D Tensor of shape [batch,
|
132 |
+
ir_frames, ir_size]. A 2-D tensor will apply a single linear
|
133 |
+
time-invariant filter to the audio. A 3-D Tensor will apply a linear
|
134 |
+
time-varying filter. Automatically chops the audio into equally shaped
|
135 |
+
blocks to match ir_frames.
|
136 |
+
Returns:
|
137 |
+
audio_out: Convolved audio. Tensor of shape
|
138 |
+
[batch, audio_timesteps].
|
139 |
+
"""
|
140 |
+
# Add a frame dimension to impulse response if it doesn't have one.
|
141 |
+
ir_shape = impulse_response.size()
|
142 |
+
if len(ir_shape) == 2:
|
143 |
+
impulse_response = impulse_response.unsqueeze(1)
|
144 |
+
ir_shape = impulse_response.size()
|
145 |
+
|
146 |
+
# Get shapes of audio and impulse response.
|
147 |
+
batch_size_ir, n_ir_frames, ir_size = ir_shape
|
148 |
+
batch_size, audio_size = audio.size() # B, T
|
149 |
+
|
150 |
+
# Validate that batch sizes match.
|
151 |
+
if batch_size != batch_size_ir:
|
152 |
+
raise ValueError('Batch size of audio ({}) and impulse response ({}) must '
|
153 |
+
'be the same.'.format(batch_size, batch_size_ir))
|
154 |
+
|
155 |
+
# Cut audio into 50% overlapped frames (center padding).
|
156 |
+
hop_size = int(audio_size / n_ir_frames)
|
157 |
+
frame_size = 2 * hop_size
|
158 |
+
audio_frames = F.pad(audio, (hop_size, hop_size)).unfold(1, frame_size, hop_size)
|
159 |
+
|
160 |
+
# Apply Bartlett (triangular) window
|
161 |
+
window = torch.bartlett_window(frame_size).to(audio_frames)
|
162 |
+
audio_frames = audio_frames * window
|
163 |
+
|
164 |
+
# Pad and FFT the audio and impulse responses.
|
165 |
+
fft_size = get_fft_size(frame_size, ir_size, power_of_2=False)
|
166 |
+
audio_fft = torch.fft.rfft(audio_frames, fft_size)
|
167 |
+
ir_fft = torch.fft.rfft(torch.cat((impulse_response,impulse_response[:,-1:,:]),1), fft_size)
|
168 |
+
|
169 |
+
# Multiply the FFTs (same as convolution in time).
|
170 |
+
audio_ir_fft = torch.multiply(audio_fft, ir_fft)
|
171 |
+
|
172 |
+
# Take the IFFT to resynthesize audio.
|
173 |
+
audio_frames_out = torch.fft.irfft(audio_ir_fft, fft_size)
|
174 |
+
|
175 |
+
# Overlap Add
|
176 |
+
batch_size, n_audio_frames, frame_size = audio_frames_out.size() # # B, n_frames+1, 2*(hop_size+n_mags-1)-1
|
177 |
+
fold = torch.nn.Fold(output_size=(1, (n_audio_frames - 1) * hop_size + frame_size),kernel_size=(1, frame_size),stride=(1, hop_size))
|
178 |
+
output_signal = fold(audio_frames_out.transpose(1, 2)).squeeze(1).squeeze(1)
|
179 |
+
|
180 |
+
# Crop and shift the output audio.
|
181 |
+
output_signal = crop_and_compensate_delay(output_signal[:,hop_size:], audio_size, ir_size)
|
182 |
+
return output_signal
|
183 |
+
|
184 |
+
|
185 |
+
def apply_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1)
|
186 |
+
window_size: int = 0,
|
187 |
+
causal: bool = False):
|
188 |
+
"""Apply a window to an impulse response and put in causal form.
|
189 |
+
Args:
|
190 |
+
impulse_response: A series of impulse responses frames to window, of shape
|
191 |
+
[batch, n_frames, ir_size]. ---------> ir_size means size of filter_bank ??????
|
192 |
+
|
193 |
+
window_size: Size of the window to apply in the time domain. If window_size
|
194 |
+
is less than 1, it defaults to the impulse_response size.
|
195 |
+
causal: Impulse response input is in causal form (peak in the middle).
|
196 |
+
Returns:
|
197 |
+
impulse_response: Windowed impulse response in causal form, with last
|
198 |
+
dimension cropped to window_size if window_size is greater than 0 and less
|
199 |
+
than ir_size.
|
200 |
+
"""
|
201 |
+
|
202 |
+
# If IR is in causal form, put it in zero-phase form.
|
203 |
+
if causal:
|
204 |
+
impulse_response = torch.fftshift(impulse_response, axes=-1)
|
205 |
+
|
206 |
+
# Get a window for better time/frequency resolution than rectangular.
|
207 |
+
# Window defaults to IR size, cannot be bigger.
|
208 |
+
ir_size = int(impulse_response.size(-1))
|
209 |
+
if (window_size <= 0) or (window_size > ir_size):
|
210 |
+
window_size = ir_size
|
211 |
+
window = nn.Parameter(torch.hann_window(window_size), requires_grad = False).to(impulse_response)
|
212 |
+
|
213 |
+
# Zero pad the window and put in in zero-phase form.
|
214 |
+
padding = ir_size - window_size
|
215 |
+
if padding > 0:
|
216 |
+
half_idx = (window_size + 1) // 2
|
217 |
+
window = torch.cat([window[half_idx:],
|
218 |
+
torch.zeros([padding]),
|
219 |
+
window[:half_idx]], axis=0)
|
220 |
+
else:
|
221 |
+
window = window.roll(window.size(-1)//2, -1)
|
222 |
+
|
223 |
+
# Apply the window, to get new IR (both in zero-phase form).
|
224 |
+
window = window.unsqueeze(0)
|
225 |
+
impulse_response = impulse_response * window
|
226 |
+
|
227 |
+
# Put IR in causal form and trim zero padding.
|
228 |
+
if padding > 0:
|
229 |
+
first_half_start = (ir_size - (half_idx - 1)) + 1
|
230 |
+
second_half_end = half_idx + 1
|
231 |
+
impulse_response = torch.cat([impulse_response[..., first_half_start:],
|
232 |
+
impulse_response[..., :second_half_end]],
|
233 |
+
dim=-1)
|
234 |
+
else:
|
235 |
+
impulse_response = impulse_response.roll(impulse_response.size(-1)//2, -1)
|
236 |
+
|
237 |
+
return impulse_response
|
238 |
+
|
239 |
+
|
240 |
+
def apply_dynamic_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1) or 2*n_mag-1
|
241 |
+
half_width_frames): # B,n_frames, 1
|
242 |
+
ir_size = int(impulse_response.size(-1)) # 2*(n_mag -1) or 2*n_mag-1
|
243 |
+
|
244 |
+
window = torch.arange(-(ir_size // 2), (ir_size + 1) // 2).to(impulse_response) / half_width_frames
|
245 |
+
window[window > 1] = 0
|
246 |
+
window = (1 + torch.cos(np.pi * window)) / 2 # B, n_frames, 2*(n_mag -1) or 2*n_mag-1
|
247 |
+
|
248 |
+
impulse_response = impulse_response.roll(ir_size // 2, -1)
|
249 |
+
impulse_response = impulse_response * window
|
250 |
+
|
251 |
+
return impulse_response
|
252 |
+
|
253 |
+
|
254 |
+
def frequency_impulse_response(magnitudes,
|
255 |
+
hann_window = True,
|
256 |
+
half_width_frames = None):
|
257 |
+
|
258 |
+
# Get the IR
|
259 |
+
impulse_response = torch.fft.irfft(magnitudes) # B, n_frames, 2*(n_mags-1)
|
260 |
+
|
261 |
+
# Window and put in causal form.
|
262 |
+
if hann_window:
|
263 |
+
if half_width_frames is None:
|
264 |
+
impulse_response = apply_window_to_impulse_response(impulse_response)
|
265 |
+
else:
|
266 |
+
impulse_response = apply_dynamic_window_to_impulse_response(impulse_response, half_width_frames)
|
267 |
+
else:
|
268 |
+
impulse_response = impulse_response.roll(impulse_response.size(-1) // 2, -1)
|
269 |
+
|
270 |
+
return impulse_response
|
271 |
+
|
272 |
+
|
273 |
+
def frequency_filter(audio,
|
274 |
+
magnitudes,
|
275 |
+
hann_window=True,
|
276 |
+
half_width_frames=None):
|
277 |
+
|
278 |
+
impulse_response = frequency_impulse_response(magnitudes, hann_window, half_width_frames)
|
279 |
+
|
280 |
+
return fft_convolve(audio, impulse_response)
|
281 |
+
|
DDSP-SVC/ddsp/loss.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torchaudio
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from .core import upsample
|
8 |
+
|
9 |
+
class SSSLoss(nn.Module):
|
10 |
+
"""
|
11 |
+
Single-scale Spectral Loss.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, n_fft=111, alpha=1.0, overlap=0, eps=1e-7):
|
15 |
+
super().__init__()
|
16 |
+
self.n_fft = n_fft
|
17 |
+
self.alpha = alpha
|
18 |
+
self.eps = eps
|
19 |
+
self.hop_length = int(n_fft * (1 - overlap)) # 25% of the length
|
20 |
+
self.spec = torchaudio.transforms.Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length, power=1, normalized=True, center=False)
|
21 |
+
|
22 |
+
def forward(self, x_true, x_pred):
|
23 |
+
S_true = self.spec(x_true) + self.eps
|
24 |
+
S_pred = self.spec(x_pred) + self.eps
|
25 |
+
|
26 |
+
converge_term = torch.mean(torch.linalg.norm(S_true - S_pred, dim = (1, 2)) / torch.linalg.norm(S_true + S_pred, dim = (1, 2)))
|
27 |
+
|
28 |
+
log_term = F.l1_loss(S_true.log(), S_pred.log())
|
29 |
+
|
30 |
+
loss = converge_term + self.alpha * log_term
|
31 |
+
return loss
|
32 |
+
|
33 |
+
|
34 |
+
class RSSLoss(nn.Module):
|
35 |
+
'''
|
36 |
+
Random-scale Spectral Loss.
|
37 |
+
'''
|
38 |
+
|
39 |
+
def __init__(self, fft_min, fft_max, n_scale, alpha=1.0, overlap=0, eps=1e-7, device='cuda'):
|
40 |
+
super().__init__()
|
41 |
+
self.fft_min = fft_min
|
42 |
+
self.fft_max = fft_max
|
43 |
+
self.n_scale = n_scale
|
44 |
+
self.lossdict = {}
|
45 |
+
for n_fft in range(fft_min, fft_max):
|
46 |
+
self.lossdict[n_fft] = SSSLoss(n_fft, alpha, overlap, eps).to(device)
|
47 |
+
|
48 |
+
def forward(self, x_pred, x_true):
|
49 |
+
value = 0.
|
50 |
+
n_ffts = torch.randint(self.fft_min, self.fft_max, (self.n_scale,))
|
51 |
+
for n_fft in n_ffts:
|
52 |
+
loss_func = self.lossdict[int(n_fft)]
|
53 |
+
value += loss_func(x_true, x_pred)
|
54 |
+
return value / self.n_scale
|
55 |
+
|
56 |
+
|
57 |
+
|
DDSP-SVC/ddsp/pcmer.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from torch import nn
|
4 |
+
import math
|
5 |
+
from functools import partial
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
|
8 |
+
from local_attention import LocalAttention
|
9 |
+
import torch.nn.functional as F
|
10 |
+
#import fast_transformers.causal_product.causal_product_cuda
|
11 |
+
|
12 |
+
def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None):
|
13 |
+
b, h, *_ = data.shape
|
14 |
+
# (batch size, head, length, model_dim)
|
15 |
+
|
16 |
+
# normalize model dim
|
17 |
+
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
|
18 |
+
|
19 |
+
# what is ration?, projection_matrix.shape[0] --> 266
|
20 |
+
|
21 |
+
ratio = (projection_matrix.shape[0] ** -0.5)
|
22 |
+
|
23 |
+
projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
|
24 |
+
projection = projection.type_as(data)
|
25 |
+
|
26 |
+
#data_dash = w^T x
|
27 |
+
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)
|
28 |
+
|
29 |
+
|
30 |
+
# diag_data = D**2
|
31 |
+
diag_data = data ** 2
|
32 |
+
diag_data = torch.sum(diag_data, dim=-1)
|
33 |
+
diag_data = (diag_data / 2.0) * (data_normalizer ** 2)
|
34 |
+
diag_data = diag_data.unsqueeze(dim=-1)
|
35 |
+
|
36 |
+
#print ()
|
37 |
+
if is_query:
|
38 |
+
data_dash = ratio * (
|
39 |
+
torch.exp(data_dash - diag_data -
|
40 |
+
torch.max(data_dash, dim=-1, keepdim=True).values) + eps)
|
41 |
+
else:
|
42 |
+
data_dash = ratio * (
|
43 |
+
torch.exp(data_dash - diag_data + eps))#- torch.max(data_dash)) + eps)
|
44 |
+
|
45 |
+
return data_dash.type_as(data)
|
46 |
+
|
47 |
+
def orthogonal_matrix_chunk(cols, qr_uniform_q = False, device = None):
|
48 |
+
unstructured_block = torch.randn((cols, cols), device = device)
|
49 |
+
q, r = torch.linalg.qr(unstructured_block.cpu(), mode='reduced')
|
50 |
+
q, r = map(lambda t: t.to(device), (q, r))
|
51 |
+
|
52 |
+
# proposed by @Parskatt
|
53 |
+
# to make sure Q is uniform https://arxiv.org/pdf/math-ph/0609050.pdf
|
54 |
+
if qr_uniform_q:
|
55 |
+
d = torch.diag(r, 0)
|
56 |
+
q *= d.sign()
|
57 |
+
return q.t()
|
58 |
+
def exists(val):
|
59 |
+
return val is not None
|
60 |
+
|
61 |
+
def empty(tensor):
|
62 |
+
return tensor.numel() == 0
|
63 |
+
|
64 |
+
def default(val, d):
|
65 |
+
return val if exists(val) else d
|
66 |
+
|
67 |
+
def cast_tuple(val):
|
68 |
+
return (val,) if not isinstance(val, tuple) else val
|
69 |
+
|
70 |
+
class PCmer(nn.Module):
|
71 |
+
"""The encoder that is used in the Transformer model."""
|
72 |
+
|
73 |
+
def __init__(self,
|
74 |
+
num_layers,
|
75 |
+
num_heads,
|
76 |
+
dim_model,
|
77 |
+
dim_keys,
|
78 |
+
dim_values,
|
79 |
+
residual_dropout,
|
80 |
+
attention_dropout):
|
81 |
+
super().__init__()
|
82 |
+
self.num_layers = num_layers
|
83 |
+
self.num_heads = num_heads
|
84 |
+
self.dim_model = dim_model
|
85 |
+
self.dim_values = dim_values
|
86 |
+
self.dim_keys = dim_keys
|
87 |
+
self.residual_dropout = residual_dropout
|
88 |
+
self.attention_dropout = attention_dropout
|
89 |
+
|
90 |
+
self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)])
|
91 |
+
|
92 |
+
# METHODS ########################################################################################################
|
93 |
+
|
94 |
+
def forward(self, phone, mask=None):
|
95 |
+
|
96 |
+
# apply all layers to the input
|
97 |
+
for (i, layer) in enumerate(self._layers):
|
98 |
+
phone = layer(phone, mask)
|
99 |
+
# provide the final sequence
|
100 |
+
return phone
|
101 |
+
|
102 |
+
|
103 |
+
# ==================================================================================================================== #
|
104 |
+
# CLASS _ E N C O D E R L A Y E R #
|
105 |
+
# ==================================================================================================================== #
|
106 |
+
|
107 |
+
|
108 |
+
class _EncoderLayer(nn.Module):
|
109 |
+
"""One layer of the encoder.
|
110 |
+
|
111 |
+
Attributes:
|
112 |
+
attn: (:class:`mha.MultiHeadAttention`): The attention mechanism that is used to read the input sequence.
|
113 |
+
feed_forward (:class:`ffl.FeedForwardLayer`): The feed-forward layer on top of the attention mechanism.
|
114 |
+
"""
|
115 |
+
|
116 |
+
def __init__(self, parent: PCmer):
|
117 |
+
"""Creates a new instance of ``_EncoderLayer``.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
parent (Encoder): The encoder that the layers is created for.
|
121 |
+
"""
|
122 |
+
super().__init__()
|
123 |
+
|
124 |
+
|
125 |
+
self.conformer = ConformerConvModule(parent.dim_model)
|
126 |
+
self.norm = nn.LayerNorm(parent.dim_model)
|
127 |
+
self.dropout = nn.Dropout(parent.residual_dropout)
|
128 |
+
|
129 |
+
# selfatt -> fastatt: performer!
|
130 |
+
self.attn = SelfAttention(dim = parent.dim_model,
|
131 |
+
heads = parent.num_heads,
|
132 |
+
causal = False)
|
133 |
+
|
134 |
+
# METHODS ########################################################################################################
|
135 |
+
|
136 |
+
def forward(self, phone, mask=None):
|
137 |
+
|
138 |
+
# compute attention sub-layer
|
139 |
+
phone = phone + (self.attn(self.norm(phone), mask=mask))
|
140 |
+
|
141 |
+
phone = phone + (self.conformer(phone))
|
142 |
+
|
143 |
+
return phone
|
144 |
+
|
145 |
+
def calc_same_padding(kernel_size):
|
146 |
+
pad = kernel_size // 2
|
147 |
+
return (pad, pad - (kernel_size + 1) % 2)
|
148 |
+
|
149 |
+
# helper classes
|
150 |
+
|
151 |
+
class Swish(nn.Module):
|
152 |
+
def forward(self, x):
|
153 |
+
return x * x.sigmoid()
|
154 |
+
|
155 |
+
class Transpose(nn.Module):
|
156 |
+
def __init__(self, dims):
|
157 |
+
super().__init__()
|
158 |
+
assert len(dims) == 2, 'dims must be a tuple of two dimensions'
|
159 |
+
self.dims = dims
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
return x.transpose(*self.dims)
|
163 |
+
|
164 |
+
class GLU(nn.Module):
|
165 |
+
def __init__(self, dim):
|
166 |
+
super().__init__()
|
167 |
+
self.dim = dim
|
168 |
+
|
169 |
+
def forward(self, x):
|
170 |
+
out, gate = x.chunk(2, dim=self.dim)
|
171 |
+
return out * gate.sigmoid()
|
172 |
+
|
173 |
+
class DepthWiseConv1d(nn.Module):
|
174 |
+
def __init__(self, chan_in, chan_out, kernel_size, padding):
|
175 |
+
super().__init__()
|
176 |
+
self.padding = padding
|
177 |
+
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
x = F.pad(x, self.padding)
|
181 |
+
return self.conv(x)
|
182 |
+
|
183 |
+
class ConformerConvModule(nn.Module):
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
dim,
|
187 |
+
causal = False,
|
188 |
+
expansion_factor = 2,
|
189 |
+
kernel_size = 31,
|
190 |
+
dropout = 0.):
|
191 |
+
super().__init__()
|
192 |
+
|
193 |
+
inner_dim = dim * expansion_factor
|
194 |
+
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
|
195 |
+
|
196 |
+
self.net = nn.Sequential(
|
197 |
+
nn.LayerNorm(dim),
|
198 |
+
Transpose((1, 2)),
|
199 |
+
nn.Conv1d(dim, inner_dim * 2, 1),
|
200 |
+
GLU(dim=1),
|
201 |
+
DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),
|
202 |
+
#nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
|
203 |
+
Swish(),
|
204 |
+
nn.Conv1d(inner_dim, dim, 1),
|
205 |
+
Transpose((1, 2)),
|
206 |
+
nn.Dropout(dropout)
|
207 |
+
)
|
208 |
+
|
209 |
+
def forward(self, x):
|
210 |
+
return self.net(x)
|
211 |
+
|
212 |
+
def linear_attention(q, k, v):
|
213 |
+
if v is None:
|
214 |
+
#print (k.size(), q.size())
|
215 |
+
out = torch.einsum('...ed,...nd->...ne', k, q)
|
216 |
+
return out
|
217 |
+
|
218 |
+
else:
|
219 |
+
k_cumsum = k.sum(dim = -2)
|
220 |
+
#k_cumsum = k.sum(dim = -2)
|
221 |
+
D_inv = 1. / (torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q)) + 1e-8)
|
222 |
+
|
223 |
+
context = torch.einsum('...nd,...ne->...de', k, v)
|
224 |
+
#print ("TRUEEE: ", context.size(), q.size(), D_inv.size())
|
225 |
+
out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
|
226 |
+
return out
|
227 |
+
|
228 |
+
def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, qr_uniform_q = False, device = None):
|
229 |
+
nb_full_blocks = int(nb_rows / nb_columns)
|
230 |
+
#print (nb_full_blocks)
|
231 |
+
block_list = []
|
232 |
+
|
233 |
+
for _ in range(nb_full_blocks):
|
234 |
+
q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device)
|
235 |
+
block_list.append(q)
|
236 |
+
# block_list[n] is a orthogonal matrix ... (model_dim * model_dim)
|
237 |
+
#print (block_list[0].size(), torch.einsum('...nd,...nd->...n', block_list[0], torch.roll(block_list[0],1,1)))
|
238 |
+
#print (nb_rows, nb_full_blocks, nb_columns)
|
239 |
+
remaining_rows = nb_rows - nb_full_blocks * nb_columns
|
240 |
+
#print (remaining_rows)
|
241 |
+
if remaining_rows > 0:
|
242 |
+
q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device)
|
243 |
+
#print (q[:remaining_rows].size())
|
244 |
+
block_list.append(q[:remaining_rows])
|
245 |
+
|
246 |
+
final_matrix = torch.cat(block_list)
|
247 |
+
|
248 |
+
if scaling == 0:
|
249 |
+
multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1)
|
250 |
+
elif scaling == 1:
|
251 |
+
multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device)
|
252 |
+
else:
|
253 |
+
raise ValueError(f'Invalid scaling {scaling}')
|
254 |
+
|
255 |
+
return torch.diag(multiplier) @ final_matrix
|
256 |
+
|
257 |
+
class FastAttention(nn.Module):
|
258 |
+
def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, no_projection = False):
|
259 |
+
super().__init__()
|
260 |
+
nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
|
261 |
+
|
262 |
+
self.dim_heads = dim_heads
|
263 |
+
self.nb_features = nb_features
|
264 |
+
self.ortho_scaling = ortho_scaling
|
265 |
+
|
266 |
+
self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling, qr_uniform_q = qr_uniform_q)
|
267 |
+
projection_matrix = self.create_projection()
|
268 |
+
self.register_buffer('projection_matrix', projection_matrix)
|
269 |
+
|
270 |
+
self.generalized_attention = generalized_attention
|
271 |
+
self.kernel_fn = kernel_fn
|
272 |
+
|
273 |
+
# if this is turned on, no projection will be used
|
274 |
+
# queries and keys will be softmax-ed as in the original efficient attention paper
|
275 |
+
self.no_projection = no_projection
|
276 |
+
|
277 |
+
self.causal = causal
|
278 |
+
if causal:
|
279 |
+
try:
|
280 |
+
import fast_transformers.causal_product.causal_product_cuda
|
281 |
+
self.causal_linear_fn = partial(causal_linear_attention)
|
282 |
+
except ImportError:
|
283 |
+
print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version')
|
284 |
+
self.causal_linear_fn = causal_linear_attention_noncuda
|
285 |
+
@torch.no_grad()
|
286 |
+
def redraw_projection_matrix(self):
|
287 |
+
projections = self.create_projection()
|
288 |
+
self.projection_matrix.copy_(projections)
|
289 |
+
del projections
|
290 |
+
|
291 |
+
def forward(self, q, k, v):
|
292 |
+
device = q.device
|
293 |
+
|
294 |
+
if self.no_projection:
|
295 |
+
q = q.softmax(dim = -1)
|
296 |
+
k = torch.exp(k) if self.causal else k.softmax(dim = -2)
|
297 |
+
|
298 |
+
elif self.generalized_attention:
|
299 |
+
create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device)
|
300 |
+
q, k = map(create_kernel, (q, k))
|
301 |
+
|
302 |
+
else:
|
303 |
+
create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device)
|
304 |
+
|
305 |
+
q = create_kernel(q, is_query = True)
|
306 |
+
k = create_kernel(k, is_query = False)
|
307 |
+
|
308 |
+
attn_fn = linear_attention if not self.causal else self.causal_linear_fn
|
309 |
+
if v is None:
|
310 |
+
out = attn_fn(q, k, None)
|
311 |
+
return out
|
312 |
+
else:
|
313 |
+
out = attn_fn(q, k, v)
|
314 |
+
return out
|
315 |
+
class SelfAttention(nn.Module):
|
316 |
+
def __init__(self, dim, causal = False, heads = 8, dim_head = 64, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, dropout = 0., no_projection = False):
|
317 |
+
super().__init__()
|
318 |
+
assert dim % heads == 0, 'dimension must be divisible by number of heads'
|
319 |
+
dim_head = default(dim_head, dim // heads)
|
320 |
+
inner_dim = dim_head * heads
|
321 |
+
self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, no_projection = no_projection)
|
322 |
+
|
323 |
+
self.heads = heads
|
324 |
+
self.global_heads = heads - local_heads
|
325 |
+
self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None
|
326 |
+
|
327 |
+
#print (heads, nb_features, dim_head)
|
328 |
+
#name_embedding = torch.zeros(110, heads, dim_head, dim_head)
|
329 |
+
#self.name_embedding = nn.Parameter(name_embedding, requires_grad=True)
|
330 |
+
|
331 |
+
|
332 |
+
self.to_q = nn.Linear(dim, inner_dim)
|
333 |
+
self.to_k = nn.Linear(dim, inner_dim)
|
334 |
+
self.to_v = nn.Linear(dim, inner_dim)
|
335 |
+
self.to_out = nn.Linear(inner_dim, dim)
|
336 |
+
self.dropout = nn.Dropout(dropout)
|
337 |
+
|
338 |
+
@torch.no_grad()
|
339 |
+
def redraw_projection_matrix(self):
|
340 |
+
self.fast_attention.redraw_projection_matrix()
|
341 |
+
#torch.nn.init.zeros_(self.name_embedding)
|
342 |
+
#print (torch.sum(self.name_embedding))
|
343 |
+
def forward(self, x, context = None, mask = None, context_mask = None, name=None, inference=False, **kwargs):
|
344 |
+
b, n, _, h, gh = *x.shape, self.heads, self.global_heads
|
345 |
+
|
346 |
+
cross_attend = exists(context)
|
347 |
+
|
348 |
+
context = default(context, x)
|
349 |
+
context_mask = default(context_mask, mask) if not cross_attend else context_mask
|
350 |
+
#print (torch.sum(self.name_embedding))
|
351 |
+
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
|
352 |
+
|
353 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
354 |
+
(q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
|
355 |
+
|
356 |
+
attn_outs = []
|
357 |
+
#print (name)
|
358 |
+
#print (self.name_embedding[name].size())
|
359 |
+
if not empty(q):
|
360 |
+
if exists(context_mask):
|
361 |
+
global_mask = context_mask[:, None, :, None]
|
362 |
+
v.masked_fill_(~global_mask, 0.)
|
363 |
+
if cross_attend:
|
364 |
+
pass
|
365 |
+
#print (torch.sum(self.name_embedding))
|
366 |
+
#out = self.fast_attention(q,self.name_embedding[name],None)
|
367 |
+
#print (torch.sum(self.name_embedding[...,-1:]))
|
368 |
+
else:
|
369 |
+
out = self.fast_attention(q, k, v)
|
370 |
+
attn_outs.append(out)
|
371 |
+
|
372 |
+
if not empty(lq):
|
373 |
+
assert not cross_attend, 'local attention is not compatible with cross attention'
|
374 |
+
out = self.local_attn(lq, lk, lv, input_mask = mask)
|
375 |
+
attn_outs.append(out)
|
376 |
+
|
377 |
+
out = torch.cat(attn_outs, dim = 1)
|
378 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
379 |
+
out = self.to_out(out)
|
380 |
+
return self.dropout(out)
|
DDSP-SVC/ddsp/unit2control.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gin
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn.utils import weight_norm
|
7 |
+
|
8 |
+
from .pcmer import PCmer
|
9 |
+
|
10 |
+
|
11 |
+
def split_to_dict(tensor, tensor_splits):
|
12 |
+
"""Split a tensor into a dictionary of multiple tensors."""
|
13 |
+
labels = []
|
14 |
+
sizes = []
|
15 |
+
|
16 |
+
for k, v in tensor_splits.items():
|
17 |
+
labels.append(k)
|
18 |
+
sizes.append(v)
|
19 |
+
|
20 |
+
tensors = torch.split(tensor, sizes, dim=-1)
|
21 |
+
return dict(zip(labels, tensors))
|
22 |
+
|
23 |
+
|
24 |
+
class Unit2Control(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
input_channel,
|
28 |
+
n_spk,
|
29 |
+
output_splits):
|
30 |
+
super().__init__()
|
31 |
+
self.output_splits = output_splits
|
32 |
+
self.f0_embed = nn.Linear(1, 256)
|
33 |
+
self.phase_embed = nn.Linear(1, 256)
|
34 |
+
self.volume_embed = nn.Linear(1, 256)
|
35 |
+
self.n_spk = n_spk
|
36 |
+
if n_spk is not None and n_spk > 1:
|
37 |
+
self.spk_embed = nn.Embedding(n_spk, 256)
|
38 |
+
|
39 |
+
# conv in stack
|
40 |
+
self.stack = nn.Sequential(
|
41 |
+
nn.Conv1d(input_channel, 256, 3, 1, 1),
|
42 |
+
nn.GroupNorm(4, 256),
|
43 |
+
nn.LeakyReLU(),
|
44 |
+
nn.Conv1d(256, 256, 3, 1, 1))
|
45 |
+
|
46 |
+
# transformer
|
47 |
+
self.decoder = PCmer(
|
48 |
+
num_layers=3,
|
49 |
+
num_heads=8,
|
50 |
+
dim_model=256,
|
51 |
+
dim_keys=256,
|
52 |
+
dim_values=256,
|
53 |
+
residual_dropout=0.1,
|
54 |
+
attention_dropout=0.1)
|
55 |
+
self.norm = nn.LayerNorm(256)
|
56 |
+
|
57 |
+
# out
|
58 |
+
self.n_out = sum([v for k, v in output_splits.items()])
|
59 |
+
self.dense_out = weight_norm(
|
60 |
+
nn.Linear(256, self.n_out))
|
61 |
+
|
62 |
+
def forward(self, units, f0, phase, volume, spk_id = None, spk_mix_dict = None):
|
63 |
+
|
64 |
+
'''
|
65 |
+
input:
|
66 |
+
B x n_frames x n_unit
|
67 |
+
return:
|
68 |
+
dict of B x n_frames x feat
|
69 |
+
'''
|
70 |
+
|
71 |
+
x = self.stack(units.transpose(1,2)).transpose(1,2)
|
72 |
+
x = x + self.f0_embed((1+ f0 / 700).log()) + self.phase_embed(phase / np.pi) + self.volume_embed(volume)
|
73 |
+
if self.n_spk is not None and self.n_spk > 1:
|
74 |
+
if spk_mix_dict is not None:
|
75 |
+
for k, v in spk_mix_dict.items():
|
76 |
+
spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device)
|
77 |
+
x = x + v * self.spk_embed(spk_id_torch - 1)
|
78 |
+
else:
|
79 |
+
x = x + self.spk_embed(spk_id - 1)
|
80 |
+
x = self.decoder(x)
|
81 |
+
x = self.norm(x)
|
82 |
+
e = self.dense_out(x)
|
83 |
+
controls = split_to_dict(e, self.output_splits)
|
84 |
+
|
85 |
+
return controls
|
86 |
+
|
DDSP-SVC/ddsp/vocoder.py
ADDED
@@ -0,0 +1,652 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import yaml
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import pyworld as pw
|
7 |
+
import parselmouth
|
8 |
+
import torchcrepe
|
9 |
+
import resampy
|
10 |
+
from transformers import HubertModel, Wav2Vec2FeatureExtractor
|
11 |
+
from fairseq import checkpoint_utils
|
12 |
+
from encoder.hubert.model import HubertSoft
|
13 |
+
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
14 |
+
from torchaudio.transforms import Resample
|
15 |
+
from .unit2control import Unit2Control
|
16 |
+
from .core import frequency_filter, upsample, remove_above_fmax, MaskedAvgPool1d, MedianPool1d
|
17 |
+
import time
|
18 |
+
|
19 |
+
CREPE_RESAMPLE_KERNEL = {}
|
20 |
+
|
21 |
+
class F0_Extractor:
|
22 |
+
def __init__(self, f0_extractor, sample_rate = 44100, hop_size = 512, f0_min = 65, f0_max = 800):
|
23 |
+
self.f0_extractor = f0_extractor
|
24 |
+
self.sample_rate = sample_rate
|
25 |
+
self.hop_size = hop_size
|
26 |
+
self.f0_min = f0_min
|
27 |
+
self.f0_max = f0_max
|
28 |
+
if f0_extractor == 'crepe':
|
29 |
+
key_str = str(sample_rate)
|
30 |
+
if key_str not in CREPE_RESAMPLE_KERNEL:
|
31 |
+
CREPE_RESAMPLE_KERNEL[key_str] = Resample(sample_rate, 16000, lowpass_filter_width = 128)
|
32 |
+
self.resample_kernel = CREPE_RESAMPLE_KERNEL[key_str]
|
33 |
+
|
34 |
+
def extract(self, audio, uv_interp = False, device = None, silence_front = 0): # audio: 1d numpy array
|
35 |
+
# extractor start time
|
36 |
+
n_frames = int(len(audio) // self.hop_size) + 1
|
37 |
+
|
38 |
+
start_frame = int(silence_front * self.sample_rate / self.hop_size)
|
39 |
+
real_silence_front = start_frame * self.hop_size / self.sample_rate
|
40 |
+
audio = audio[int(np.round(real_silence_front * self.sample_rate)) : ]
|
41 |
+
|
42 |
+
# extract f0 using parselmouth
|
43 |
+
if self.f0_extractor == 'parselmouth':
|
44 |
+
f0 = parselmouth.Sound(audio, self.sample_rate).to_pitch_ac(
|
45 |
+
time_step = self.hop_size / self.sample_rate,
|
46 |
+
voicing_threshold = 0.6,
|
47 |
+
pitch_floor = self.f0_min,
|
48 |
+
pitch_ceiling = self.f0_max).selected_array['frequency']
|
49 |
+
pad_size = start_frame + (int(len(audio) // self.hop_size) - len(f0) + 1) // 2
|
50 |
+
f0 = np.pad(f0,(pad_size, n_frames - len(f0) - pad_size))
|
51 |
+
|
52 |
+
# extract f0 using dio
|
53 |
+
elif self.f0_extractor == 'dio':
|
54 |
+
_f0, t = pw.dio(
|
55 |
+
audio.astype('double'),
|
56 |
+
self.sample_rate,
|
57 |
+
f0_floor = self.f0_min,
|
58 |
+
f0_ceil = self.f0_max,
|
59 |
+
channels_in_octave=2,
|
60 |
+
frame_period = (1000 * self.hop_size / self.sample_rate))
|
61 |
+
f0 = pw.stonemask(audio.astype('double'), _f0, t, self.sample_rate)
|
62 |
+
f0 = np.pad(f0.astype('float'), (start_frame, n_frames - len(f0) - start_frame))
|
63 |
+
|
64 |
+
# extract f0 using harvest
|
65 |
+
elif self.f0_extractor == 'harvest':
|
66 |
+
f0, _ = pw.harvest(
|
67 |
+
audio.astype('double'),
|
68 |
+
self.sample_rate,
|
69 |
+
f0_floor = self.f0_min,
|
70 |
+
f0_ceil = self.f0_max,
|
71 |
+
frame_period = (1000 * self.hop_size / self.sample_rate))
|
72 |
+
f0 = np.pad(f0.astype('float'), (start_frame, n_frames - len(f0) - start_frame))
|
73 |
+
|
74 |
+
# extract f0 using crepe
|
75 |
+
elif self.f0_extractor == 'crepe':
|
76 |
+
if device is None:
|
77 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
78 |
+
resample_kernel = self.resample_kernel.to(device)
|
79 |
+
wav16k_torch = resample_kernel(torch.FloatTensor(audio).unsqueeze(0).to(device))
|
80 |
+
|
81 |
+
f0, pd = torchcrepe.predict(wav16k_torch, 16000, 80, self.f0_min, self.f0_max, pad=True, model='full', batch_size=512, device=device, return_periodicity=True)
|
82 |
+
pd = MedianPool1d(pd, 4)
|
83 |
+
f0 = torchcrepe.threshold.At(0.05)(f0, pd)
|
84 |
+
f0 = MaskedAvgPool1d(f0, 4)
|
85 |
+
|
86 |
+
f0 = f0.squeeze(0).cpu().numpy()
|
87 |
+
f0 = np.array([f0[int(min(int(np.round(n * self.hop_size / self.sample_rate / 0.005)), len(f0) - 1))] for n in range(n_frames - start_frame)])
|
88 |
+
f0 = np.pad(f0, (start_frame, 0))
|
89 |
+
|
90 |
+
else:
|
91 |
+
raise ValueError(f" [x] Unknown f0 extractor: {f0_extractor}")
|
92 |
+
|
93 |
+
# interpolate the unvoiced f0
|
94 |
+
if uv_interp:
|
95 |
+
uv = f0 == 0
|
96 |
+
if len(f0[~uv]) > 0:
|
97 |
+
f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
|
98 |
+
f0[f0 < self.f0_min] = self.f0_min
|
99 |
+
return f0
|
100 |
+
|
101 |
+
|
102 |
+
class Volume_Extractor:
|
103 |
+
def __init__(self, hop_size = 512):
|
104 |
+
self.hop_size = hop_size
|
105 |
+
|
106 |
+
def extract(self, audio): # audio: 1d numpy array
|
107 |
+
n_frames = int(len(audio) // self.hop_size) + 1
|
108 |
+
audio2 = audio ** 2
|
109 |
+
audio2 = np.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode = 'reflect')
|
110 |
+
volume = np.array([np.mean(audio2[int(n * self.hop_size) : int((n + 1) * self.hop_size)]) for n in range(n_frames)])
|
111 |
+
volume = np.sqrt(volume)
|
112 |
+
return volume
|
113 |
+
|
114 |
+
|
115 |
+
class Units_Encoder:
|
116 |
+
def __init__(self, encoder, encoder_ckpt, encoder_sample_rate = 16000, encoder_hop_size = 320, device = None,
|
117 |
+
cnhubertsoft_gate=10):
|
118 |
+
if device is None:
|
119 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
120 |
+
self.device = device
|
121 |
+
|
122 |
+
is_loaded_encoder = False
|
123 |
+
if encoder == 'hubertsoft':
|
124 |
+
self.model = Audio2HubertSoft(encoder_ckpt).to(device)
|
125 |
+
is_loaded_encoder = True
|
126 |
+
if encoder == 'hubertbase':
|
127 |
+
self.model = Audio2HubertBase(encoder_ckpt, device=device)
|
128 |
+
is_loaded_encoder = True
|
129 |
+
if encoder == 'hubertbase768':
|
130 |
+
self.model = Audio2HubertBase768(encoder_ckpt, device=device)
|
131 |
+
is_loaded_encoder = True
|
132 |
+
if encoder == 'contentvec':
|
133 |
+
self.model = Audio2ContentVec(encoder_ckpt, device=device)
|
134 |
+
is_loaded_encoder = True
|
135 |
+
if encoder == 'contentvec768':
|
136 |
+
self.model = Audio2ContentVec768(encoder_ckpt, device=device)
|
137 |
+
is_loaded_encoder = True
|
138 |
+
if encoder == 'contentvec768l12':
|
139 |
+
self.model = Audio2ContentVec768L12(encoder_ckpt, device=device)
|
140 |
+
is_loaded_encoder = True
|
141 |
+
if encoder == 'cnhubertsoftfish':
|
142 |
+
self.model = CNHubertSoftFish(encoder_ckpt, device=device, gate_size=cnhubertsoft_gate)
|
143 |
+
is_loaded_encoder = True
|
144 |
+
if not is_loaded_encoder:
|
145 |
+
raise ValueError(f" [x] Unknown units encoder: {encoder}")
|
146 |
+
|
147 |
+
self.resample_kernel = {}
|
148 |
+
self.encoder_sample_rate = encoder_sample_rate
|
149 |
+
self.encoder_hop_size = encoder_hop_size
|
150 |
+
|
151 |
+
def encode(self,
|
152 |
+
audio, # B, T
|
153 |
+
sample_rate,
|
154 |
+
hop_size):
|
155 |
+
|
156 |
+
# resample
|
157 |
+
if sample_rate == self.encoder_sample_rate:
|
158 |
+
audio_res = audio
|
159 |
+
else:
|
160 |
+
key_str = str(sample_rate)
|
161 |
+
if key_str not in self.resample_kernel:
|
162 |
+
self.resample_kernel[key_str] = Resample(sample_rate, self.encoder_sample_rate, lowpass_filter_width = 128).to(self.device)
|
163 |
+
audio_res = self.resample_kernel[key_str](audio)
|
164 |
+
|
165 |
+
# encode
|
166 |
+
if audio_res.size(-1) < self.encoder_hop_size:
|
167 |
+
audio_res = torch.nn.functional.pad(audio, (0, self.encoder_hop_size - audio_res.size(-1)))
|
168 |
+
units = self.model(audio_res)
|
169 |
+
|
170 |
+
# alignment
|
171 |
+
n_frames = audio.size(-1) // hop_size + 1
|
172 |
+
ratio = (hop_size / sample_rate) / (self.encoder_hop_size / self.encoder_sample_rate)
|
173 |
+
index = torch.clamp(torch.round(ratio * torch.arange(n_frames).to(self.device)).long(), max = units.size(1) - 1)
|
174 |
+
units_aligned = torch.gather(units, 1, index.unsqueeze(0).unsqueeze(-1).repeat([1, 1, units.size(-1)]))
|
175 |
+
return units_aligned
|
176 |
+
|
177 |
+
class Audio2HubertSoft(torch.nn.Module):
|
178 |
+
def __init__(self, path, h_sample_rate = 16000, h_hop_size = 320):
|
179 |
+
super().__init__()
|
180 |
+
print(' [Encoder Model] HuBERT Soft')
|
181 |
+
self.hubert = HubertSoft()
|
182 |
+
print(' [Loading] ' + path)
|
183 |
+
checkpoint = torch.load(path)
|
184 |
+
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
|
185 |
+
self.hubert.load_state_dict(checkpoint)
|
186 |
+
self.hubert.eval()
|
187 |
+
|
188 |
+
def forward(self,
|
189 |
+
audio): # B, T
|
190 |
+
with torch.inference_mode():
|
191 |
+
units = self.hubert.units(audio.unsqueeze(1))
|
192 |
+
return units
|
193 |
+
|
194 |
+
|
195 |
+
class Audio2ContentVec():
|
196 |
+
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'):
|
197 |
+
self.device = device
|
198 |
+
print(' [Encoder Model] Content Vec')
|
199 |
+
print(' [Loading] ' + path)
|
200 |
+
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", )
|
201 |
+
self.hubert = self.models[0]
|
202 |
+
self.hubert = self.hubert.to(self.device)
|
203 |
+
self.hubert.eval()
|
204 |
+
|
205 |
+
def __call__(self,
|
206 |
+
audio): # B, T
|
207 |
+
# wav_tensor = torch.from_numpy(audio).to(self.device)
|
208 |
+
wav_tensor = audio
|
209 |
+
feats = wav_tensor.view(1, -1)
|
210 |
+
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
|
211 |
+
inputs = {
|
212 |
+
"source": feats.to(wav_tensor.device),
|
213 |
+
"padding_mask": padding_mask.to(wav_tensor.device),
|
214 |
+
"output_layer": 9, # layer 9
|
215 |
+
}
|
216 |
+
with torch.no_grad():
|
217 |
+
logits = self.hubert.extract_features(**inputs)
|
218 |
+
feats = self.hubert.final_proj(logits[0])
|
219 |
+
units = feats # .transpose(2, 1)
|
220 |
+
return units
|
221 |
+
|
222 |
+
|
223 |
+
class Audio2ContentVec768():
|
224 |
+
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'):
|
225 |
+
self.device = device
|
226 |
+
print(' [Encoder Model] Content Vec')
|
227 |
+
print(' [Loading] ' + path)
|
228 |
+
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", )
|
229 |
+
self.hubert = self.models[0]
|
230 |
+
self.hubert = self.hubert.to(self.device)
|
231 |
+
self.hubert.eval()
|
232 |
+
|
233 |
+
def __call__(self,
|
234 |
+
audio): # B, T
|
235 |
+
# wav_tensor = torch.from_numpy(audio).to(self.device)
|
236 |
+
wav_tensor = audio
|
237 |
+
feats = wav_tensor.view(1, -1)
|
238 |
+
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
|
239 |
+
inputs = {
|
240 |
+
"source": feats.to(wav_tensor.device),
|
241 |
+
"padding_mask": padding_mask.to(wav_tensor.device),
|
242 |
+
"output_layer": 9, # layer 9
|
243 |
+
}
|
244 |
+
with torch.no_grad():
|
245 |
+
logits = self.hubert.extract_features(**inputs)
|
246 |
+
feats = logits[0]
|
247 |
+
units = feats # .transpose(2, 1)
|
248 |
+
return units
|
249 |
+
|
250 |
+
|
251 |
+
class Audio2ContentVec768L12():
|
252 |
+
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'):
|
253 |
+
self.device = device
|
254 |
+
print(' [Encoder Model] Content Vec')
|
255 |
+
print(' [Loading] ' + path)
|
256 |
+
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", )
|
257 |
+
self.hubert = self.models[0]
|
258 |
+
self.hubert = self.hubert.to(self.device)
|
259 |
+
self.hubert.eval()
|
260 |
+
|
261 |
+
def __call__(self,
|
262 |
+
audio): # B, T
|
263 |
+
# wav_tensor = torch.from_numpy(audio).to(self.device)
|
264 |
+
wav_tensor = audio
|
265 |
+
feats = wav_tensor.view(1, -1)
|
266 |
+
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
|
267 |
+
inputs = {
|
268 |
+
"source": feats.to(wav_tensor.device),
|
269 |
+
"padding_mask": padding_mask.to(wav_tensor.device),
|
270 |
+
"output_layer": 12, # layer 12
|
271 |
+
}
|
272 |
+
with torch.no_grad():
|
273 |
+
logits = self.hubert.extract_features(**inputs)
|
274 |
+
feats = logits[0]
|
275 |
+
units = feats # .transpose(2, 1)
|
276 |
+
return units
|
277 |
+
|
278 |
+
|
279 |
+
class CNHubertSoftFish(torch.nn.Module):
|
280 |
+
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu', gate_size=10):
|
281 |
+
super().__init__()
|
282 |
+
self.device = device
|
283 |
+
self.gate_size = gate_size
|
284 |
+
|
285 |
+
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
286 |
+
"./pretrain/TencentGameMate/chinese-hubert-base")
|
287 |
+
self.model = HubertModel.from_pretrained("./pretrain/TencentGameMate/chinese-hubert-base")
|
288 |
+
self.proj = torch.nn.Sequential(torch.nn.Dropout(0.1), torch.nn.Linear(768, 256))
|
289 |
+
# self.label_embedding = nn.Embedding(128, 256)
|
290 |
+
|
291 |
+
state_dict = torch.load(path, map_location=device)
|
292 |
+
self.load_state_dict(state_dict)
|
293 |
+
|
294 |
+
@torch.no_grad()
|
295 |
+
def forward(self, audio):
|
296 |
+
input_values = self.feature_extractor(
|
297 |
+
audio, sampling_rate=16000, return_tensors="pt"
|
298 |
+
).input_values
|
299 |
+
input_values = input_values.to(self.model.device)
|
300 |
+
|
301 |
+
return self._forward(input_values[0])
|
302 |
+
|
303 |
+
@torch.no_grad()
|
304 |
+
def _forward(self, input_values):
|
305 |
+
features = self.model(input_values)
|
306 |
+
features = self.proj(features.last_hidden_state)
|
307 |
+
|
308 |
+
# Top-k gating
|
309 |
+
topk, indices = torch.topk(features, self.gate_size, dim=2)
|
310 |
+
features = torch.zeros_like(features).scatter(2, indices, topk)
|
311 |
+
features = features / features.sum(2, keepdim=True)
|
312 |
+
|
313 |
+
return features.to(self.device) # .transpose(1, 2)
|
314 |
+
|
315 |
+
|
316 |
+
class Audio2HubertBase():
|
317 |
+
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'):
|
318 |
+
self.device = device
|
319 |
+
print(' [Encoder Model] HuBERT Base')
|
320 |
+
print(' [Loading] ' + path)
|
321 |
+
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", )
|
322 |
+
self.hubert = self.models[0]
|
323 |
+
self.hubert = self.hubert.to(self.device)
|
324 |
+
self.hubert = self.hubert.float()
|
325 |
+
self.hubert.eval()
|
326 |
+
|
327 |
+
def __call__(self,
|
328 |
+
audio): # B, T
|
329 |
+
with torch.no_grad():
|
330 |
+
padding_mask = torch.BoolTensor(audio.shape).fill_(False)
|
331 |
+
inputs = {
|
332 |
+
"source": audio.to(self.device),
|
333 |
+
"padding_mask": padding_mask.to(self.device),
|
334 |
+
"output_layer": 9, # layer 9
|
335 |
+
}
|
336 |
+
logits = self.hubert.extract_features(**inputs)
|
337 |
+
units = self.hubert.final_proj(logits[0])
|
338 |
+
return units
|
339 |
+
|
340 |
+
|
341 |
+
class Audio2HubertBase768():
|
342 |
+
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'):
|
343 |
+
self.device = device
|
344 |
+
print(' [Encoder Model] HuBERT Base')
|
345 |
+
print(' [Loading] ' + path)
|
346 |
+
self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", )
|
347 |
+
self.hubert = self.models[0]
|
348 |
+
self.hubert = self.hubert.to(self.device)
|
349 |
+
self.hubert = self.hubert.float()
|
350 |
+
self.hubert.eval()
|
351 |
+
|
352 |
+
def __call__(self,
|
353 |
+
audio): # B, T
|
354 |
+
with torch.no_grad():
|
355 |
+
padding_mask = torch.BoolTensor(audio.shape).fill_(False)
|
356 |
+
inputs = {
|
357 |
+
"source": audio.to(self.device),
|
358 |
+
"padding_mask": padding_mask.to(self.device),
|
359 |
+
"output_layer": 9, # layer 9
|
360 |
+
}
|
361 |
+
logits = self.hubert.extract_features(**inputs)
|
362 |
+
units = logits[0]
|
363 |
+
return units
|
364 |
+
|
365 |
+
|
366 |
+
class DotDict(dict):
|
367 |
+
def __getattr__(*args):
|
368 |
+
val = dict.get(*args)
|
369 |
+
return DotDict(val) if type(val) is dict else val
|
370 |
+
|
371 |
+
__setattr__ = dict.__setitem__
|
372 |
+
__delattr__ = dict.__delitem__
|
373 |
+
|
374 |
+
def load_model(
|
375 |
+
model_path,
|
376 |
+
device='cpu'):
|
377 |
+
config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml')
|
378 |
+
with open(config_file, "r") as config:
|
379 |
+
args = yaml.safe_load(config)
|
380 |
+
args = DotDict(args)
|
381 |
+
|
382 |
+
# load model
|
383 |
+
model = None
|
384 |
+
|
385 |
+
if args.model.type == 'Sins':
|
386 |
+
model = Sins(
|
387 |
+
sampling_rate=args.data.sampling_rate,
|
388 |
+
block_size=args.data.block_size,
|
389 |
+
n_harmonics=args.model.n_harmonics,
|
390 |
+
n_mag_allpass=args.model.n_mag_allpass,
|
391 |
+
n_mag_noise=args.model.n_mag_noise,
|
392 |
+
n_unit=args.data.encoder_out_channels,
|
393 |
+
n_spk=args.model.n_spk)
|
394 |
+
|
395 |
+
elif args.model.type == 'CombSub':
|
396 |
+
model = CombSub(
|
397 |
+
sampling_rate=args.data.sampling_rate,
|
398 |
+
block_size=args.data.block_size,
|
399 |
+
n_mag_allpass=args.model.n_mag_allpass,
|
400 |
+
n_mag_harmonic=args.model.n_mag_harmonic,
|
401 |
+
n_mag_noise=args.model.n_mag_noise,
|
402 |
+
n_unit=args.data.encoder_out_channels,
|
403 |
+
n_spk=args.model.n_spk)
|
404 |
+
|
405 |
+
elif args.model.type == 'CombSubFast':
|
406 |
+
model = CombSubFast(
|
407 |
+
sampling_rate=args.data.sampling_rate,
|
408 |
+
block_size=args.data.block_size,
|
409 |
+
n_unit=args.data.encoder_out_channels,
|
410 |
+
n_spk=args.model.n_spk)
|
411 |
+
|
412 |
+
else:
|
413 |
+
raise ValueError(f" [x] Unknown Model: {args.model.type}")
|
414 |
+
|
415 |
+
print(' [Loading] ' + model_path)
|
416 |
+
ckpt = torch.load(model_path, map_location=torch.device(device))
|
417 |
+
model.to(device)
|
418 |
+
model.load_state_dict(ckpt['model'])
|
419 |
+
model.eval()
|
420 |
+
return model, args
|
421 |
+
|
422 |
+
|
423 |
+
class Sins(torch.nn.Module):
|
424 |
+
def __init__(self,
|
425 |
+
sampling_rate,
|
426 |
+
block_size,
|
427 |
+
n_harmonics,
|
428 |
+
n_mag_allpass,
|
429 |
+
n_mag_noise,
|
430 |
+
n_unit=256,
|
431 |
+
n_spk=1):
|
432 |
+
super().__init__()
|
433 |
+
|
434 |
+
print(' [DDSP Model] Sinusoids Additive Synthesiser')
|
435 |
+
|
436 |
+
# params
|
437 |
+
self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
|
438 |
+
self.register_buffer("block_size", torch.tensor(block_size))
|
439 |
+
# Unit2Control
|
440 |
+
split_map = {
|
441 |
+
'amplitudes': n_harmonics,
|
442 |
+
'group_delay': n_mag_allpass,
|
443 |
+
'noise_magnitude': n_mag_noise,
|
444 |
+
}
|
445 |
+
self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map)
|
446 |
+
|
447 |
+
def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, max_upsample_dim=32):
|
448 |
+
'''
|
449 |
+
units_frames: B x n_frames x n_unit
|
450 |
+
f0_frames: B x n_frames x 1
|
451 |
+
volume_frames: B x n_frames x 1
|
452 |
+
spk_id: B x 1
|
453 |
+
'''
|
454 |
+
# exciter phase
|
455 |
+
f0 = upsample(f0_frames, self.block_size)
|
456 |
+
if infer:
|
457 |
+
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
|
458 |
+
else:
|
459 |
+
x = torch.cumsum(f0 / self.sampling_rate, axis=1)
|
460 |
+
if initial_phase is not None:
|
461 |
+
x += initial_phase.to(x) / 2 / np.pi
|
462 |
+
x = x - torch.round(x)
|
463 |
+
x = x.to(f0)
|
464 |
+
|
465 |
+
phase = 2 * np.pi * x
|
466 |
+
phase_frames = phase[:, ::self.block_size, :]
|
467 |
+
|
468 |
+
# parameter prediction
|
469 |
+
ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict)
|
470 |
+
|
471 |
+
amplitudes_frames = torch.exp(ctrls['amplitudes'])/ 128
|
472 |
+
group_delay = np.pi * torch.tanh(ctrls['group_delay'])
|
473 |
+
noise_param = torch.exp(ctrls['noise_magnitude']) / 128
|
474 |
+
|
475 |
+
# sinusoids exciter signal
|
476 |
+
amplitudes_frames = remove_above_fmax(amplitudes_frames, f0_frames, self.sampling_rate / 2, level_start = 1)
|
477 |
+
n_harmonic = amplitudes_frames.shape[-1]
|
478 |
+
level_harmonic = torch.arange(1, n_harmonic + 1).to(phase)
|
479 |
+
sinusoids = 0.
|
480 |
+
for n in range(( n_harmonic - 1) // max_upsample_dim + 1):
|
481 |
+
start = n * max_upsample_dim
|
482 |
+
end = (n + 1) * max_upsample_dim
|
483 |
+
phases = phase * level_harmonic[start:end]
|
484 |
+
amplitudes = upsample(amplitudes_frames[:,:,start:end], self.block_size)
|
485 |
+
sinusoids += (torch.sin(phases) * amplitudes).sum(-1)
|
486 |
+
|
487 |
+
# harmonic part filter (apply group-delay)
|
488 |
+
harmonic = frequency_filter(
|
489 |
+
sinusoids,
|
490 |
+
torch.exp(1.j * torch.cumsum(group_delay, axis = -1)),
|
491 |
+
hann_window = False)
|
492 |
+
|
493 |
+
# noise part filter
|
494 |
+
noise = torch.rand_like(harmonic) * 2 - 1
|
495 |
+
noise = frequency_filter(
|
496 |
+
noise,
|
497 |
+
torch.complex(noise_param, torch.zeros_like(noise_param)),
|
498 |
+
hann_window = True)
|
499 |
+
|
500 |
+
signal = harmonic + noise
|
501 |
+
|
502 |
+
return signal, phase, (harmonic, noise) #, (noise_param, noise_param)
|
503 |
+
|
504 |
+
class CombSubFast(torch.nn.Module):
|
505 |
+
def __init__(self,
|
506 |
+
sampling_rate,
|
507 |
+
block_size,
|
508 |
+
n_unit=256,
|
509 |
+
n_spk=1):
|
510 |
+
super().__init__()
|
511 |
+
|
512 |
+
print(' [DDSP Model] Combtooth Subtractive Synthesiser')
|
513 |
+
# params
|
514 |
+
self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
|
515 |
+
self.register_buffer("block_size", torch.tensor(block_size))
|
516 |
+
self.register_buffer("window", torch.sqrt(torch.hann_window(2 * block_size)))
|
517 |
+
#Unit2Control
|
518 |
+
split_map = {
|
519 |
+
'harmonic_magnitude': block_size + 1,
|
520 |
+
'harmonic_phase': block_size + 1,
|
521 |
+
'noise_magnitude': block_size + 1
|
522 |
+
}
|
523 |
+
self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map)
|
524 |
+
|
525 |
+
def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, **kwargs):
|
526 |
+
'''
|
527 |
+
units_frames: B x n_frames x n_unit
|
528 |
+
f0_frames: B x n_frames x 1
|
529 |
+
volume_frames: B x n_frames x 1
|
530 |
+
spk_id: B x 1
|
531 |
+
'''
|
532 |
+
# exciter phase
|
533 |
+
f0 = upsample(f0_frames, self.block_size)
|
534 |
+
if infer:
|
535 |
+
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
|
536 |
+
else:
|
537 |
+
x = torch.cumsum(f0 / self.sampling_rate, axis=1)
|
538 |
+
if initial_phase is not None:
|
539 |
+
x += initial_phase.to(x) / 2 / np.pi
|
540 |
+
x = x - torch.round(x)
|
541 |
+
x = x.to(f0)
|
542 |
+
|
543 |
+
phase_frames = 2 * np.pi * x[:, ::self.block_size, :]
|
544 |
+
|
545 |
+
# parameter prediction
|
546 |
+
ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict)
|
547 |
+
|
548 |
+
src_filter = torch.exp(ctrls['harmonic_magnitude'] + 1.j * np.pi * ctrls['harmonic_phase'])
|
549 |
+
src_filter = torch.cat((src_filter, src_filter[:,-1:,:]), 1)
|
550 |
+
noise_filter= torch.exp(ctrls['noise_magnitude']) / 128
|
551 |
+
noise_filter = torch.cat((noise_filter, noise_filter[:,-1:,:]), 1)
|
552 |
+
|
553 |
+
# combtooth exciter signal
|
554 |
+
combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3))
|
555 |
+
combtooth = combtooth.squeeze(-1)
|
556 |
+
combtooth_frames = F.pad(combtooth, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size)
|
557 |
+
combtooth_frames = combtooth_frames * self.window
|
558 |
+
combtooth_fft = torch.fft.rfft(combtooth_frames, 2 * self.block_size)
|
559 |
+
|
560 |
+
# noise exciter signal
|
561 |
+
noise = torch.rand_like(combtooth) * 2 - 1
|
562 |
+
noise_frames = F.pad(noise, (self.block_size, self.block_size)).unfold(1, 2 * self.block_size, self.block_size)
|
563 |
+
noise_frames = noise_frames * self.window
|
564 |
+
noise_fft = torch.fft.rfft(noise_frames, 2 * self.block_size)
|
565 |
+
|
566 |
+
# apply the filters
|
567 |
+
signal_fft = combtooth_fft * src_filter + noise_fft * noise_filter
|
568 |
+
|
569 |
+
# take the ifft to resynthesize audio.
|
570 |
+
signal_frames_out = torch.fft.irfft(signal_fft, 2 * self.block_size) * self.window
|
571 |
+
|
572 |
+
# overlap add
|
573 |
+
fold = torch.nn.Fold(output_size=(1, (signal_frames_out.size(1) + 1) * self.block_size), kernel_size=(1, 2 * self.block_size), stride=(1, self.block_size))
|
574 |
+
signal = fold(signal_frames_out.transpose(1, 2))[:, 0, 0, self.block_size : -self.block_size]
|
575 |
+
|
576 |
+
return signal, phase_frames, (signal, signal)
|
577 |
+
|
578 |
+
class CombSub(torch.nn.Module):
|
579 |
+
def __init__(self,
|
580 |
+
sampling_rate,
|
581 |
+
block_size,
|
582 |
+
n_mag_allpass,
|
583 |
+
n_mag_harmonic,
|
584 |
+
n_mag_noise,
|
585 |
+
n_unit=256,
|
586 |
+
n_spk=1):
|
587 |
+
super().__init__()
|
588 |
+
|
589 |
+
print(' [DDSP Model] Combtooth Subtractive Synthesiser (Old Version)')
|
590 |
+
# params
|
591 |
+
self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
|
592 |
+
self.register_buffer("block_size", torch.tensor(block_size))
|
593 |
+
#Unit2Control
|
594 |
+
split_map = {
|
595 |
+
'group_delay': n_mag_allpass,
|
596 |
+
'harmonic_magnitude': n_mag_harmonic,
|
597 |
+
'noise_magnitude': n_mag_noise
|
598 |
+
}
|
599 |
+
self.unit2ctrl = Unit2Control(n_unit, n_spk, split_map)
|
600 |
+
|
601 |
+
def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_dict=None, initial_phase=None, infer=True, **kwargs):
|
602 |
+
'''
|
603 |
+
units_frames: B x n_frames x n_unit
|
604 |
+
f0_frames: B x n_frames x 1
|
605 |
+
volume_frames: B x n_frames x 1
|
606 |
+
spk_id: B x 1
|
607 |
+
'''
|
608 |
+
# exciter phase
|
609 |
+
f0 = upsample(f0_frames, self.block_size)
|
610 |
+
if infer:
|
611 |
+
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
|
612 |
+
else:
|
613 |
+
x = torch.cumsum(f0 / self.sampling_rate, axis=1)
|
614 |
+
if initial_phase is not None:
|
615 |
+
x += initial_phase.to(x) / 2 / np.pi
|
616 |
+
x = x - torch.round(x)
|
617 |
+
x = x.to(f0)
|
618 |
+
|
619 |
+
phase_frames = 2 * np.pi * x[:, ::self.block_size, :]
|
620 |
+
|
621 |
+
# parameter prediction
|
622 |
+
ctrls = self.unit2ctrl(units_frames, f0_frames, phase_frames, volume_frames, spk_id=spk_id, spk_mix_dict=spk_mix_dict)
|
623 |
+
|
624 |
+
group_delay = np.pi * torch.tanh(ctrls['group_delay'])
|
625 |
+
src_param = torch.exp(ctrls['harmonic_magnitude'])
|
626 |
+
noise_param = torch.exp(ctrls['noise_magnitude']) / 128
|
627 |
+
|
628 |
+
# combtooth exciter signal
|
629 |
+
combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3))
|
630 |
+
combtooth = combtooth.squeeze(-1)
|
631 |
+
|
632 |
+
# harmonic part filter (using dynamic-windowed LTV-FIR, with group-delay prediction)
|
633 |
+
harmonic = frequency_filter(
|
634 |
+
combtooth,
|
635 |
+
torch.exp(1.j * torch.cumsum(group_delay, axis = -1)),
|
636 |
+
hann_window = False)
|
637 |
+
harmonic = frequency_filter(
|
638 |
+
harmonic,
|
639 |
+
torch.complex(src_param, torch.zeros_like(src_param)),
|
640 |
+
hann_window = True,
|
641 |
+
half_width_frames = 1.5 * self.sampling_rate / (f0_frames + 1e-3))
|
642 |
+
|
643 |
+
# noise part filter (using constant-windowed LTV-FIR, without group-delay)
|
644 |
+
noise = torch.rand_like(harmonic) * 2 - 1
|
645 |
+
noise = frequency_filter(
|
646 |
+
noise,
|
647 |
+
torch.complex(noise_param, torch.zeros_like(noise_param)),
|
648 |
+
hann_window = True)
|
649 |
+
|
650 |
+
signal = harmonic + noise
|
651 |
+
|
652 |
+
return signal, phase_frames, (harmonic, noise)
|
DDSP-SVC/diffusion/data_loaders.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
import librosa
|
5 |
+
import torch
|
6 |
+
import random
|
7 |
+
from tqdm import tqdm
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
def traverse_dir(
|
11 |
+
root_dir,
|
12 |
+
extension,
|
13 |
+
amount=None,
|
14 |
+
str_include=None,
|
15 |
+
str_exclude=None,
|
16 |
+
is_pure=False,
|
17 |
+
is_sort=False,
|
18 |
+
is_ext=True):
|
19 |
+
|
20 |
+
file_list = []
|
21 |
+
cnt = 0
|
22 |
+
for root, _, files in os.walk(root_dir):
|
23 |
+
for file in files:
|
24 |
+
if file.endswith(extension):
|
25 |
+
# path
|
26 |
+
mix_path = os.path.join(root, file)
|
27 |
+
pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
|
28 |
+
|
29 |
+
# amount
|
30 |
+
if (amount is not None) and (cnt == amount):
|
31 |
+
if is_sort:
|
32 |
+
file_list.sort()
|
33 |
+
return file_list
|
34 |
+
|
35 |
+
# check string
|
36 |
+
if (str_include is not None) and (str_include not in pure_path):
|
37 |
+
continue
|
38 |
+
if (str_exclude is not None) and (str_exclude in pure_path):
|
39 |
+
continue
|
40 |
+
|
41 |
+
if not is_ext:
|
42 |
+
ext = pure_path.split('.')[-1]
|
43 |
+
pure_path = pure_path[:-(len(ext)+1)]
|
44 |
+
file_list.append(pure_path)
|
45 |
+
cnt += 1
|
46 |
+
if is_sort:
|
47 |
+
file_list.sort()
|
48 |
+
return file_list
|
49 |
+
|
50 |
+
|
51 |
+
def get_data_loaders(args, whole_audio=False):
|
52 |
+
data_train = AudioDataset(
|
53 |
+
args.data.train_path,
|
54 |
+
waveform_sec=args.data.duration,
|
55 |
+
hop_size=args.data.block_size,
|
56 |
+
sample_rate=args.data.sampling_rate,
|
57 |
+
load_all_data=args.train.cache_all_data,
|
58 |
+
whole_audio=whole_audio,
|
59 |
+
n_spk=args.model.n_spk,
|
60 |
+
device=args.train.cache_device,
|
61 |
+
fp16=args.train.cache_fp16,
|
62 |
+
use_aug=True)
|
63 |
+
loader_train = torch.utils.data.DataLoader(
|
64 |
+
data_train ,
|
65 |
+
batch_size=args.train.batch_size if not whole_audio else 1,
|
66 |
+
shuffle=True,
|
67 |
+
num_workers=args.train.num_workers if args.train.cache_device=='cpu' else 0,
|
68 |
+
persistent_workers=(args.train.num_workers > 0) if args.train.cache_device=='cpu' else False,
|
69 |
+
pin_memory=True if args.train.cache_device=='cpu' else False
|
70 |
+
)
|
71 |
+
data_valid = AudioDataset(
|
72 |
+
args.data.valid_path,
|
73 |
+
waveform_sec=args.data.duration,
|
74 |
+
hop_size=args.data.block_size,
|
75 |
+
sample_rate=args.data.sampling_rate,
|
76 |
+
load_all_data=args.train.cache_all_data,
|
77 |
+
whole_audio=True,
|
78 |
+
n_spk=args.model.n_spk)
|
79 |
+
loader_valid = torch.utils.data.DataLoader(
|
80 |
+
data_valid,
|
81 |
+
batch_size=1,
|
82 |
+
shuffle=False,
|
83 |
+
num_workers=0,
|
84 |
+
pin_memory=True
|
85 |
+
)
|
86 |
+
return loader_train, loader_valid
|
87 |
+
|
88 |
+
|
89 |
+
class AudioDataset(Dataset):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
path_root,
|
93 |
+
waveform_sec,
|
94 |
+
hop_size,
|
95 |
+
sample_rate,
|
96 |
+
load_all_data=True,
|
97 |
+
whole_audio=False,
|
98 |
+
n_spk=1,
|
99 |
+
device='cpu',
|
100 |
+
fp16=False,
|
101 |
+
use_aug=False,
|
102 |
+
):
|
103 |
+
super().__init__()
|
104 |
+
|
105 |
+
self.waveform_sec = waveform_sec
|
106 |
+
self.sample_rate = sample_rate
|
107 |
+
self.hop_size = hop_size
|
108 |
+
self.path_root = path_root
|
109 |
+
self.paths = traverse_dir(
|
110 |
+
os.path.join(path_root, 'audio'),
|
111 |
+
extension='wav',
|
112 |
+
is_pure=True,
|
113 |
+
is_sort=True,
|
114 |
+
is_ext=False
|
115 |
+
)
|
116 |
+
self.whole_audio = whole_audio
|
117 |
+
self.use_aug = use_aug
|
118 |
+
self.data_buffer={}
|
119 |
+
self.pitch_aug_dict = np.load(os.path.join(self.path_root, 'pitch_aug_dict.npy'), allow_pickle=True).item()
|
120 |
+
if load_all_data:
|
121 |
+
print('Load all the data from :', path_root)
|
122 |
+
else:
|
123 |
+
print('Load the f0, volume data from :', path_root)
|
124 |
+
for name in tqdm(self.paths, total=len(self.paths)):
|
125 |
+
path_audio = os.path.join(self.path_root, 'audio', name) + '.wav'
|
126 |
+
duration = librosa.get_duration(filename = path_audio, sr = self.sample_rate)
|
127 |
+
|
128 |
+
path_f0 = os.path.join(self.path_root, 'f0', name) + '.npy'
|
129 |
+
f0 = np.load(path_f0)
|
130 |
+
f0 = torch.from_numpy(f0).float().unsqueeze(-1).to(device)
|
131 |
+
|
132 |
+
path_volume = os.path.join(self.path_root, 'volume', name) + '.npy'
|
133 |
+
volume = np.load(path_volume)
|
134 |
+
volume = torch.from_numpy(volume).float().unsqueeze(-1).to(device)
|
135 |
+
|
136 |
+
path_augvol = os.path.join(self.path_root, 'aug_vol', name) + '.npy'
|
137 |
+
aug_vol = np.load(path_augvol)
|
138 |
+
aug_vol = torch.from_numpy(aug_vol).float().unsqueeze(-1).to(device)
|
139 |
+
|
140 |
+
if n_spk is not None and n_spk > 1:
|
141 |
+
spk_id = int(os.path.dirname(name)) if str.isdigit(os.path.dirname(name)) else 0
|
142 |
+
if spk_id < 1 or spk_id > n_spk:
|
143 |
+
raise ValueError(' [x] Muiti-speaker traing error : spk_id must be a positive integer from 1 to n_spk ')
|
144 |
+
else:
|
145 |
+
spk_id = 1
|
146 |
+
spk_id = torch.LongTensor(np.array([spk_id])).to(device)
|
147 |
+
|
148 |
+
if load_all_data:
|
149 |
+
'''
|
150 |
+
audio, sr = librosa.load(path_audio, sr=self.sample_rate)
|
151 |
+
if len(audio.shape) > 1:
|
152 |
+
audio = librosa.to_mono(audio)
|
153 |
+
audio = torch.from_numpy(audio).to(device)
|
154 |
+
'''
|
155 |
+
path_mel = os.path.join(self.path_root, 'mel', name) + '.npy'
|
156 |
+
mel = np.load(path_mel)
|
157 |
+
mel = torch.from_numpy(mel).to(device)
|
158 |
+
|
159 |
+
path_augmel = os.path.join(self.path_root, 'aug_mel', name) + '.npy'
|
160 |
+
aug_mel = np.load(path_augmel)
|
161 |
+
aug_mel = torch.from_numpy(aug_mel).to(device)
|
162 |
+
|
163 |
+
path_units = os.path.join(self.path_root, 'units', name) + '.npy'
|
164 |
+
units = np.load(path_units)
|
165 |
+
units = torch.from_numpy(units).to(device)
|
166 |
+
|
167 |
+
if fp16:
|
168 |
+
mel = mel.half()
|
169 |
+
aug_mel = aug_mel.half()
|
170 |
+
units = units.half()
|
171 |
+
|
172 |
+
self.data_buffer[name] = {
|
173 |
+
'duration': duration,
|
174 |
+
'mel': mel,
|
175 |
+
'aug_mel': aug_mel,
|
176 |
+
'units': units,
|
177 |
+
'f0': f0,
|
178 |
+
'volume': volume,
|
179 |
+
'aug_vol': aug_vol,
|
180 |
+
'spk_id': spk_id
|
181 |
+
}
|
182 |
+
else:
|
183 |
+
self.data_buffer[name] = {
|
184 |
+
'duration': duration,
|
185 |
+
'f0': f0,
|
186 |
+
'volume': volume,
|
187 |
+
'aug_vol': aug_vol,
|
188 |
+
'spk_id': spk_id
|
189 |
+
}
|
190 |
+
|
191 |
+
|
192 |
+
def __getitem__(self, file_idx):
|
193 |
+
name = self.paths[file_idx]
|
194 |
+
data_buffer = self.data_buffer[name]
|
195 |
+
# check duration. if too short, then skip
|
196 |
+
if data_buffer['duration'] < (self.waveform_sec + 0.1):
|
197 |
+
return self.__getitem__( (file_idx + 1) % len(self.paths))
|
198 |
+
|
199 |
+
# get item
|
200 |
+
return self.get_data(name, data_buffer)
|
201 |
+
|
202 |
+
def get_data(self, name, data_buffer):
|
203 |
+
frame_resolution = self.hop_size / self.sample_rate
|
204 |
+
duration = data_buffer['duration']
|
205 |
+
waveform_sec = duration if self.whole_audio else self.waveform_sec
|
206 |
+
|
207 |
+
# load audio
|
208 |
+
idx_from = 0 if self.whole_audio else random.uniform(0, duration - waveform_sec - 0.1)
|
209 |
+
start_frame = int(idx_from / frame_resolution)
|
210 |
+
units_frame_len = int(waveform_sec / frame_resolution)
|
211 |
+
aug_flag = random.choice([True, False]) and self.use_aug
|
212 |
+
'''
|
213 |
+
audio = data_buffer.get('audio')
|
214 |
+
if audio is None:
|
215 |
+
path_audio = os.path.join(self.path_root, 'audio', name) + '.wav'
|
216 |
+
audio, sr = librosa.load(
|
217 |
+
path_audio,
|
218 |
+
sr = self.sample_rate,
|
219 |
+
offset = start_frame * frame_resolution,
|
220 |
+
duration = waveform_sec)
|
221 |
+
if len(audio.shape) > 1:
|
222 |
+
audio = librosa.to_mono(audio)
|
223 |
+
# clip audio into N seconds
|
224 |
+
audio = audio[ : audio.shape[-1] // self.hop_size * self.hop_size]
|
225 |
+
audio = torch.from_numpy(audio).float()
|
226 |
+
else:
|
227 |
+
audio = audio[start_frame * self.hop_size : (start_frame + units_frame_len) * self.hop_size]
|
228 |
+
'''
|
229 |
+
# load mel
|
230 |
+
mel_key = 'aug_mel' if aug_flag else 'mel'
|
231 |
+
mel = data_buffer.get(mel_key)
|
232 |
+
if mel is None:
|
233 |
+
mel = os.path.join(self.path_root, mel_key, name) + '.npy'
|
234 |
+
mel = np.load(mel)
|
235 |
+
mel = mel[start_frame : start_frame + units_frame_len]
|
236 |
+
mel = torch.from_numpy(mel).float()
|
237 |
+
else:
|
238 |
+
mel = mel[start_frame : start_frame + units_frame_len]
|
239 |
+
|
240 |
+
# load units
|
241 |
+
units = data_buffer.get('units')
|
242 |
+
if units is None:
|
243 |
+
units = os.path.join(self.path_root, 'units', name) + '.npy'
|
244 |
+
units = np.load(units)
|
245 |
+
units = units[start_frame : start_frame + units_frame_len]
|
246 |
+
units = torch.from_numpy(units).float()
|
247 |
+
else:
|
248 |
+
units = units[start_frame : start_frame + units_frame_len]
|
249 |
+
|
250 |
+
# load f0
|
251 |
+
f0 = data_buffer.get('f0')
|
252 |
+
aug_shift = 0
|
253 |
+
if aug_flag:
|
254 |
+
aug_shift = self.pitch_aug_dict[name]
|
255 |
+
f0_frames = 2 ** (aug_shift / 12) * f0[start_frame : start_frame + units_frame_len]
|
256 |
+
|
257 |
+
# load volume
|
258 |
+
vol_key = 'aug_vol' if aug_flag else 'volume'
|
259 |
+
volume = data_buffer.get(vol_key)
|
260 |
+
volume_frames = volume[start_frame : start_frame + units_frame_len]
|
261 |
+
|
262 |
+
# load spk_id
|
263 |
+
spk_id = data_buffer.get('spk_id')
|
264 |
+
|
265 |
+
# load shift
|
266 |
+
aug_shift = torch.LongTensor(np.array([[aug_shift]]))
|
267 |
+
|
268 |
+
return dict(mel=mel, f0=f0_frames, volume=volume_frames, units=units, spk_id=spk_id, aug_shift=aug_shift, name=name)
|
269 |
+
|
270 |
+
def __len__(self):
|
271 |
+
return len(self.paths)
|
DDSP-SVC/diffusion/diffusion.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import deque
|
2 |
+
from functools import partial
|
3 |
+
from inspect import isfunction
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import librosa.sequence
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
|
12 |
+
def exists(x):
|
13 |
+
return x is not None
|
14 |
+
|
15 |
+
|
16 |
+
def default(val, d):
|
17 |
+
if exists(val):
|
18 |
+
return val
|
19 |
+
return d() if isfunction(d) else d
|
20 |
+
|
21 |
+
|
22 |
+
def extract(a, t, x_shape):
|
23 |
+
b, *_ = t.shape
|
24 |
+
out = a.gather(-1, t)
|
25 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
26 |
+
|
27 |
+
|
28 |
+
def noise_like(shape, device, repeat=False):
|
29 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
30 |
+
noise = lambda: torch.randn(shape, device=device)
|
31 |
+
return repeat_noise() if repeat else noise()
|
32 |
+
|
33 |
+
|
34 |
+
def linear_beta_schedule(timesteps, max_beta=0.02):
|
35 |
+
"""
|
36 |
+
linear schedule
|
37 |
+
"""
|
38 |
+
betas = np.linspace(1e-4, max_beta, timesteps)
|
39 |
+
return betas
|
40 |
+
|
41 |
+
|
42 |
+
def cosine_beta_schedule(timesteps, s=0.008):
|
43 |
+
"""
|
44 |
+
cosine schedule
|
45 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
46 |
+
"""
|
47 |
+
steps = timesteps + 1
|
48 |
+
x = np.linspace(0, steps, steps)
|
49 |
+
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
|
50 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
51 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
52 |
+
return np.clip(betas, a_min=0, a_max=0.999)
|
53 |
+
|
54 |
+
|
55 |
+
beta_schedule = {
|
56 |
+
"cosine": cosine_beta_schedule,
|
57 |
+
"linear": linear_beta_schedule,
|
58 |
+
}
|
59 |
+
|
60 |
+
|
61 |
+
class GaussianDiffusion(nn.Module):
|
62 |
+
def __init__(self,
|
63 |
+
denoise_fn,
|
64 |
+
out_dims=128,
|
65 |
+
timesteps=1000,
|
66 |
+
k_step=1000,
|
67 |
+
max_beta=0.02,
|
68 |
+
spec_min=-12,
|
69 |
+
spec_max=2):
|
70 |
+
super().__init__()
|
71 |
+
self.denoise_fn = denoise_fn
|
72 |
+
self.out_dims = out_dims
|
73 |
+
betas = beta_schedule['linear'](timesteps, max_beta=max_beta)
|
74 |
+
|
75 |
+
alphas = 1. - betas
|
76 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
77 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
78 |
+
|
79 |
+
timesteps, = betas.shape
|
80 |
+
self.num_timesteps = int(timesteps)
|
81 |
+
self.k_step = k_step
|
82 |
+
|
83 |
+
self.noise_list = deque(maxlen=4)
|
84 |
+
|
85 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
86 |
+
|
87 |
+
self.register_buffer('betas', to_torch(betas))
|
88 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
89 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
90 |
+
|
91 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
92 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
93 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
94 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
95 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
96 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
97 |
+
|
98 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
99 |
+
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
100 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
101 |
+
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
102 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
103 |
+
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
104 |
+
self.register_buffer('posterior_mean_coef1', to_torch(
|
105 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
106 |
+
self.register_buffer('posterior_mean_coef2', to_torch(
|
107 |
+
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
108 |
+
|
109 |
+
self.register_buffer('spec_min', torch.FloatTensor([spec_min])[None, None, :out_dims])
|
110 |
+
self.register_buffer('spec_max', torch.FloatTensor([spec_max])[None, None, :out_dims])
|
111 |
+
|
112 |
+
def q_mean_variance(self, x_start, t):
|
113 |
+
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
114 |
+
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
115 |
+
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
116 |
+
return mean, variance, log_variance
|
117 |
+
|
118 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
119 |
+
return (
|
120 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
121 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
122 |
+
)
|
123 |
+
|
124 |
+
def q_posterior(self, x_start, x_t, t):
|
125 |
+
posterior_mean = (
|
126 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
127 |
+
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
128 |
+
)
|
129 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
130 |
+
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
131 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
132 |
+
|
133 |
+
def p_mean_variance(self, x, t, cond):
|
134 |
+
noise_pred = self.denoise_fn(x, t, cond=cond)
|
135 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
|
136 |
+
|
137 |
+
x_recon.clamp_(-1., 1.)
|
138 |
+
|
139 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
140 |
+
return model_mean, posterior_variance, posterior_log_variance
|
141 |
+
|
142 |
+
@torch.no_grad()
|
143 |
+
def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
|
144 |
+
b, *_, device = *x.shape, x.device
|
145 |
+
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond)
|
146 |
+
noise = noise_like(x.shape, device, repeat_noise)
|
147 |
+
# no noise when t == 0
|
148 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
149 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
150 |
+
|
151 |
+
@torch.no_grad()
|
152 |
+
def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False):
|
153 |
+
"""
|
154 |
+
Use the PLMS method from
|
155 |
+
[Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
|
156 |
+
"""
|
157 |
+
|
158 |
+
def get_x_pred(x, noise_t, t):
|
159 |
+
a_t = extract(self.alphas_cumprod, t, x.shape)
|
160 |
+
a_prev = extract(self.alphas_cumprod, torch.max(t - interval, torch.zeros_like(t)), x.shape)
|
161 |
+
a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
|
162 |
+
|
163 |
+
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (
|
164 |
+
a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
|
165 |
+
x_pred = x + x_delta
|
166 |
+
|
167 |
+
return x_pred
|
168 |
+
|
169 |
+
noise_list = self.noise_list
|
170 |
+
noise_pred = self.denoise_fn(x, t, cond=cond)
|
171 |
+
|
172 |
+
if len(noise_list) == 0:
|
173 |
+
x_pred = get_x_pred(x, noise_pred, t)
|
174 |
+
noise_pred_prev = self.denoise_fn(x_pred, max(t - interval, 0), cond=cond)
|
175 |
+
noise_pred_prime = (noise_pred + noise_pred_prev) / 2
|
176 |
+
elif len(noise_list) == 1:
|
177 |
+
noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2
|
178 |
+
elif len(noise_list) == 2:
|
179 |
+
noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12
|
180 |
+
else:
|
181 |
+
noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24
|
182 |
+
|
183 |
+
x_prev = get_x_pred(x, noise_pred_prime, t)
|
184 |
+
noise_list.append(noise_pred)
|
185 |
+
|
186 |
+
return x_prev
|
187 |
+
|
188 |
+
def q_sample(self, x_start, t, noise=None):
|
189 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
190 |
+
return (
|
191 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
192 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
193 |
+
)
|
194 |
+
|
195 |
+
def p_losses(self, x_start, t, cond, noise=None, loss_type='l2'):
|
196 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
197 |
+
|
198 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
199 |
+
x_recon = self.denoise_fn(x_noisy, t, cond)
|
200 |
+
|
201 |
+
if loss_type == 'l1':
|
202 |
+
loss = (noise - x_recon).abs().mean()
|
203 |
+
elif loss_type == 'l2':
|
204 |
+
loss = F.mse_loss(noise, x_recon)
|
205 |
+
else:
|
206 |
+
raise NotImplementedError()
|
207 |
+
|
208 |
+
return loss
|
209 |
+
|
210 |
+
def forward(self,
|
211 |
+
condition,
|
212 |
+
gt_spec=None,
|
213 |
+
infer=True,
|
214 |
+
infer_speedup=10,
|
215 |
+
method='dpm-solver',
|
216 |
+
k_step=300,
|
217 |
+
use_tqdm=True):
|
218 |
+
"""
|
219 |
+
conditioning diffusion, use fastspeech2 encoder output as the condition
|
220 |
+
"""
|
221 |
+
cond = condition.transpose(1, 2)
|
222 |
+
b, device = condition.shape[0], condition.device
|
223 |
+
|
224 |
+
if not infer:
|
225 |
+
spec = self.norm_spec(gt_spec)
|
226 |
+
t = torch.randint(0, self.k_step, (b,), device=device).long()
|
227 |
+
norm_spec = spec.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
|
228 |
+
return self.p_losses(norm_spec, t, cond=cond)
|
229 |
+
else:
|
230 |
+
shape = (cond.shape[0], 1, self.out_dims, cond.shape[2])
|
231 |
+
|
232 |
+
if gt_spec is None:
|
233 |
+
t = self.k_step
|
234 |
+
x = torch.randn(shape, device=device)
|
235 |
+
else:
|
236 |
+
t = k_step
|
237 |
+
norm_spec = self.norm_spec(gt_spec)
|
238 |
+
norm_spec = norm_spec.transpose(1, 2)[:, None, :, :]
|
239 |
+
x = self.q_sample(x_start=norm_spec, t=torch.tensor([t - 1], device=device).long())
|
240 |
+
|
241 |
+
if method is not None and infer_speedup > 1:
|
242 |
+
if method == 'dpm-solver':
|
243 |
+
from .dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
|
244 |
+
# 1. Define the noise schedule.
|
245 |
+
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t])
|
246 |
+
|
247 |
+
# 2. Convert your discrete-time `model` to the continuous-time
|
248 |
+
# noise prediction model. Here is an example for a diffusion model
|
249 |
+
# `model` with the noise prediction type ("noise") .
|
250 |
+
def my_wrapper(fn):
|
251 |
+
def wrapped(x, t, **kwargs):
|
252 |
+
ret = fn(x, t, **kwargs)
|
253 |
+
if use_tqdm:
|
254 |
+
self.bar.update(1)
|
255 |
+
return ret
|
256 |
+
|
257 |
+
return wrapped
|
258 |
+
|
259 |
+
model_fn = model_wrapper(
|
260 |
+
my_wrapper(self.denoise_fn),
|
261 |
+
noise_schedule,
|
262 |
+
model_type="noise", # or "x_start" or "v" or "score"
|
263 |
+
model_kwargs={"cond": cond}
|
264 |
+
)
|
265 |
+
|
266 |
+
# 3. Define dpm-solver and sample by singlestep DPM-Solver.
|
267 |
+
# (We recommend singlestep DPM-Solver for unconditional sampling)
|
268 |
+
# You can adjust the `steps` to balance the computation
|
269 |
+
# costs and the sample quality.
|
270 |
+
dpm_solver = DPM_Solver(model_fn, noise_schedule)
|
271 |
+
|
272 |
+
steps = t // infer_speedup
|
273 |
+
if use_tqdm:
|
274 |
+
self.bar = tqdm(desc="sample time step", total=steps)
|
275 |
+
x = dpm_solver.sample(
|
276 |
+
x,
|
277 |
+
steps=steps,
|
278 |
+
order=3,
|
279 |
+
skip_type="time_uniform",
|
280 |
+
method="singlestep",
|
281 |
+
)
|
282 |
+
if use_tqdm:
|
283 |
+
self.bar.close()
|
284 |
+
elif method == 'pndm':
|
285 |
+
self.noise_list = deque(maxlen=4)
|
286 |
+
if use_tqdm:
|
287 |
+
for i in tqdm(
|
288 |
+
reversed(range(0, t, infer_speedup)), desc='sample time step',
|
289 |
+
total=t // infer_speedup,
|
290 |
+
):
|
291 |
+
x = self.p_sample_plms(
|
292 |
+
x, torch.full((b,), i, device=device, dtype=torch.long),
|
293 |
+
infer_speedup, cond=cond
|
294 |
+
)
|
295 |
+
else:
|
296 |
+
for i in reversed(range(0, t, infer_speedup)):
|
297 |
+
x = self.p_sample_plms(
|
298 |
+
x, torch.full((b,), i, device=device, dtype=torch.long),
|
299 |
+
infer_speedup, cond=cond
|
300 |
+
)
|
301 |
+
else:
|
302 |
+
raise NotImplementedError(method)
|
303 |
+
else:
|
304 |
+
if use_tqdm:
|
305 |
+
for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
|
306 |
+
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
307 |
+
else:
|
308 |
+
for i in reversed(range(0, t)):
|
309 |
+
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
310 |
+
x = x.squeeze(1).transpose(1, 2) # [B, T, M]
|
311 |
+
return self.denorm_spec(x)
|
312 |
+
|
313 |
+
def norm_spec(self, x):
|
314 |
+
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
315 |
+
|
316 |
+
def denorm_spec(self, x):
|
317 |
+
return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
|
DDSP-SVC/diffusion/dpm_solver_pytorch.py
ADDED
@@ -0,0 +1,1201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class NoiseScheduleVP:
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
schedule='discrete',
|
10 |
+
betas=None,
|
11 |
+
alphas_cumprod=None,
|
12 |
+
continuous_beta_0=0.1,
|
13 |
+
continuous_beta_1=20.,
|
14 |
+
):
|
15 |
+
"""Create a wrapper class for the forward SDE (VP type).
|
16 |
+
|
17 |
+
***
|
18 |
+
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
19 |
+
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
|
20 |
+
***
|
21 |
+
|
22 |
+
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
23 |
+
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
24 |
+
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
25 |
+
|
26 |
+
log_alpha_t = self.marginal_log_mean_coeff(t)
|
27 |
+
sigma_t = self.marginal_std(t)
|
28 |
+
lambda_t = self.marginal_lambda(t)
|
29 |
+
|
30 |
+
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
31 |
+
|
32 |
+
t = self.inverse_lambda(lambda_t)
|
33 |
+
|
34 |
+
===============================================================
|
35 |
+
|
36 |
+
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
|
37 |
+
|
38 |
+
1. For discrete-time DPMs:
|
39 |
+
|
40 |
+
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
|
41 |
+
t_i = (i + 1) / N
|
42 |
+
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
|
43 |
+
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
|
47 |
+
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
|
48 |
+
|
49 |
+
Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
|
50 |
+
|
51 |
+
**Important**: Please pay special attention for the args for `alphas_cumprod`:
|
52 |
+
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
|
53 |
+
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
|
54 |
+
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
|
55 |
+
alpha_{t_n} = \sqrt{\hat{alpha_n}},
|
56 |
+
and
|
57 |
+
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
|
58 |
+
|
59 |
+
|
60 |
+
2. For continuous-time DPMs:
|
61 |
+
|
62 |
+
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
|
63 |
+
schedule are the default settings in DDPM and improved-DDPM:
|
64 |
+
|
65 |
+
Args:
|
66 |
+
beta_min: A `float` number. The smallest beta for the linear schedule.
|
67 |
+
beta_max: A `float` number. The largest beta for the linear schedule.
|
68 |
+
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
|
69 |
+
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
|
70 |
+
T: A `float` number. The ending time of the forward process.
|
71 |
+
|
72 |
+
===============================================================
|
73 |
+
|
74 |
+
Args:
|
75 |
+
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
|
76 |
+
'linear' or 'cosine' for continuous-time DPMs.
|
77 |
+
Returns:
|
78 |
+
A wrapper object of the forward SDE (VP type).
|
79 |
+
|
80 |
+
===============================================================
|
81 |
+
|
82 |
+
Example:
|
83 |
+
|
84 |
+
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
|
85 |
+
>>> ns = NoiseScheduleVP('discrete', betas=betas)
|
86 |
+
|
87 |
+
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
|
88 |
+
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
89 |
+
|
90 |
+
# For continuous-time DPMs (VPSDE), linear schedule:
|
91 |
+
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
|
92 |
+
|
93 |
+
"""
|
94 |
+
|
95 |
+
if schedule not in ['discrete', 'linear', 'cosine']:
|
96 |
+
raise ValueError(
|
97 |
+
"Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
|
98 |
+
schedule))
|
99 |
+
|
100 |
+
self.schedule = schedule
|
101 |
+
if schedule == 'discrete':
|
102 |
+
if betas is not None:
|
103 |
+
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
104 |
+
else:
|
105 |
+
assert alphas_cumprod is not None
|
106 |
+
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
107 |
+
self.total_N = len(log_alphas)
|
108 |
+
self.T = 1.
|
109 |
+
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
|
110 |
+
self.log_alpha_array = log_alphas.reshape((1, -1,))
|
111 |
+
else:
|
112 |
+
self.total_N = 1000
|
113 |
+
self.beta_0 = continuous_beta_0
|
114 |
+
self.beta_1 = continuous_beta_1
|
115 |
+
self.cosine_s = 0.008
|
116 |
+
self.cosine_beta_max = 999.
|
117 |
+
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
|
118 |
+
1. + self.cosine_s) / math.pi - self.cosine_s
|
119 |
+
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
|
120 |
+
self.schedule = schedule
|
121 |
+
if schedule == 'cosine':
|
122 |
+
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
123 |
+
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
124 |
+
self.T = 0.9946
|
125 |
+
else:
|
126 |
+
self.T = 1.
|
127 |
+
|
128 |
+
def marginal_log_mean_coeff(self, t):
|
129 |
+
"""
|
130 |
+
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
131 |
+
"""
|
132 |
+
if self.schedule == 'discrete':
|
133 |
+
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
|
134 |
+
self.log_alpha_array.to(t.device)).reshape((-1))
|
135 |
+
elif self.schedule == 'linear':
|
136 |
+
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
137 |
+
elif self.schedule == 'cosine':
|
138 |
+
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
|
139 |
+
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
140 |
+
return log_alpha_t
|
141 |
+
|
142 |
+
def marginal_alpha(self, t):
|
143 |
+
"""
|
144 |
+
Compute alpha_t of a given continuous-time label t in [0, T].
|
145 |
+
"""
|
146 |
+
return torch.exp(self.marginal_log_mean_coeff(t))
|
147 |
+
|
148 |
+
def marginal_std(self, t):
|
149 |
+
"""
|
150 |
+
Compute sigma_t of a given continuous-time label t in [0, T].
|
151 |
+
"""
|
152 |
+
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
153 |
+
|
154 |
+
def marginal_lambda(self, t):
|
155 |
+
"""
|
156 |
+
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
157 |
+
"""
|
158 |
+
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
159 |
+
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
160 |
+
return log_mean_coeff - log_std
|
161 |
+
|
162 |
+
def inverse_lambda(self, lamb):
|
163 |
+
"""
|
164 |
+
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
165 |
+
"""
|
166 |
+
if self.schedule == 'linear':
|
167 |
+
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
168 |
+
Delta = self.beta_0 ** 2 + tmp
|
169 |
+
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
170 |
+
elif self.schedule == 'discrete':
|
171 |
+
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
172 |
+
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
|
173 |
+
torch.flip(self.t_array.to(lamb.device), [1]))
|
174 |
+
return t.reshape((-1,))
|
175 |
+
else:
|
176 |
+
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
177 |
+
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
|
178 |
+
1. + self.cosine_s) / math.pi - self.cosine_s
|
179 |
+
t = t_fn(log_alpha)
|
180 |
+
return t
|
181 |
+
|
182 |
+
|
183 |
+
def model_wrapper(
|
184 |
+
model,
|
185 |
+
noise_schedule,
|
186 |
+
model_type="noise",
|
187 |
+
model_kwargs={},
|
188 |
+
guidance_type="uncond",
|
189 |
+
condition=None,
|
190 |
+
unconditional_condition=None,
|
191 |
+
guidance_scale=1.,
|
192 |
+
classifier_fn=None,
|
193 |
+
classifier_kwargs={},
|
194 |
+
):
|
195 |
+
"""Create a wrapper function for the noise prediction model.
|
196 |
+
|
197 |
+
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
198 |
+
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
|
199 |
+
|
200 |
+
We support four types of the diffusion model by setting `model_type`:
|
201 |
+
|
202 |
+
1. "noise": noise prediction model. (Trained by predicting noise).
|
203 |
+
|
204 |
+
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
|
205 |
+
|
206 |
+
3. "v": velocity prediction model. (Trained by predicting the velocity).
|
207 |
+
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
|
208 |
+
|
209 |
+
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
|
210 |
+
arXiv preprint arXiv:2202.00512 (2022).
|
211 |
+
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
212 |
+
arXiv preprint arXiv:2210.02303 (2022).
|
213 |
+
|
214 |
+
4. "score": marginal score function. (Trained by denoising score matching).
|
215 |
+
Note that the score function and the noise prediction model follows a simple relationship:
|
216 |
+
```
|
217 |
+
noise(x_t, t) = -sigma_t * score(x_t, t)
|
218 |
+
```
|
219 |
+
|
220 |
+
We support three types of guided sampling by DPMs by setting `guidance_type`:
|
221 |
+
1. "uncond": unconditional sampling by DPMs.
|
222 |
+
The input `model` has the following format:
|
223 |
+
``
|
224 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
225 |
+
``
|
226 |
+
|
227 |
+
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
|
228 |
+
The input `model` has the following format:
|
229 |
+
``
|
230 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
231 |
+
``
|
232 |
+
|
233 |
+
The input `classifier_fn` has the following format:
|
234 |
+
``
|
235 |
+
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
|
236 |
+
``
|
237 |
+
|
238 |
+
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
|
239 |
+
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
240 |
+
|
241 |
+
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
|
242 |
+
The input `model` has the following format:
|
243 |
+
``
|
244 |
+
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
245 |
+
``
|
246 |
+
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
247 |
+
|
248 |
+
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
249 |
+
arXiv preprint arXiv:2207.12598 (2022).
|
250 |
+
|
251 |
+
|
252 |
+
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
253 |
+
or continuous-time labels (i.e. epsilon to T).
|
254 |
+
|
255 |
+
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
|
256 |
+
``
|
257 |
+
def model_fn(x, t_continuous) -> noise:
|
258 |
+
t_input = get_model_input_time(t_continuous)
|
259 |
+
return noise_pred(model, x, t_input, **model_kwargs)
|
260 |
+
``
|
261 |
+
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
262 |
+
|
263 |
+
===============================================================
|
264 |
+
|
265 |
+
Args:
|
266 |
+
model: A diffusion model with the corresponding format described above.
|
267 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
268 |
+
model_type: A `str`. The parameterization type of the diffusion model.
|
269 |
+
"noise" or "x_start" or "v" or "score".
|
270 |
+
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
271 |
+
guidance_type: A `str`. The type of the guidance for sampling.
|
272 |
+
"uncond" or "classifier" or "classifier-free".
|
273 |
+
condition: A pytorch tensor. The condition for the guided sampling.
|
274 |
+
Only used for "classifier" or "classifier-free" guidance type.
|
275 |
+
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
|
276 |
+
Only used for "classifier-free" guidance type.
|
277 |
+
guidance_scale: A `float`. The scale for the guided sampling.
|
278 |
+
classifier_fn: A classifier function. Only used for the classifier guidance.
|
279 |
+
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
|
280 |
+
Returns:
|
281 |
+
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
282 |
+
"""
|
283 |
+
|
284 |
+
def get_model_input_time(t_continuous):
|
285 |
+
"""
|
286 |
+
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
287 |
+
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
288 |
+
For continuous-time DPMs, we just use `t_continuous`.
|
289 |
+
"""
|
290 |
+
if noise_schedule.schedule == 'discrete':
|
291 |
+
return (t_continuous - 1. / noise_schedule.total_N) * noise_schedule.total_N
|
292 |
+
else:
|
293 |
+
return t_continuous
|
294 |
+
|
295 |
+
def noise_pred_fn(x, t_continuous, cond=None):
|
296 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
297 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
298 |
+
t_input = get_model_input_time(t_continuous)
|
299 |
+
if cond is None:
|
300 |
+
output = model(x, t_input, **model_kwargs)
|
301 |
+
else:
|
302 |
+
output = model(x, t_input, cond, **model_kwargs)
|
303 |
+
if model_type == "noise":
|
304 |
+
return output
|
305 |
+
elif model_type == "x_start":
|
306 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
307 |
+
dims = x.dim()
|
308 |
+
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
|
309 |
+
elif model_type == "v":
|
310 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
311 |
+
dims = x.dim()
|
312 |
+
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
|
313 |
+
elif model_type == "score":
|
314 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
315 |
+
dims = x.dim()
|
316 |
+
return -expand_dims(sigma_t, dims) * output
|
317 |
+
|
318 |
+
def cond_grad_fn(x, t_input):
|
319 |
+
"""
|
320 |
+
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
321 |
+
"""
|
322 |
+
with torch.enable_grad():
|
323 |
+
x_in = x.detach().requires_grad_(True)
|
324 |
+
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
325 |
+
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
326 |
+
|
327 |
+
def model_fn(x, t_continuous):
|
328 |
+
"""
|
329 |
+
The noise predicition model function that is used for DPM-Solver.
|
330 |
+
"""
|
331 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
332 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
333 |
+
if guidance_type == "uncond":
|
334 |
+
return noise_pred_fn(x, t_continuous)
|
335 |
+
elif guidance_type == "classifier":
|
336 |
+
assert classifier_fn is not None
|
337 |
+
t_input = get_model_input_time(t_continuous)
|
338 |
+
cond_grad = cond_grad_fn(x, t_input)
|
339 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
340 |
+
noise = noise_pred_fn(x, t_continuous)
|
341 |
+
return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
|
342 |
+
elif guidance_type == "classifier-free":
|
343 |
+
if guidance_scale == 1. or unconditional_condition is None:
|
344 |
+
return noise_pred_fn(x, t_continuous, cond=condition)
|
345 |
+
else:
|
346 |
+
x_in = torch.cat([x] * 2)
|
347 |
+
t_in = torch.cat([t_continuous] * 2)
|
348 |
+
c_in = torch.cat([unconditional_condition, condition])
|
349 |
+
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
350 |
+
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
351 |
+
|
352 |
+
assert model_type in ["noise", "x_start", "v"]
|
353 |
+
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
354 |
+
return model_fn
|
355 |
+
|
356 |
+
|
357 |
+
class DPM_Solver:
|
358 |
+
def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
|
359 |
+
"""Construct a DPM-Solver.
|
360 |
+
|
361 |
+
We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
|
362 |
+
If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
|
363 |
+
If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
|
364 |
+
In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
|
365 |
+
The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
|
366 |
+
|
367 |
+
Args:
|
368 |
+
model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
|
369 |
+
``
|
370 |
+
def model_fn(x, t_continuous):
|
371 |
+
return noise
|
372 |
+
``
|
373 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
374 |
+
predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
|
375 |
+
thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
|
376 |
+
max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
|
377 |
+
|
378 |
+
[1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
|
379 |
+
"""
|
380 |
+
self.model = model_fn
|
381 |
+
self.noise_schedule = noise_schedule
|
382 |
+
self.predict_x0 = predict_x0
|
383 |
+
self.thresholding = thresholding
|
384 |
+
self.max_val = max_val
|
385 |
+
|
386 |
+
def noise_prediction_fn(self, x, t):
|
387 |
+
"""
|
388 |
+
Return the noise prediction model.
|
389 |
+
"""
|
390 |
+
return self.model(x, t)
|
391 |
+
|
392 |
+
def data_prediction_fn(self, x, t):
|
393 |
+
"""
|
394 |
+
Return the data prediction model (with thresholding).
|
395 |
+
"""
|
396 |
+
noise = self.noise_prediction_fn(x, t)
|
397 |
+
dims = x.dim()
|
398 |
+
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
399 |
+
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
|
400 |
+
if self.thresholding:
|
401 |
+
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
402 |
+
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
403 |
+
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
404 |
+
x0 = torch.clamp(x0, -s, s) / s
|
405 |
+
return x0
|
406 |
+
|
407 |
+
def model_fn(self, x, t):
|
408 |
+
"""
|
409 |
+
Convert the model to the noise prediction model or the data prediction model.
|
410 |
+
"""
|
411 |
+
if self.predict_x0:
|
412 |
+
return self.data_prediction_fn(x, t)
|
413 |
+
else:
|
414 |
+
return self.noise_prediction_fn(x, t)
|
415 |
+
|
416 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
417 |
+
"""Compute the intermediate time steps for sampling.
|
418 |
+
|
419 |
+
Args:
|
420 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
421 |
+
- 'logSNR': uniform logSNR for the time steps.
|
422 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
423 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
424 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
425 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
426 |
+
N: A `int`. The total number of the spacing of the time steps.
|
427 |
+
device: A torch device.
|
428 |
+
Returns:
|
429 |
+
A pytorch tensor of the time steps, with the shape (N + 1,).
|
430 |
+
"""
|
431 |
+
if skip_type == 'logSNR':
|
432 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
433 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
434 |
+
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
435 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
436 |
+
elif skip_type == 'time_uniform':
|
437 |
+
return torch.linspace(t_T, t_0, N + 1).to(device)
|
438 |
+
elif skip_type == 'time_quadratic':
|
439 |
+
t_order = 2
|
440 |
+
t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
|
441 |
+
return t
|
442 |
+
else:
|
443 |
+
raise ValueError(
|
444 |
+
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
445 |
+
|
446 |
+
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
447 |
+
"""
|
448 |
+
Get the order of each step for sampling by the singlestep DPM-Solver.
|
449 |
+
|
450 |
+
We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
|
451 |
+
Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
|
452 |
+
- If order == 1:
|
453 |
+
We take `steps` of DPM-Solver-1 (i.e. DDIM).
|
454 |
+
- If order == 2:
|
455 |
+
- Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
|
456 |
+
- If steps % 2 == 0, we use K steps of DPM-Solver-2.
|
457 |
+
- If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
458 |
+
- If order == 3:
|
459 |
+
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
|
460 |
+
- If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
461 |
+
- If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
|
462 |
+
- If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
|
463 |
+
|
464 |
+
============================================
|
465 |
+
Args:
|
466 |
+
order: A `int`. The max order for the solver (2 or 3).
|
467 |
+
steps: A `int`. The total number of function evaluations (NFE).
|
468 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
469 |
+
- 'logSNR': uniform logSNR for the time steps.
|
470 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
471 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
472 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
473 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
474 |
+
device: A torch device.
|
475 |
+
Returns:
|
476 |
+
orders: A list of the solver order of each step.
|
477 |
+
"""
|
478 |
+
if order == 3:
|
479 |
+
K = steps // 3 + 1
|
480 |
+
if steps % 3 == 0:
|
481 |
+
orders = [3, ] * (K - 2) + [2, 1]
|
482 |
+
elif steps % 3 == 1:
|
483 |
+
orders = [3, ] * (K - 1) + [1]
|
484 |
+
else:
|
485 |
+
orders = [3, ] * (K - 1) + [2]
|
486 |
+
elif order == 2:
|
487 |
+
if steps % 2 == 0:
|
488 |
+
K = steps // 2
|
489 |
+
orders = [2, ] * K
|
490 |
+
else:
|
491 |
+
K = steps // 2 + 1
|
492 |
+
orders = [2, ] * (K - 1) + [1]
|
493 |
+
elif order == 1:
|
494 |
+
K = 1
|
495 |
+
orders = [1, ] * steps
|
496 |
+
else:
|
497 |
+
raise ValueError("'order' must be '1' or '2' or '3'.")
|
498 |
+
if skip_type == 'logSNR':
|
499 |
+
# To reproduce the results in DPM-Solver paper
|
500 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
501 |
+
else:
|
502 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
|
503 |
+
torch.cumsum(torch.tensor([0, ] + orders), dim=0).to(device)]
|
504 |
+
return timesteps_outer, orders
|
505 |
+
|
506 |
+
def denoise_fn(self, x, s):
|
507 |
+
"""
|
508 |
+
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
509 |
+
"""
|
510 |
+
return self.data_prediction_fn(x, s)
|
511 |
+
|
512 |
+
def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
|
513 |
+
"""
|
514 |
+
DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
|
515 |
+
|
516 |
+
Args:
|
517 |
+
x: A pytorch tensor. The initial value at time `s`.
|
518 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
519 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
520 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
521 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
522 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`.
|
523 |
+
Returns:
|
524 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
525 |
+
"""
|
526 |
+
ns = self.noise_schedule
|
527 |
+
dims = x.dim()
|
528 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
529 |
+
h = lambda_t - lambda_s
|
530 |
+
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
|
531 |
+
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
|
532 |
+
alpha_t = torch.exp(log_alpha_t)
|
533 |
+
|
534 |
+
if self.predict_x0:
|
535 |
+
phi_1 = torch.expm1(-h)
|
536 |
+
if model_s is None:
|
537 |
+
model_s = self.model_fn(x, s)
|
538 |
+
x_t = (
|
539 |
+
expand_dims(sigma_t / sigma_s, dims) * x
|
540 |
+
- expand_dims(alpha_t * phi_1, dims) * model_s
|
541 |
+
)
|
542 |
+
if return_intermediate:
|
543 |
+
return x_t, {'model_s': model_s}
|
544 |
+
else:
|
545 |
+
return x_t
|
546 |
+
else:
|
547 |
+
phi_1 = torch.expm1(h)
|
548 |
+
if model_s is None:
|
549 |
+
model_s = self.model_fn(x, s)
|
550 |
+
x_t = (
|
551 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
|
552 |
+
- expand_dims(sigma_t * phi_1, dims) * model_s
|
553 |
+
)
|
554 |
+
if return_intermediate:
|
555 |
+
return x_t, {'model_s': model_s}
|
556 |
+
else:
|
557 |
+
return x_t
|
558 |
+
|
559 |
+
def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
|
560 |
+
solver_type='dpm_solver'):
|
561 |
+
"""
|
562 |
+
Singlestep solver DPM-Solver-2 from time `s` to time `t`.
|
563 |
+
|
564 |
+
Args:
|
565 |
+
x: A pytorch tensor. The initial value at time `s`.
|
566 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
567 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
568 |
+
r1: A `float`. The hyperparameter of the second-order solver.
|
569 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
570 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
571 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
|
572 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
573 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
574 |
+
Returns:
|
575 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
576 |
+
"""
|
577 |
+
if solver_type not in ['dpm_solver', 'taylor']:
|
578 |
+
raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
|
579 |
+
if r1 is None:
|
580 |
+
r1 = 0.5
|
581 |
+
ns = self.noise_schedule
|
582 |
+
dims = x.dim()
|
583 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
584 |
+
h = lambda_t - lambda_s
|
585 |
+
lambda_s1 = lambda_s + r1 * h
|
586 |
+
s1 = ns.inverse_lambda(lambda_s1)
|
587 |
+
log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
|
588 |
+
s1), ns.marginal_log_mean_coeff(t)
|
589 |
+
sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
|
590 |
+
alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
|
591 |
+
|
592 |
+
if self.predict_x0:
|
593 |
+
phi_11 = torch.expm1(-r1 * h)
|
594 |
+
phi_1 = torch.expm1(-h)
|
595 |
+
|
596 |
+
if model_s is None:
|
597 |
+
model_s = self.model_fn(x, s)
|
598 |
+
x_s1 = (
|
599 |
+
expand_dims(sigma_s1 / sigma_s, dims) * x
|
600 |
+
- expand_dims(alpha_s1 * phi_11, dims) * model_s
|
601 |
+
)
|
602 |
+
model_s1 = self.model_fn(x_s1, s1)
|
603 |
+
if solver_type == 'dpm_solver':
|
604 |
+
x_t = (
|
605 |
+
expand_dims(sigma_t / sigma_s, dims) * x
|
606 |
+
- expand_dims(alpha_t * phi_1, dims) * model_s
|
607 |
+
- (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
|
608 |
+
)
|
609 |
+
elif solver_type == 'taylor':
|
610 |
+
x_t = (
|
611 |
+
expand_dims(sigma_t / sigma_s, dims) * x
|
612 |
+
- expand_dims(alpha_t * phi_1, dims) * model_s
|
613 |
+
+ (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
|
614 |
+
model_s1 - model_s)
|
615 |
+
)
|
616 |
+
else:
|
617 |
+
phi_11 = torch.expm1(r1 * h)
|
618 |
+
phi_1 = torch.expm1(h)
|
619 |
+
|
620 |
+
if model_s is None:
|
621 |
+
model_s = self.model_fn(x, s)
|
622 |
+
x_s1 = (
|
623 |
+
expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
|
624 |
+
- expand_dims(sigma_s1 * phi_11, dims) * model_s
|
625 |
+
)
|
626 |
+
model_s1 = self.model_fn(x_s1, s1)
|
627 |
+
if solver_type == 'dpm_solver':
|
628 |
+
x_t = (
|
629 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
|
630 |
+
- expand_dims(sigma_t * phi_1, dims) * model_s
|
631 |
+
- (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
|
632 |
+
)
|
633 |
+
elif solver_type == 'taylor':
|
634 |
+
x_t = (
|
635 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
|
636 |
+
- expand_dims(sigma_t * phi_1, dims) * model_s
|
637 |
+
- (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
|
638 |
+
)
|
639 |
+
if return_intermediate:
|
640 |
+
return x_t, {'model_s': model_s, 'model_s1': model_s1}
|
641 |
+
else:
|
642 |
+
return x_t
|
643 |
+
|
644 |
+
def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
|
645 |
+
return_intermediate=False, solver_type='dpm_solver'):
|
646 |
+
"""
|
647 |
+
Singlestep solver DPM-Solver-3 from time `s` to time `t`.
|
648 |
+
|
649 |
+
Args:
|
650 |
+
x: A pytorch tensor. The initial value at time `s`.
|
651 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
652 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
653 |
+
r1: A `float`. The hyperparameter of the third-order solver.
|
654 |
+
r2: A `float`. The hyperparameter of the third-order solver.
|
655 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
656 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
657 |
+
model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
|
658 |
+
If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
|
659 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
|
660 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
661 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
662 |
+
Returns:
|
663 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
664 |
+
"""
|
665 |
+
if solver_type not in ['dpm_solver', 'taylor']:
|
666 |
+
raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
|
667 |
+
if r1 is None:
|
668 |
+
r1 = 1. / 3.
|
669 |
+
if r2 is None:
|
670 |
+
r2 = 2. / 3.
|
671 |
+
ns = self.noise_schedule
|
672 |
+
dims = x.dim()
|
673 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
674 |
+
h = lambda_t - lambda_s
|
675 |
+
lambda_s1 = lambda_s + r1 * h
|
676 |
+
lambda_s2 = lambda_s + r2 * h
|
677 |
+
s1 = ns.inverse_lambda(lambda_s1)
|
678 |
+
s2 = ns.inverse_lambda(lambda_s2)
|
679 |
+
log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
|
680 |
+
s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
|
681 |
+
sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
|
682 |
+
s2), ns.marginal_std(t)
|
683 |
+
alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
|
684 |
+
|
685 |
+
if self.predict_x0:
|
686 |
+
phi_11 = torch.expm1(-r1 * h)
|
687 |
+
phi_12 = torch.expm1(-r2 * h)
|
688 |
+
phi_1 = torch.expm1(-h)
|
689 |
+
phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
|
690 |
+
phi_2 = phi_1 / h + 1.
|
691 |
+
phi_3 = phi_2 / h - 0.5
|
692 |
+
|
693 |
+
if model_s is None:
|
694 |
+
model_s = self.model_fn(x, s)
|
695 |
+
if model_s1 is None:
|
696 |
+
x_s1 = (
|
697 |
+
expand_dims(sigma_s1 / sigma_s, dims) * x
|
698 |
+
- expand_dims(alpha_s1 * phi_11, dims) * model_s
|
699 |
+
)
|
700 |
+
model_s1 = self.model_fn(x_s1, s1)
|
701 |
+
x_s2 = (
|
702 |
+
expand_dims(sigma_s2 / sigma_s, dims) * x
|
703 |
+
- expand_dims(alpha_s2 * phi_12, dims) * model_s
|
704 |
+
+ r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
|
705 |
+
)
|
706 |
+
model_s2 = self.model_fn(x_s2, s2)
|
707 |
+
if solver_type == 'dpm_solver':
|
708 |
+
x_t = (
|
709 |
+
expand_dims(sigma_t / sigma_s, dims) * x
|
710 |
+
- expand_dims(alpha_t * phi_1, dims) * model_s
|
711 |
+
+ (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
|
712 |
+
)
|
713 |
+
elif solver_type == 'taylor':
|
714 |
+
D1_0 = (1. / r1) * (model_s1 - model_s)
|
715 |
+
D1_1 = (1. / r2) * (model_s2 - model_s)
|
716 |
+
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
717 |
+
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
|
718 |
+
x_t = (
|
719 |
+
expand_dims(sigma_t / sigma_s, dims) * x
|
720 |
+
- expand_dims(alpha_t * phi_1, dims) * model_s
|
721 |
+
+ expand_dims(alpha_t * phi_2, dims) * D1
|
722 |
+
- expand_dims(alpha_t * phi_3, dims) * D2
|
723 |
+
)
|
724 |
+
else:
|
725 |
+
phi_11 = torch.expm1(r1 * h)
|
726 |
+
phi_12 = torch.expm1(r2 * h)
|
727 |
+
phi_1 = torch.expm1(h)
|
728 |
+
phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
|
729 |
+
phi_2 = phi_1 / h - 1.
|
730 |
+
phi_3 = phi_2 / h - 0.5
|
731 |
+
|
732 |
+
if model_s is None:
|
733 |
+
model_s = self.model_fn(x, s)
|
734 |
+
if model_s1 is None:
|
735 |
+
x_s1 = (
|
736 |
+
expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
|
737 |
+
- expand_dims(sigma_s1 * phi_11, dims) * model_s
|
738 |
+
)
|
739 |
+
model_s1 = self.model_fn(x_s1, s1)
|
740 |
+
x_s2 = (
|
741 |
+
expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
|
742 |
+
- expand_dims(sigma_s2 * phi_12, dims) * model_s
|
743 |
+
- r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
|
744 |
+
)
|
745 |
+
model_s2 = self.model_fn(x_s2, s2)
|
746 |
+
if solver_type == 'dpm_solver':
|
747 |
+
x_t = (
|
748 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
|
749 |
+
- expand_dims(sigma_t * phi_1, dims) * model_s
|
750 |
+
- (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
|
751 |
+
)
|
752 |
+
elif solver_type == 'taylor':
|
753 |
+
D1_0 = (1. / r1) * (model_s1 - model_s)
|
754 |
+
D1_1 = (1. / r2) * (model_s2 - model_s)
|
755 |
+
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
756 |
+
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
|
757 |
+
x_t = (
|
758 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
|
759 |
+
- expand_dims(sigma_t * phi_1, dims) * model_s
|
760 |
+
- expand_dims(sigma_t * phi_2, dims) * D1
|
761 |
+
- expand_dims(sigma_t * phi_3, dims) * D2
|
762 |
+
)
|
763 |
+
|
764 |
+
if return_intermediate:
|
765 |
+
return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
|
766 |
+
else:
|
767 |
+
return x_t
|
768 |
+
|
769 |
+
def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
|
770 |
+
"""
|
771 |
+
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
|
772 |
+
|
773 |
+
Args:
|
774 |
+
x: A pytorch tensor. The initial value at time `s`.
|
775 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
776 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
|
777 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
778 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
779 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
780 |
+
Returns:
|
781 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
782 |
+
"""
|
783 |
+
if solver_type not in ['dpm_solver', 'taylor']:
|
784 |
+
raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
|
785 |
+
ns = self.noise_schedule
|
786 |
+
dims = x.dim()
|
787 |
+
model_prev_1, model_prev_0 = model_prev_list
|
788 |
+
t_prev_1, t_prev_0 = t_prev_list
|
789 |
+
lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
|
790 |
+
t_prev_0), ns.marginal_lambda(t)
|
791 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
792 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
793 |
+
alpha_t = torch.exp(log_alpha_t)
|
794 |
+
|
795 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
796 |
+
h = lambda_t - lambda_prev_0
|
797 |
+
r0 = h_0 / h
|
798 |
+
D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
|
799 |
+
if self.predict_x0:
|
800 |
+
if solver_type == 'dpm_solver':
|
801 |
+
x_t = (
|
802 |
+
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
803 |
+
- expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
|
804 |
+
- 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
|
805 |
+
)
|
806 |
+
elif solver_type == 'taylor':
|
807 |
+
x_t = (
|
808 |
+
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
809 |
+
- expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
|
810 |
+
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
|
811 |
+
)
|
812 |
+
else:
|
813 |
+
if solver_type == 'dpm_solver':
|
814 |
+
x_t = (
|
815 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
816 |
+
- expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
|
817 |
+
- 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
|
818 |
+
)
|
819 |
+
elif solver_type == 'taylor':
|
820 |
+
x_t = (
|
821 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
822 |
+
- expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
|
823 |
+
- expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
|
824 |
+
)
|
825 |
+
return x_t
|
826 |
+
|
827 |
+
def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
|
828 |
+
"""
|
829 |
+
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
|
830 |
+
|
831 |
+
Args:
|
832 |
+
x: A pytorch tensor. The initial value at time `s`.
|
833 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
834 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
|
835 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
836 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
837 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
838 |
+
Returns:
|
839 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
840 |
+
"""
|
841 |
+
ns = self.noise_schedule
|
842 |
+
dims = x.dim()
|
843 |
+
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
|
844 |
+
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
|
845 |
+
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
|
846 |
+
t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
|
847 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
848 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
849 |
+
alpha_t = torch.exp(log_alpha_t)
|
850 |
+
|
851 |
+
h_1 = lambda_prev_1 - lambda_prev_2
|
852 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
853 |
+
h = lambda_t - lambda_prev_0
|
854 |
+
r0, r1 = h_0 / h, h_1 / h
|
855 |
+
D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
|
856 |
+
D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
|
857 |
+
D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
|
858 |
+
D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
|
859 |
+
if self.predict_x0:
|
860 |
+
x_t = (
|
861 |
+
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
862 |
+
- expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
|
863 |
+
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
|
864 |
+
- expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
|
865 |
+
)
|
866 |
+
else:
|
867 |
+
x_t = (
|
868 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
869 |
+
- expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
|
870 |
+
- expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
|
871 |
+
- expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
|
872 |
+
)
|
873 |
+
return x_t
|
874 |
+
|
875 |
+
def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
|
876 |
+
r2=None):
|
877 |
+
"""
|
878 |
+
Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
|
879 |
+
|
880 |
+
Args:
|
881 |
+
x: A pytorch tensor. The initial value at time `s`.
|
882 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
883 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
884 |
+
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
|
885 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
|
886 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
887 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
888 |
+
r1: A `float`. The hyperparameter of the second-order or third-order solver.
|
889 |
+
r2: A `float`. The hyperparameter of the third-order solver.
|
890 |
+
Returns:
|
891 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
892 |
+
"""
|
893 |
+
if order == 1:
|
894 |
+
return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
|
895 |
+
elif order == 2:
|
896 |
+
return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
|
897 |
+
solver_type=solver_type, r1=r1)
|
898 |
+
elif order == 3:
|
899 |
+
return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
|
900 |
+
solver_type=solver_type, r1=r1, r2=r2)
|
901 |
+
else:
|
902 |
+
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
903 |
+
|
904 |
+
def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
|
905 |
+
"""
|
906 |
+
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
|
907 |
+
|
908 |
+
Args:
|
909 |
+
x: A pytorch tensor. The initial value at time `s`.
|
910 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
911 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
|
912 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
913 |
+
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
|
914 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
915 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
916 |
+
Returns:
|
917 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
918 |
+
"""
|
919 |
+
if order == 1:
|
920 |
+
return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
|
921 |
+
elif order == 2:
|
922 |
+
return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
|
923 |
+
elif order == 3:
|
924 |
+
return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
|
925 |
+
else:
|
926 |
+
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
927 |
+
|
928 |
+
def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
|
929 |
+
solver_type='dpm_solver'):
|
930 |
+
"""
|
931 |
+
The adaptive step size solver based on singlestep DPM-Solver.
|
932 |
+
|
933 |
+
Args:
|
934 |
+
x: A pytorch tensor. The initial value at time `t_T`.
|
935 |
+
order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
|
936 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
937 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
938 |
+
h_init: A `float`. The initial step size (for logSNR).
|
939 |
+
atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
|
940 |
+
rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
|
941 |
+
theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
|
942 |
+
t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
|
943 |
+
current time and `t_0` is less than `t_err`. The default setting is 1e-5.
|
944 |
+
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
|
945 |
+
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
|
946 |
+
Returns:
|
947 |
+
x_0: A pytorch tensor. The approximated solution at time `t_0`.
|
948 |
+
|
949 |
+
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
|
950 |
+
"""
|
951 |
+
ns = self.noise_schedule
|
952 |
+
s = t_T * torch.ones((x.shape[0],)).to(x)
|
953 |
+
lambda_s = ns.marginal_lambda(s)
|
954 |
+
lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
|
955 |
+
h = h_init * torch.ones_like(s).to(x)
|
956 |
+
x_prev = x
|
957 |
+
nfe = 0
|
958 |
+
if order == 2:
|
959 |
+
r1 = 0.5
|
960 |
+
lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
|
961 |
+
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
|
962 |
+
solver_type=solver_type,
|
963 |
+
**kwargs)
|
964 |
+
elif order == 3:
|
965 |
+
r1, r2 = 1. / 3., 2. / 3.
|
966 |
+
lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
|
967 |
+
return_intermediate=True,
|
968 |
+
solver_type=solver_type)
|
969 |
+
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
|
970 |
+
solver_type=solver_type,
|
971 |
+
**kwargs)
|
972 |
+
else:
|
973 |
+
raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
|
974 |
+
while torch.abs((s - t_0)).mean() > t_err:
|
975 |
+
t = ns.inverse_lambda(lambda_s + h)
|
976 |
+
x_lower, lower_noise_kwargs = lower_update(x, s, t)
|
977 |
+
x_higher = higher_update(x, s, t, **lower_noise_kwargs)
|
978 |
+
delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
|
979 |
+
norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
|
980 |
+
E = norm_fn((x_higher - x_lower) / delta).max()
|
981 |
+
if torch.all(E <= 1.):
|
982 |
+
x = x_higher
|
983 |
+
s = t
|
984 |
+
x_prev = x_lower
|
985 |
+
lambda_s = ns.marginal_lambda(s)
|
986 |
+
h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
|
987 |
+
nfe += order
|
988 |
+
print('adaptive solver nfe', nfe)
|
989 |
+
return x
|
990 |
+
|
991 |
+
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
992 |
+
method='singlestep', denoise=False, solver_type='dpm_solver', atol=0.0078,
|
993 |
+
rtol=0.05,
|
994 |
+
):
|
995 |
+
"""
|
996 |
+
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
|
997 |
+
|
998 |
+
=====================================================
|
999 |
+
|
1000 |
+
We support the following algorithms for both noise prediction model and data prediction model:
|
1001 |
+
- 'singlestep':
|
1002 |
+
Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
|
1003 |
+
We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
|
1004 |
+
The total number of function evaluations (NFE) == `steps`.
|
1005 |
+
Given a fixed NFE == `steps`, the sampling procedure is:
|
1006 |
+
- If `order` == 1:
|
1007 |
+
- Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
|
1008 |
+
- If `order` == 2:
|
1009 |
+
- Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
|
1010 |
+
- If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
|
1011 |
+
- If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
|
1012 |
+
- If `order` == 3:
|
1013 |
+
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
|
1014 |
+
- If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
|
1015 |
+
- If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
|
1016 |
+
- If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
|
1017 |
+
- 'multistep':
|
1018 |
+
Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
|
1019 |
+
We initialize the first `order` values by lower order multistep solvers.
|
1020 |
+
Given a fixed NFE == `steps`, the sampling procedure is:
|
1021 |
+
Denote K = steps.
|
1022 |
+
- If `order` == 1:
|
1023 |
+
- We use K steps of DPM-Solver-1 (i.e. DDIM).
|
1024 |
+
- If `order` == 2:
|
1025 |
+
- We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
|
1026 |
+
- If `order` == 3:
|
1027 |
+
- We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
|
1028 |
+
- 'singlestep_fixed':
|
1029 |
+
Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
|
1030 |
+
We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
|
1031 |
+
- 'adaptive':
|
1032 |
+
Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
|
1033 |
+
We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
|
1034 |
+
You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
|
1035 |
+
(NFE) and the sample quality.
|
1036 |
+
- If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
|
1037 |
+
- If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
|
1038 |
+
|
1039 |
+
=====================================================
|
1040 |
+
|
1041 |
+
Some advices for choosing the algorithm:
|
1042 |
+
- For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
|
1043 |
+
Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
|
1044 |
+
e.g.
|
1045 |
+
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
|
1046 |
+
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
|
1047 |
+
skip_type='time_uniform', method='singlestep')
|
1048 |
+
- For **guided sampling with large guidance scale** by DPMs:
|
1049 |
+
Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
|
1050 |
+
e.g.
|
1051 |
+
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
|
1052 |
+
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
|
1053 |
+
skip_type='time_uniform', method='multistep')
|
1054 |
+
|
1055 |
+
We support three types of `skip_type`:
|
1056 |
+
- 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
|
1057 |
+
- 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
|
1058 |
+
- 'time_quadratic': quadratic time for the time steps.
|
1059 |
+
|
1060 |
+
=====================================================
|
1061 |
+
Args:
|
1062 |
+
x: A pytorch tensor. The initial value at time `t_start`
|
1063 |
+
e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
|
1064 |
+
steps: A `int`. The total number of function evaluations (NFE).
|
1065 |
+
t_start: A `float`. The starting time of the sampling.
|
1066 |
+
If `T` is None, we use self.noise_schedule.T (default is 1.0).
|
1067 |
+
t_end: A `float`. The ending time of the sampling.
|
1068 |
+
If `t_end` is None, we use 1. / self.noise_schedule.total_N.
|
1069 |
+
e.g. if total_N == 1000, we have `t_end` == 1e-3.
|
1070 |
+
For discrete-time DPMs:
|
1071 |
+
- We recommend `t_end` == 1. / self.noise_schedule.total_N.
|
1072 |
+
For continuous-time DPMs:
|
1073 |
+
- We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
|
1074 |
+
order: A `int`. The order of DPM-Solver.
|
1075 |
+
skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
|
1076 |
+
method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
|
1077 |
+
denoise: A `bool`. Whether to denoise at the final step. Default is False.
|
1078 |
+
If `denoise` is True, the total NFE is (`steps` + 1).
|
1079 |
+
solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
|
1080 |
+
atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
|
1081 |
+
rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
|
1082 |
+
Returns:
|
1083 |
+
x_end: A pytorch tensor. The approximated solution at time `t_end`.
|
1084 |
+
|
1085 |
+
"""
|
1086 |
+
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
1087 |
+
t_T = self.noise_schedule.T if t_start is None else t_start
|
1088 |
+
device = x.device
|
1089 |
+
if method == 'adaptive':
|
1090 |
+
with torch.no_grad():
|
1091 |
+
x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
|
1092 |
+
solver_type=solver_type)
|
1093 |
+
elif method == 'multistep':
|
1094 |
+
assert steps >= order
|
1095 |
+
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
1096 |
+
assert timesteps.shape[0] - 1 == steps
|
1097 |
+
with torch.no_grad():
|
1098 |
+
vec_t = timesteps[0].expand((x.shape[0]))
|
1099 |
+
model_prev_list = [self.model_fn(x, vec_t)]
|
1100 |
+
t_prev_list = [vec_t]
|
1101 |
+
# Init the first `order` values by lower order multistep DPM-Solver.
|
1102 |
+
for init_order in range(1, order):
|
1103 |
+
vec_t = timesteps[init_order].expand(x.shape[0])
|
1104 |
+
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
|
1105 |
+
solver_type=solver_type)
|
1106 |
+
model_prev_list.append(self.model_fn(x, vec_t))
|
1107 |
+
t_prev_list.append(vec_t)
|
1108 |
+
# Compute the remaining values by `order`-th order multistep DPM-Solver.
|
1109 |
+
for step in range(order, steps + 1):
|
1110 |
+
vec_t = timesteps[step].expand(x.shape[0])
|
1111 |
+
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, order,
|
1112 |
+
solver_type=solver_type)
|
1113 |
+
for i in range(order - 1):
|
1114 |
+
t_prev_list[i] = t_prev_list[i + 1]
|
1115 |
+
model_prev_list[i] = model_prev_list[i + 1]
|
1116 |
+
t_prev_list[-1] = vec_t
|
1117 |
+
# We do not need to evaluate the final model value.
|
1118 |
+
if step < steps:
|
1119 |
+
model_prev_list[-1] = self.model_fn(x, vec_t)
|
1120 |
+
elif method in ['singlestep', 'singlestep_fixed']:
|
1121 |
+
if method == 'singlestep':
|
1122 |
+
timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
|
1123 |
+
skip_type=skip_type,
|
1124 |
+
t_T=t_T, t_0=t_0,
|
1125 |
+
device=device)
|
1126 |
+
elif method == 'singlestep_fixed':
|
1127 |
+
K = steps // order
|
1128 |
+
orders = [order, ] * K
|
1129 |
+
timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
|
1130 |
+
for i, order in enumerate(orders):
|
1131 |
+
t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
|
1132 |
+
timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
|
1133 |
+
N=order, device=device)
|
1134 |
+
lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
|
1135 |
+
vec_s, vec_t = t_T_inner.repeat(x.shape[0]), t_0_inner.repeat(x.shape[0])
|
1136 |
+
h = lambda_inner[-1] - lambda_inner[0]
|
1137 |
+
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
|
1138 |
+
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
|
1139 |
+
x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
|
1140 |
+
if denoise:
|
1141 |
+
x = self.denoise_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
1142 |
+
return x
|
1143 |
+
|
1144 |
+
|
1145 |
+
#############################################################
|
1146 |
+
# other utility functions
|
1147 |
+
#############################################################
|
1148 |
+
|
1149 |
+
def interpolate_fn(x, xp, yp):
|
1150 |
+
"""
|
1151 |
+
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
1152 |
+
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
1153 |
+
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
1154 |
+
|
1155 |
+
Args:
|
1156 |
+
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
1157 |
+
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
1158 |
+
yp: PyTorch tensor with shape [C, K].
|
1159 |
+
Returns:
|
1160 |
+
The function values f(x), with shape [N, C].
|
1161 |
+
"""
|
1162 |
+
N, K = x.shape[0], xp.shape[1]
|
1163 |
+
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
1164 |
+
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
1165 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
1166 |
+
cand_start_idx = x_idx - 1
|
1167 |
+
start_idx = torch.where(
|
1168 |
+
torch.eq(x_idx, 0),
|
1169 |
+
torch.tensor(1, device=x.device),
|
1170 |
+
torch.where(
|
1171 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
1172 |
+
),
|
1173 |
+
)
|
1174 |
+
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
1175 |
+
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
1176 |
+
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
1177 |
+
start_idx2 = torch.where(
|
1178 |
+
torch.eq(x_idx, 0),
|
1179 |
+
torch.tensor(0, device=x.device),
|
1180 |
+
torch.where(
|
1181 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
1182 |
+
),
|
1183 |
+
)
|
1184 |
+
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
1185 |
+
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
1186 |
+
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
1187 |
+
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
1188 |
+
return cand
|
1189 |
+
|
1190 |
+
|
1191 |
+
def expand_dims(v, dims):
|
1192 |
+
"""
|
1193 |
+
Expand the tensor `v` to the dim `dims`.
|
1194 |
+
|
1195 |
+
Args:
|
1196 |
+
`v`: a PyTorch tensor with shape [N].
|
1197 |
+
`dim`: a `int`.
|
1198 |
+
Returns:
|
1199 |
+
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
1200 |
+
"""
|
1201 |
+
return v[(...,) + (None,) * (dims - 1)]
|
DDSP-SVC/diffusion/infer_gt_mel.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from diffusion.unit2mel import load_model_vocoder
|
5 |
+
|
6 |
+
|
7 |
+
class DiffGtMel:
|
8 |
+
def __init__(self, project_path=None, device=None):
|
9 |
+
self.project_path = project_path
|
10 |
+
if device is not None:
|
11 |
+
self.device = device
|
12 |
+
else:
|
13 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
14 |
+
self.model = None
|
15 |
+
self.vocoder = None
|
16 |
+
self.args = None
|
17 |
+
|
18 |
+
def flush_model(self, project_path, ddsp_config=None):
|
19 |
+
if (self.model is None) or (project_path != self.project_path):
|
20 |
+
model, vocoder, args = load_model_vocoder(project_path, device=self.device)
|
21 |
+
if self.check_args(ddsp_config, args):
|
22 |
+
self.model = model
|
23 |
+
self.vocoder = vocoder
|
24 |
+
self.args = args
|
25 |
+
|
26 |
+
def check_args(self, args1, args2):
|
27 |
+
if args1.data.block_size != args2.data.block_size:
|
28 |
+
raise ValueError("DDSP与DIFF模型的block_size不一致")
|
29 |
+
if args1.data.sampling_rate != args2.data.sampling_rate:
|
30 |
+
raise ValueError("DDSP与DIFF模型的sampling_rate不一致")
|
31 |
+
if args1.data.encoder != args2.data.encoder:
|
32 |
+
raise ValueError("DDSP与DIFF模型的encoder不一致")
|
33 |
+
return True
|
34 |
+
|
35 |
+
def __call__(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, use_dpm=True,
|
36 |
+
spk_mix_dict=None, start_frame=0):
|
37 |
+
input_mel = self.vocoder.extract(audio, self.args.data.sampling_rate)
|
38 |
+
if use_dpm:
|
39 |
+
method = 'dpm-solver'
|
40 |
+
else:
|
41 |
+
method = 'pndm'
|
42 |
+
out_mel = self.model(
|
43 |
+
hubert,
|
44 |
+
f0,
|
45 |
+
volume,
|
46 |
+
spk_id=spk_id,
|
47 |
+
spk_mix_dict=spk_mix_dict,
|
48 |
+
gt_spec=input_mel,
|
49 |
+
infer=True,
|
50 |
+
infer_speedup=acc,
|
51 |
+
method=method,
|
52 |
+
k_step=k_step,
|
53 |
+
use_tqdm=False)
|
54 |
+
if start_frame > 0:
|
55 |
+
out_mel = out_mel[:, start_frame:, :]
|
56 |
+
f0 = f0[:, start_frame:, :]
|
57 |
+
output = self.vocoder.infer(out_mel, f0)
|
58 |
+
if start_frame > 0:
|
59 |
+
output = F.pad(output, (start_frame * self.vocoder.vocoder_hop_size, 0))
|
60 |
+
return output
|
61 |
+
|
62 |
+
def infer(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, use_dpm=True, silence_front=0,
|
63 |
+
use_silence=False, spk_mix_dict=None):
|
64 |
+
start_frame = int(silence_front * self.vocoder.vocoder_sample_rate / self.vocoder.vocoder_hop_size)
|
65 |
+
if use_silence:
|
66 |
+
audio = audio[:, start_frame * self.vocoder.vocoder_hop_size:]
|
67 |
+
f0 = f0[:, start_frame:, :]
|
68 |
+
hubert = hubert[:, start_frame:, :]
|
69 |
+
volume = volume[:, start_frame:, :]
|
70 |
+
_start_frame = 0
|
71 |
+
else:
|
72 |
+
_start_frame = start_frame
|
73 |
+
audio = self.__call__(audio, f0, hubert, volume, acc=acc, spk_id=spk_id, k_step=k_step,
|
74 |
+
use_dpm=use_dpm, spk_mix_dict=spk_mix_dict, start_frame=_start_frame)
|
75 |
+
if use_silence:
|
76 |
+
if start_frame > 0:
|
77 |
+
audio = F.pad(audio, (start_frame * self.vocoder.vocoder_hop_size, 0))
|
78 |
+
return audio
|
DDSP-SVC/diffusion/solver.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import librosa
|
6 |
+
from logger.saver import Saver
|
7 |
+
from logger import utils
|
8 |
+
|
9 |
+
def test(args, model, vocoder, loader_test, saver):
|
10 |
+
print(' [*] testing...')
|
11 |
+
model.eval()
|
12 |
+
|
13 |
+
# losses
|
14 |
+
test_loss = 0.
|
15 |
+
|
16 |
+
# intialization
|
17 |
+
num_batches = len(loader_test)
|
18 |
+
rtf_all = []
|
19 |
+
|
20 |
+
# run
|
21 |
+
with torch.no_grad():
|
22 |
+
for bidx, data in enumerate(loader_test):
|
23 |
+
fn = data['name'][0]
|
24 |
+
print('--------')
|
25 |
+
print('{}/{} - {}'.format(bidx, num_batches, fn))
|
26 |
+
|
27 |
+
# unpack data
|
28 |
+
for k in data.keys():
|
29 |
+
if k != 'name':
|
30 |
+
data[k] = data[k].to(args.device)
|
31 |
+
print('>>', data['name'][0])
|
32 |
+
|
33 |
+
# forward
|
34 |
+
st_time = time.time()
|
35 |
+
mel = model(
|
36 |
+
data['units'],
|
37 |
+
data['f0'],
|
38 |
+
data['volume'],
|
39 |
+
data['spk_id'],
|
40 |
+
gt_spec=None,
|
41 |
+
infer=True,
|
42 |
+
infer_speedup=args.infer.speedup,
|
43 |
+
method=args.infer.method)
|
44 |
+
signal = vocoder.infer(mel, data['f0'])
|
45 |
+
ed_time = time.time()
|
46 |
+
|
47 |
+
# RTF
|
48 |
+
run_time = ed_time - st_time
|
49 |
+
song_time = signal.shape[-1] / args.data.sampling_rate
|
50 |
+
rtf = run_time / song_time
|
51 |
+
print('RTF: {} | {} / {}'.format(rtf, run_time, song_time))
|
52 |
+
rtf_all.append(rtf)
|
53 |
+
|
54 |
+
# loss
|
55 |
+
for i in range(args.train.batch_size):
|
56 |
+
loss = model(
|
57 |
+
data['units'],
|
58 |
+
data['f0'],
|
59 |
+
data['volume'],
|
60 |
+
data['spk_id'],
|
61 |
+
gt_spec=data['mel'],
|
62 |
+
infer=False)
|
63 |
+
test_loss += loss.item()
|
64 |
+
|
65 |
+
# log mel
|
66 |
+
saver.log_spec(data['name'][0], data['mel'], mel)
|
67 |
+
|
68 |
+
# log audio
|
69 |
+
path_audio = os.path.join(args.data.valid_path, 'audio', data['name'][0]) + '.wav'
|
70 |
+
audio, sr = librosa.load(path_audio, sr=args.data.sampling_rate)
|
71 |
+
if len(audio.shape) > 1:
|
72 |
+
audio = librosa.to_mono(audio)
|
73 |
+
audio = torch.from_numpy(audio).unsqueeze(0).to(signal)
|
74 |
+
saver.log_audio({fn+'/gt.wav': audio, fn+'/pred.wav': signal})
|
75 |
+
|
76 |
+
# report
|
77 |
+
test_loss /= args.train.batch_size
|
78 |
+
test_loss /= num_batches
|
79 |
+
|
80 |
+
# check
|
81 |
+
print(' [test_loss] test_loss:', test_loss)
|
82 |
+
print(' Real Time Factor', np.mean(rtf_all))
|
83 |
+
return test_loss
|
84 |
+
|
85 |
+
|
86 |
+
def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test):
|
87 |
+
# saver
|
88 |
+
saver = Saver(args, initial_global_step=initial_global_step)
|
89 |
+
|
90 |
+
# model size
|
91 |
+
params_count = utils.get_network_paras_amount({'model': model})
|
92 |
+
saver.log_info('--- model size ---')
|
93 |
+
saver.log_info(params_count)
|
94 |
+
|
95 |
+
# run
|
96 |
+
num_batches = len(loader_train)
|
97 |
+
model.train()
|
98 |
+
saver.log_info('======= start training =======')
|
99 |
+
for epoch in range(args.train.epochs):
|
100 |
+
for batch_idx, data in enumerate(loader_train):
|
101 |
+
saver.global_step_increment()
|
102 |
+
optimizer.zero_grad()
|
103 |
+
|
104 |
+
# unpack data
|
105 |
+
for k in data.keys():
|
106 |
+
if k != 'name':
|
107 |
+
data[k] = data[k].to(args.device)
|
108 |
+
|
109 |
+
# forward
|
110 |
+
loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'],
|
111 |
+
aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False)
|
112 |
+
|
113 |
+
# handle nan loss
|
114 |
+
if torch.isnan(loss):
|
115 |
+
raise ValueError(' [x] nan loss ')
|
116 |
+
else:
|
117 |
+
# backpropagate
|
118 |
+
loss.backward()
|
119 |
+
optimizer.step()
|
120 |
+
scheduler.step()
|
121 |
+
|
122 |
+
# log loss
|
123 |
+
if saver.global_step % args.train.interval_log == 0:
|
124 |
+
current_lr = optimizer.param_groups[0]['lr']
|
125 |
+
saver.log_info(
|
126 |
+
'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format(
|
127 |
+
epoch,
|
128 |
+
batch_idx,
|
129 |
+
num_batches,
|
130 |
+
args.env.expdir,
|
131 |
+
args.train.interval_log/saver.get_interval_time(),
|
132 |
+
current_lr,
|
133 |
+
loss.item(),
|
134 |
+
saver.get_total_time(),
|
135 |
+
saver.global_step
|
136 |
+
)
|
137 |
+
)
|
138 |
+
|
139 |
+
saver.log_value({
|
140 |
+
'train/loss': loss.item()
|
141 |
+
})
|
142 |
+
|
143 |
+
saver.log_value({
|
144 |
+
'train/lr': current_lr
|
145 |
+
})
|
146 |
+
|
147 |
+
# validation
|
148 |
+
if saver.global_step % args.train.interval_val == 0:
|
149 |
+
# save latest
|
150 |
+
saver.save_model(model, optimizer, postfix=f'{saver.global_step}')
|
151 |
+
last_val_step = saver.global_step - args.train.interval_val
|
152 |
+
if last_val_step % args.train.interval_force_save != 0:
|
153 |
+
saver.delete_model(postfix=f'{last_val_step}')
|
154 |
+
|
155 |
+
# run testing set
|
156 |
+
|
157 |
+
test_loss = test(args, model, vocoder, loader_test, saver)
|
158 |
+
|
159 |
+
saver.log_info(
|
160 |
+
' --- <validation> --- \nloss: {:.3f}. '.format(
|
161 |
+
test_loss,
|
162 |
+
)
|
163 |
+
)
|
164 |
+
|
165 |
+
saver.log_value({
|
166 |
+
'validation/loss': test_loss
|
167 |
+
})
|
168 |
+
|
169 |
+
model.train()
|
170 |
+
|
171 |
+
|
DDSP-SVC/diffusion/unit2mel.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
from .diffusion import GaussianDiffusion
|
7 |
+
from .wavenet import WaveNet
|
8 |
+
from .vocoder import Vocoder
|
9 |
+
|
10 |
+
class DotDict(dict):
|
11 |
+
def __getattr__(*args):
|
12 |
+
val = dict.get(*args)
|
13 |
+
return DotDict(val) if type(val) is dict else val
|
14 |
+
|
15 |
+
__setattr__ = dict.__setitem__
|
16 |
+
__delattr__ = dict.__delitem__
|
17 |
+
|
18 |
+
|
19 |
+
def load_model_vocoder(
|
20 |
+
model_path,
|
21 |
+
device='cpu'):
|
22 |
+
config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml')
|
23 |
+
with open(config_file, "r") as config:
|
24 |
+
args = yaml.safe_load(config)
|
25 |
+
args = DotDict(args)
|
26 |
+
|
27 |
+
# load vocoder
|
28 |
+
vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=device)
|
29 |
+
|
30 |
+
# load model
|
31 |
+
model = Unit2Mel(
|
32 |
+
args.data.encoder_out_channels,
|
33 |
+
args.model.n_spk,
|
34 |
+
args.model.use_pitch_aug,
|
35 |
+
vocoder.dimension,
|
36 |
+
args.model.n_layers,
|
37 |
+
args.model.n_chans,
|
38 |
+
args.model.n_hidden)
|
39 |
+
|
40 |
+
print(' [Loading] ' + model_path)
|
41 |
+
ckpt = torch.load(model_path, map_location=torch.device(device))
|
42 |
+
model.to(device)
|
43 |
+
model.load_state_dict(ckpt['model'])
|
44 |
+
model.eval()
|
45 |
+
return model, vocoder, args
|
46 |
+
|
47 |
+
|
48 |
+
class Unit2Mel(nn.Module):
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
input_channel,
|
52 |
+
n_spk,
|
53 |
+
use_pitch_aug=False,
|
54 |
+
out_dims=128,
|
55 |
+
n_layers=20,
|
56 |
+
n_chans=384,
|
57 |
+
n_hidden=256):
|
58 |
+
super().__init__()
|
59 |
+
self.unit_embed = nn.Linear(input_channel, n_hidden)
|
60 |
+
self.f0_embed = nn.Linear(1, n_hidden)
|
61 |
+
self.volume_embed = nn.Linear(1, n_hidden)
|
62 |
+
if use_pitch_aug:
|
63 |
+
self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False)
|
64 |
+
else:
|
65 |
+
self.aug_shift_embed = None
|
66 |
+
self.n_spk = n_spk
|
67 |
+
if n_spk is not None and n_spk > 1:
|
68 |
+
self.spk_embed = nn.Embedding(n_spk, n_hidden)
|
69 |
+
|
70 |
+
# diffusion
|
71 |
+
self.decoder = GaussianDiffusion(WaveNet(out_dims, n_layers, n_chans, n_hidden), out_dims=out_dims)
|
72 |
+
|
73 |
+
def forward(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None,
|
74 |
+
gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True):
|
75 |
+
|
76 |
+
'''
|
77 |
+
input:
|
78 |
+
B x n_frames x n_unit
|
79 |
+
return:
|
80 |
+
dict of B x n_frames x feat
|
81 |
+
'''
|
82 |
+
|
83 |
+
x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume)
|
84 |
+
if self.n_spk is not None and self.n_spk > 1:
|
85 |
+
if spk_mix_dict is not None:
|
86 |
+
for k, v in spk_mix_dict.items():
|
87 |
+
spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device)
|
88 |
+
x = x + v * self.spk_embed(spk_id_torch - 1)
|
89 |
+
else:
|
90 |
+
x = x + self.spk_embed(spk_id - 1)
|
91 |
+
if self.aug_shift_embed is not None and aug_shift is not None:
|
92 |
+
x = x + self.aug_shift_embed(aug_shift / 5)
|
93 |
+
x = self.decoder(x, gt_spec=gt_spec, infer=infer, infer_speedup=infer_speedup, method=method, k_step=k_step, use_tqdm=use_tqdm)
|
94 |
+
|
95 |
+
return x
|
96 |
+
|
DDSP-SVC/diffusion/vocoder.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from nsf_hifigan.nvSTFT import STFT
|
3 |
+
from nsf_hifigan.models import load_model
|
4 |
+
from torchaudio.transforms import Resample
|
5 |
+
|
6 |
+
|
7 |
+
class Vocoder:
|
8 |
+
def __init__(self, vocoder_type, vocoder_ckpt, device = None):
|
9 |
+
if device is None:
|
10 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
11 |
+
self.device = device
|
12 |
+
|
13 |
+
if vocoder_type == 'nsf-hifigan':
|
14 |
+
self.vocoder = NsfHifiGAN(vocoder_ckpt, device = device)
|
15 |
+
elif vocoder_type == 'nsf-hifigan-log10':
|
16 |
+
self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device = device)
|
17 |
+
else:
|
18 |
+
raise ValueError(f" [x] Unknown vocoder: {vocoder_type}")
|
19 |
+
|
20 |
+
self.resample_kernel = {}
|
21 |
+
self.vocoder_sample_rate = self.vocoder.sample_rate()
|
22 |
+
self.vocoder_hop_size = self.vocoder.hop_size()
|
23 |
+
self.dimension = self.vocoder.dimension()
|
24 |
+
|
25 |
+
def extract(self, audio, sample_rate, keyshift=0):
|
26 |
+
|
27 |
+
# resample
|
28 |
+
if sample_rate == self.vocoder_sample_rate:
|
29 |
+
audio_res = audio
|
30 |
+
else:
|
31 |
+
key_str = str(sample_rate)
|
32 |
+
if key_str not in self.resample_kernel:
|
33 |
+
self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width = 128).to(self.device)
|
34 |
+
audio_res = self.resample_kernel[key_str](audio)
|
35 |
+
|
36 |
+
# extract
|
37 |
+
mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins
|
38 |
+
return mel
|
39 |
+
|
40 |
+
def infer(self, mel, f0):
|
41 |
+
f0 = f0[:,:mel.size(1),0] # B, n_frames
|
42 |
+
audio = self.vocoder(mel, f0)
|
43 |
+
return audio
|
44 |
+
|
45 |
+
|
46 |
+
class NsfHifiGAN(torch.nn.Module):
|
47 |
+
def __init__(self, model_path, device=None):
|
48 |
+
super().__init__()
|
49 |
+
if device is None:
|
50 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
51 |
+
self.device = device
|
52 |
+
print('| Load HifiGAN: ', model_path)
|
53 |
+
self.model, self.h = load_model(model_path, device=self.device)
|
54 |
+
self.stft = STFT(
|
55 |
+
self.h.sampling_rate,
|
56 |
+
self.h.num_mels,
|
57 |
+
self.h.n_fft,
|
58 |
+
self.h.win_size,
|
59 |
+
self.h.hop_size,
|
60 |
+
self.h.fmin,
|
61 |
+
self.h.fmax)
|
62 |
+
|
63 |
+
def sample_rate(self):
|
64 |
+
return self.h.sampling_rate
|
65 |
+
|
66 |
+
def hop_size(self):
|
67 |
+
return self.h.hop_size
|
68 |
+
|
69 |
+
def dimension(self):
|
70 |
+
return self.h.num_mels
|
71 |
+
|
72 |
+
def extract(self, audio, keyshift=0):
|
73 |
+
mel = self.stft.get_mel(audio, keyshift=keyshift).transpose(1, 2) # B, n_frames, bins
|
74 |
+
return mel
|
75 |
+
|
76 |
+
def forward(self, mel, f0):
|
77 |
+
with torch.no_grad():
|
78 |
+
c = mel.transpose(1, 2)
|
79 |
+
audio = self.model(c, f0)
|
80 |
+
return audio
|
81 |
+
|
82 |
+
class NsfHifiGANLog10(NsfHifiGAN):
|
83 |
+
def forward(self, mel, f0):
|
84 |
+
with torch.no_grad():
|
85 |
+
c = 0.434294 * mel.transpose(1, 2)
|
86 |
+
audio = self.model(c, f0)
|
87 |
+
return audio
|
DDSP-SVC/diffusion/wavenet.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from math import sqrt
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn import Mish
|
8 |
+
|
9 |
+
|
10 |
+
class Conv1d(torch.nn.Conv1d):
|
11 |
+
def __init__(self, *args, **kwargs):
|
12 |
+
super().__init__(*args, **kwargs)
|
13 |
+
nn.init.kaiming_normal_(self.weight)
|
14 |
+
|
15 |
+
|
16 |
+
class SinusoidalPosEmb(nn.Module):
|
17 |
+
def __init__(self, dim):
|
18 |
+
super().__init__()
|
19 |
+
self.dim = dim
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
device = x.device
|
23 |
+
half_dim = self.dim // 2
|
24 |
+
emb = math.log(10000) / (half_dim - 1)
|
25 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
26 |
+
emb = x[:, None] * emb[None, :]
|
27 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
28 |
+
return emb
|
29 |
+
|
30 |
+
|
31 |
+
class ResidualBlock(nn.Module):
|
32 |
+
def __init__(self, encoder_hidden, residual_channels, dilation):
|
33 |
+
super().__init__()
|
34 |
+
self.residual_channels = residual_channels
|
35 |
+
self.dilated_conv = nn.Conv1d(
|
36 |
+
residual_channels,
|
37 |
+
2 * residual_channels,
|
38 |
+
kernel_size=3,
|
39 |
+
padding=dilation,
|
40 |
+
dilation=dilation
|
41 |
+
)
|
42 |
+
self.diffusion_projection = nn.Linear(residual_channels, residual_channels)
|
43 |
+
self.conditioner_projection = nn.Conv1d(encoder_hidden, 2 * residual_channels, 1)
|
44 |
+
self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1)
|
45 |
+
|
46 |
+
def forward(self, x, conditioner, diffusion_step):
|
47 |
+
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
|
48 |
+
conditioner = self.conditioner_projection(conditioner)
|
49 |
+
y = x + diffusion_step
|
50 |
+
|
51 |
+
y = self.dilated_conv(y) + conditioner
|
52 |
+
|
53 |
+
# Using torch.split instead of torch.chunk to avoid using onnx::Slice
|
54 |
+
gate, filter = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
|
55 |
+
y = torch.sigmoid(gate) * torch.tanh(filter)
|
56 |
+
|
57 |
+
y = self.output_projection(y)
|
58 |
+
|
59 |
+
# Using torch.split instead of torch.chunk to avoid using onnx::Slice
|
60 |
+
residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
|
61 |
+
return (x + residual) / math.sqrt(2.0), skip
|
62 |
+
|
63 |
+
|
64 |
+
class WaveNet(nn.Module):
|
65 |
+
def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256):
|
66 |
+
super().__init__()
|
67 |
+
self.input_projection = Conv1d(in_dims, n_chans, 1)
|
68 |
+
self.diffusion_embedding = SinusoidalPosEmb(n_chans)
|
69 |
+
self.mlp = nn.Sequential(
|
70 |
+
nn.Linear(n_chans, n_chans * 4),
|
71 |
+
Mish(),
|
72 |
+
nn.Linear(n_chans * 4, n_chans)
|
73 |
+
)
|
74 |
+
self.residual_layers = nn.ModuleList([
|
75 |
+
ResidualBlock(
|
76 |
+
encoder_hidden=n_hidden,
|
77 |
+
residual_channels=n_chans,
|
78 |
+
dilation=1
|
79 |
+
)
|
80 |
+
for i in range(n_layers)
|
81 |
+
])
|
82 |
+
self.skip_projection = Conv1d(n_chans, n_chans, 1)
|
83 |
+
self.output_projection = Conv1d(n_chans, in_dims, 1)
|
84 |
+
nn.init.zeros_(self.output_projection.weight)
|
85 |
+
|
86 |
+
def forward(self, spec, diffusion_step, cond):
|
87 |
+
"""
|
88 |
+
:param spec: [B, 1, M, T]
|
89 |
+
:param diffusion_step: [B, 1]
|
90 |
+
:param cond: [B, M, T]
|
91 |
+
:return:
|
92 |
+
"""
|
93 |
+
x = spec.squeeze(1)
|
94 |
+
x = self.input_projection(x) # [B, residual_channel, T]
|
95 |
+
|
96 |
+
x = F.relu(x)
|
97 |
+
diffusion_step = self.diffusion_embedding(diffusion_step)
|
98 |
+
diffusion_step = self.mlp(diffusion_step)
|
99 |
+
skip = []
|
100 |
+
for layer in self.residual_layers:
|
101 |
+
x, skip_connection = layer(x, cond, diffusion_step)
|
102 |
+
skip.append(skip_connection)
|
103 |
+
|
104 |
+
x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers))
|
105 |
+
x = self.skip_projection(x)
|
106 |
+
x = F.relu(x)
|
107 |
+
x = self.output_projection(x) # [B, mel_bins, T]
|
108 |
+
return x[:, None, :, :]
|
DDSP-SVC/draw.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tqdm
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
import wave
|
7 |
+
|
8 |
+
WAV_MIN_LENGTH = 2 # wav文件的最短时长 / The minimum duration of wav files
|
9 |
+
SAMPLE_RATE = 1 # 抽取文件数量的百分比 / The percentage of files to be extracted
|
10 |
+
SAMPLE_MIN = 2 # 抽取的文件数量下限 / The lower limit of the number of files to be extracted
|
11 |
+
SAMPLE_MAX = 10 # 抽取的文件数量上限 / The upper limit of the number of files to be extracted
|
12 |
+
|
13 |
+
|
14 |
+
# 定义一个函数,用于检查wav文件的时长是否大于最短时长
|
15 |
+
def check_duration(wav_file):
|
16 |
+
# 打开wav文件
|
17 |
+
f = wave.open(wav_file, "rb")
|
18 |
+
# 获取帧数和帧率
|
19 |
+
frames = f.getnframes()
|
20 |
+
rate = f.getframerate()
|
21 |
+
# 计算时长(秒)
|
22 |
+
duration = frames / float(rate)
|
23 |
+
# 关闭文件
|
24 |
+
f.close()
|
25 |
+
# 返回时长是否大于最短时长的布尔值
|
26 |
+
return duration > WAV_MIN_LENGTH
|
27 |
+
|
28 |
+
# 定义一个函数,用于从给定的目录中随机抽取一定比例的wav文件,并剪切到另一个目录中,保留数据结构
|
29 |
+
def split_data(src_dir, dst_dir, ratio):
|
30 |
+
# 创建目标目录(如果不存在)
|
31 |
+
if not os.path.exists(dst_dir):
|
32 |
+
os.makedirs(dst_dir)
|
33 |
+
|
34 |
+
# 获取源目录下所有的子目录和文件名
|
35 |
+
subdirs, files, subfiles = [], [], []
|
36 |
+
for item in os.listdir(src_dir):
|
37 |
+
item_path = os.path.join(src_dir, item)
|
38 |
+
if os.path.isdir(item_path):
|
39 |
+
subdirs.append(item)
|
40 |
+
for subitem in os.listdir(item_path):
|
41 |
+
subitem_path = os.path.join(item_path, subitem)
|
42 |
+
if os.path.isfile(subitem_path) and subitem.endswith(".wav"):
|
43 |
+
subfiles.append(subitem)
|
44 |
+
elif os.path.isfile(item_path) and item.endswith(".wav"):
|
45 |
+
files.append(item)
|
46 |
+
|
47 |
+
# 如果源目录下没有任何wav文件,则报错并退出函数
|
48 |
+
if len(files) == 0:
|
49 |
+
if len(subfiles) == 0:
|
50 |
+
print(f"Error: No wav files found in {src_dir}")
|
51 |
+
return
|
52 |
+
|
53 |
+
# 计算需要抽取的wav文件数量
|
54 |
+
num_files = int(len(files) * ratio)
|
55 |
+
num_files = max(SAMPLE_MIN, min(SAMPLE_MAX, num_files))
|
56 |
+
|
57 |
+
# 随机打乱文件名列表,并取出前num_files个作为抽取结果
|
58 |
+
np.random.shuffle(files)
|
59 |
+
selected_files = files[:num_files]
|
60 |
+
|
61 |
+
# 创建一个进度条对象,用于显示程序的运行进度
|
62 |
+
pbar = tqdm.tqdm(total=num_files)
|
63 |
+
|
64 |
+
# 遍历抽取结果中的每个文件名,检查是否大于2秒
|
65 |
+
for file in selected_files:
|
66 |
+
src_file = os.path.join(src_dir, file)
|
67 |
+
# 检查源文件的时长是否大于2秒,如果不是,则打印源文件的文件名,并跳过该文件
|
68 |
+
if not check_duration(src_file):
|
69 |
+
print(f"Skipped {src_file} because its duration is less than 2 seconds.")
|
70 |
+
continue
|
71 |
+
# 拼接源文件和目标文件的完整路径,移动文件,并更新进度条
|
72 |
+
dst_file = os.path.join(dst_dir, file)
|
73 |
+
shutil.move(src_file, dst_file)
|
74 |
+
pbar.update(1)
|
75 |
+
|
76 |
+
pbar.close()
|
77 |
+
|
78 |
+
# 遍历源目录下所有的子目录(如果有)
|
79 |
+
for subdir in subdirs:
|
80 |
+
# 拼接子目录在源目录和目标目录中的完整路径
|
81 |
+
src_subdir = os.path.join(src_dir, subdir)
|
82 |
+
dst_subdir = os.path.join(dst_dir, subdir)
|
83 |
+
# 递归地调用本函数,对子目录中的wav文件进行同样的操作,保留数据结构
|
84 |
+
split_data(src_subdir, dst_subdir, ratio)
|
85 |
+
|
86 |
+
# 定义主函数,用于获取用户输入并调用上述函数
|
87 |
+
|
88 |
+
def main():
|
89 |
+
root_dir = os.path.abspath('.')
|
90 |
+
dst_dir = root_dir + "/data/val/audio"
|
91 |
+
# 抽取比例,默认为1
|
92 |
+
ratio = float(SAMPLE_RATE) / 100
|
93 |
+
|
94 |
+
# 固定源目录为根目录下/data/train/audio目录
|
95 |
+
src_dir = root_dir + "/data/train/audio"
|
96 |
+
|
97 |
+
# 调用split_data函数,对源目录中的wav文件进行抽取,并剪切到目标目录中,保留数据结构
|
98 |
+
split_data(src_dir, dst_dir, ratio)
|
99 |
+
|
100 |
+
# 如果本模块是主模块,则执行主函数
|
101 |
+
if __name__ == "__main__":
|
102 |
+
main()
|
DDSP-SVC/encoder/hubert/model.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
import random
|
4 |
+
|
5 |
+
from sklearn.cluster import KMeans
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
11 |
+
|
12 |
+
URLS = {
|
13 |
+
"hubert-discrete": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-discrete-e9416457.pt",
|
14 |
+
"hubert-soft": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt",
|
15 |
+
"kmeans100": "https://github.com/bshall/hubert/releases/download/v0.1/kmeans100-50f36a95.pt",
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
class Hubert(nn.Module):
|
20 |
+
def __init__(self, num_label_embeddings: int = 100, mask: bool = True):
|
21 |
+
super().__init__()
|
22 |
+
self._mask = mask
|
23 |
+
self.feature_extractor = FeatureExtractor()
|
24 |
+
self.feature_projection = FeatureProjection()
|
25 |
+
self.positional_embedding = PositionalConvEmbedding()
|
26 |
+
self.norm = nn.LayerNorm(768)
|
27 |
+
self.dropout = nn.Dropout(0.1)
|
28 |
+
self.encoder = TransformerEncoder(
|
29 |
+
nn.TransformerEncoderLayer(
|
30 |
+
768, 12, 3072, activation="gelu", batch_first=True
|
31 |
+
),
|
32 |
+
12,
|
33 |
+
)
|
34 |
+
self.proj = nn.Linear(768, 256)
|
35 |
+
|
36 |
+
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_())
|
37 |
+
self.label_embedding = nn.Embedding(num_label_embeddings, 256)
|
38 |
+
|
39 |
+
def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
40 |
+
mask = None
|
41 |
+
if self.training and self._mask:
|
42 |
+
mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2)
|
43 |
+
x[mask] = self.masked_spec_embed.to(x.dtype)
|
44 |
+
return x, mask
|
45 |
+
|
46 |
+
def encode(
|
47 |
+
self, x: torch.Tensor, layer: Optional[int] = None
|
48 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
49 |
+
x = self.feature_extractor(x)
|
50 |
+
x = self.feature_projection(x.transpose(1, 2))
|
51 |
+
x, mask = self.mask(x)
|
52 |
+
x = x + self.positional_embedding(x)
|
53 |
+
x = self.dropout(self.norm(x))
|
54 |
+
x = self.encoder(x, output_layer=layer)
|
55 |
+
return x, mask
|
56 |
+
|
57 |
+
def logits(self, x: torch.Tensor) -> torch.Tensor:
|
58 |
+
logits = torch.cosine_similarity(
|
59 |
+
x.unsqueeze(2),
|
60 |
+
self.label_embedding.weight.unsqueeze(0).unsqueeze(0),
|
61 |
+
dim=-1,
|
62 |
+
)
|
63 |
+
return logits / 0.1
|
64 |
+
|
65 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
66 |
+
x, mask = self.encode(x)
|
67 |
+
x = self.proj(x)
|
68 |
+
logits = self.logits(x)
|
69 |
+
return logits, mask
|
70 |
+
|
71 |
+
|
72 |
+
class HubertSoft(Hubert):
|
73 |
+
def __init__(self):
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
@torch.inference_mode()
|
77 |
+
def units(self, wav: torch.Tensor) -> torch.Tensor:
|
78 |
+
wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
|
79 |
+
x, _ = self.encode(wav)
|
80 |
+
return self.proj(x)
|
81 |
+
|
82 |
+
|
83 |
+
class HubertDiscrete(Hubert):
|
84 |
+
def __init__(self, kmeans):
|
85 |
+
super().__init__(504)
|
86 |
+
self.kmeans = kmeans
|
87 |
+
|
88 |
+
@torch.inference_mode()
|
89 |
+
def units(self, wav: torch.Tensor) -> torch.LongTensor:
|
90 |
+
wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
|
91 |
+
x, _ = self.encode(wav, layer=7)
|
92 |
+
x = self.kmeans.predict(x.squeeze().cpu().numpy())
|
93 |
+
return torch.tensor(x, dtype=torch.long, device=wav.device)
|
94 |
+
|
95 |
+
|
96 |
+
class FeatureExtractor(nn.Module):
|
97 |
+
def __init__(self):
|
98 |
+
super().__init__()
|
99 |
+
self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False)
|
100 |
+
self.norm0 = nn.GroupNorm(512, 512)
|
101 |
+
self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
102 |
+
self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
103 |
+
self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
104 |
+
self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
105 |
+
self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False)
|
106 |
+
self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False)
|
107 |
+
|
108 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
109 |
+
x = F.gelu(self.norm0(self.conv0(x)))
|
110 |
+
x = F.gelu(self.conv1(x))
|
111 |
+
x = F.gelu(self.conv2(x))
|
112 |
+
x = F.gelu(self.conv3(x))
|
113 |
+
x = F.gelu(self.conv4(x))
|
114 |
+
x = F.gelu(self.conv5(x))
|
115 |
+
x = F.gelu(self.conv6(x))
|
116 |
+
return x
|
117 |
+
|
118 |
+
|
119 |
+
class FeatureProjection(nn.Module):
|
120 |
+
def __init__(self):
|
121 |
+
super().__init__()
|
122 |
+
self.norm = nn.LayerNorm(512)
|
123 |
+
self.projection = nn.Linear(512, 768)
|
124 |
+
self.dropout = nn.Dropout(0.1)
|
125 |
+
|
126 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
127 |
+
x = self.norm(x)
|
128 |
+
x = self.projection(x)
|
129 |
+
x = self.dropout(x)
|
130 |
+
return x
|
131 |
+
|
132 |
+
|
133 |
+
class PositionalConvEmbedding(nn.Module):
|
134 |
+
def __init__(self):
|
135 |
+
super().__init__()
|
136 |
+
self.conv = nn.Conv1d(
|
137 |
+
768,
|
138 |
+
768,
|
139 |
+
kernel_size=128,
|
140 |
+
padding=128 // 2,
|
141 |
+
groups=16,
|
142 |
+
)
|
143 |
+
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
144 |
+
|
145 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
146 |
+
x = self.conv(x.transpose(1, 2))
|
147 |
+
x = F.gelu(x[:, :, :-1])
|
148 |
+
return x.transpose(1, 2)
|
149 |
+
|
150 |
+
|
151 |
+
class TransformerEncoder(nn.Module):
|
152 |
+
def __init__(
|
153 |
+
self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int
|
154 |
+
) -> None:
|
155 |
+
super(TransformerEncoder, self).__init__()
|
156 |
+
self.layers = nn.ModuleList(
|
157 |
+
[copy.deepcopy(encoder_layer) for _ in range(num_layers)]
|
158 |
+
)
|
159 |
+
self.num_layers = num_layers
|
160 |
+
|
161 |
+
def forward(
|
162 |
+
self,
|
163 |
+
src: torch.Tensor,
|
164 |
+
mask: torch.Tensor = None,
|
165 |
+
src_key_padding_mask: torch.Tensor = None,
|
166 |
+
output_layer: Optional[int] = None,
|
167 |
+
) -> torch.Tensor:
|
168 |
+
output = src
|
169 |
+
for layer in self.layers[:output_layer]:
|
170 |
+
output = layer(
|
171 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
|
172 |
+
)
|
173 |
+
return output
|
174 |
+
|
175 |
+
|
176 |
+
def _compute_mask(
|
177 |
+
shape: Tuple[int, int],
|
178 |
+
mask_prob: float,
|
179 |
+
mask_length: int,
|
180 |
+
device: torch.device,
|
181 |
+
min_masks: int = 0,
|
182 |
+
) -> torch.Tensor:
|
183 |
+
batch_size, sequence_length = shape
|
184 |
+
|
185 |
+
if mask_length < 1:
|
186 |
+
raise ValueError("`mask_length` has to be bigger than 0.")
|
187 |
+
|
188 |
+
if mask_length > sequence_length:
|
189 |
+
raise ValueError(
|
190 |
+
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
|
191 |
+
)
|
192 |
+
|
193 |
+
# compute number of masked spans in batch
|
194 |
+
num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random())
|
195 |
+
num_masked_spans = max(num_masked_spans, min_masks)
|
196 |
+
|
197 |
+
# make sure num masked indices <= sequence_length
|
198 |
+
if num_masked_spans * mask_length > sequence_length:
|
199 |
+
num_masked_spans = sequence_length // mask_length
|
200 |
+
|
201 |
+
# SpecAugment mask to fill
|
202 |
+
mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
|
203 |
+
|
204 |
+
# uniform distribution to sample from, make sure that offset samples are < sequence_length
|
205 |
+
uniform_dist = torch.ones(
|
206 |
+
(batch_size, sequence_length - (mask_length - 1)), device=device
|
207 |
+
)
|
208 |
+
|
209 |
+
# get random indices to mask
|
210 |
+
mask_indices = torch.multinomial(uniform_dist, num_masked_spans)
|
211 |
+
|
212 |
+
# expand masked indices to masked spans
|
213 |
+
mask_indices = (
|
214 |
+
mask_indices.unsqueeze(dim=-1)
|
215 |
+
.expand((batch_size, num_masked_spans, mask_length))
|
216 |
+
.reshape(batch_size, num_masked_spans * mask_length)
|
217 |
+
)
|
218 |
+
offsets = (
|
219 |
+
torch.arange(mask_length, device=device)[None, None, :]
|
220 |
+
.expand((batch_size, num_masked_spans, mask_length))
|
221 |
+
.reshape(batch_size, num_masked_spans * mask_length)
|
222 |
+
)
|
223 |
+
mask_idxs = mask_indices + offsets
|
224 |
+
|
225 |
+
# scatter indices to mask
|
226 |
+
mask = mask.scatter(1, mask_idxs, True)
|
227 |
+
|
228 |
+
return mask
|
229 |
+
|
230 |
+
|
231 |
+
def hubert_discrete(
|
232 |
+
pretrained: bool = True,
|
233 |
+
progress: bool = True,
|
234 |
+
) -> HubertDiscrete:
|
235 |
+
r"""HuBERT-Discrete from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
|
236 |
+
Args:
|
237 |
+
pretrained (bool): load pretrained weights into the model
|
238 |
+
progress (bool): show progress bar when downloading model
|
239 |
+
"""
|
240 |
+
kmeans = kmeans100(pretrained=pretrained, progress=progress)
|
241 |
+
hubert = HubertDiscrete(kmeans)
|
242 |
+
if pretrained:
|
243 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
244 |
+
URLS["hubert-discrete"], progress=progress
|
245 |
+
)
|
246 |
+
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
|
247 |
+
hubert.load_state_dict(checkpoint)
|
248 |
+
hubert.eval()
|
249 |
+
return hubert
|
250 |
+
|
251 |
+
|
252 |
+
def hubert_soft(
|
253 |
+
pretrained: bool = True,
|
254 |
+
progress: bool = True,
|
255 |
+
) -> HubertSoft:
|
256 |
+
r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
|
257 |
+
Args:
|
258 |
+
pretrained (bool): load pretrained weights into the model
|
259 |
+
progress (bool): show progress bar when downloading model
|
260 |
+
"""
|
261 |
+
hubert = HubertSoft()
|
262 |
+
if pretrained:
|
263 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
264 |
+
URLS["hubert-soft"], progress=progress
|
265 |
+
)
|
266 |
+
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
|
267 |
+
hubert.load_state_dict(checkpoint)
|
268 |
+
hubert.eval()
|
269 |
+
return hubert
|
270 |
+
|
271 |
+
|
272 |
+
def _kmeans(
|
273 |
+
num_clusters: int, pretrained: bool = True, progress: bool = True
|
274 |
+
) -> KMeans:
|
275 |
+
kmeans = KMeans(num_clusters)
|
276 |
+
if pretrained:
|
277 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
278 |
+
URLS[f"kmeans{num_clusters}"], progress=progress
|
279 |
+
)
|
280 |
+
kmeans.__dict__["n_features_in_"] = checkpoint["n_features_in_"]
|
281 |
+
kmeans.__dict__["_n_threads"] = checkpoint["_n_threads"]
|
282 |
+
kmeans.__dict__["cluster_centers_"] = checkpoint["cluster_centers_"].numpy()
|
283 |
+
return kmeans
|
284 |
+
|
285 |
+
|
286 |
+
def kmeans100(pretrained: bool = True, progress: bool = True) -> KMeans:
|
287 |
+
r"""
|
288 |
+
k-means checkpoint for HuBERT-Discrete with 100 clusters.
|
289 |
+
Args:
|
290 |
+
pretrained (bool): load pretrained weights into the model
|
291 |
+
progress (bool): show progress bar when downloading model
|
292 |
+
"""
|
293 |
+
return _kmeans(100, pretrained, progress)
|
DDSP-SVC/enhancer.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from nsf_hifigan.nvSTFT import STFT
|
5 |
+
from nsf_hifigan.models import load_model
|
6 |
+
from torchaudio.transforms import Resample
|
7 |
+
|
8 |
+
class Enhancer:
|
9 |
+
def __init__(self, enhancer_type, enhancer_ckpt, device=None):
|
10 |
+
if device is None:
|
11 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
12 |
+
self.device = device
|
13 |
+
|
14 |
+
if enhancer_type == 'nsf-hifigan':
|
15 |
+
self.enhancer = NsfHifiGAN(enhancer_ckpt, device=self.device)
|
16 |
+
else:
|
17 |
+
raise ValueError(f" [x] Unknown enhancer: {enhancer_type}")
|
18 |
+
|
19 |
+
self.resample_kernel = {}
|
20 |
+
self.enhancer_sample_rate = self.enhancer.sample_rate()
|
21 |
+
self.enhancer_hop_size = self.enhancer.hop_size()
|
22 |
+
|
23 |
+
def enhance(self,
|
24 |
+
audio, # 1, T
|
25 |
+
sample_rate,
|
26 |
+
f0, # 1, n_frames, 1
|
27 |
+
hop_size,
|
28 |
+
adaptive_key = 0,
|
29 |
+
silence_front = 0
|
30 |
+
):
|
31 |
+
# enhancer start time
|
32 |
+
start_frame = int(silence_front * sample_rate / hop_size)
|
33 |
+
real_silence_front = start_frame * hop_size / sample_rate
|
34 |
+
audio = audio[:, int(np.round(real_silence_front * sample_rate)) : ]
|
35 |
+
f0 = f0[: , start_frame :, :]
|
36 |
+
|
37 |
+
# adaptive parameters
|
38 |
+
if adaptive_key == 'auto':
|
39 |
+
adaptive_key = 12 * np.log2(float(torch.max(f0) / 760))
|
40 |
+
adaptive_key = max(0, np.ceil(adaptive_key))
|
41 |
+
print('auto_adaptive_key: ' + str(int(adaptive_key)))
|
42 |
+
else:
|
43 |
+
adaptive_key = float(adaptive_key)
|
44 |
+
|
45 |
+
adaptive_factor = 2 ** ( -adaptive_key / 12)
|
46 |
+
adaptive_sample_rate = 100 * int(np.round(self.enhancer_sample_rate / adaptive_factor / 100))
|
47 |
+
real_factor = self.enhancer_sample_rate / adaptive_sample_rate
|
48 |
+
|
49 |
+
# resample the ddsp output
|
50 |
+
if sample_rate == adaptive_sample_rate:
|
51 |
+
audio_res = audio
|
52 |
+
else:
|
53 |
+
key_str = str(sample_rate) + str(adaptive_sample_rate)
|
54 |
+
if key_str not in self.resample_kernel:
|
55 |
+
self.resample_kernel[key_str] = Resample(sample_rate, adaptive_sample_rate, lowpass_filter_width = 128).to(self.device)
|
56 |
+
audio_res = self.resample_kernel[key_str](audio)
|
57 |
+
|
58 |
+
n_frames = int(audio_res.size(-1) // self.enhancer_hop_size + 1)
|
59 |
+
|
60 |
+
# resample f0
|
61 |
+
if hop_size == self.enhancer_hop_size and sample_rate == self.enhancer_sample_rate and sample_rate == adaptive_sample_rate:
|
62 |
+
f0_res = f0.squeeze(-1) # 1, n_frames
|
63 |
+
else:
|
64 |
+
f0_np = f0.squeeze(0).squeeze(-1).cpu().numpy()
|
65 |
+
f0_np *= real_factor
|
66 |
+
time_org = (hop_size / sample_rate) * np.arange(len(f0_np)) / real_factor
|
67 |
+
time_frame = (self.enhancer_hop_size / self.enhancer_sample_rate) * np.arange(n_frames)
|
68 |
+
f0_res = np.interp(time_frame, time_org, f0_np, left=f0_np[0], right=f0_np[-1])
|
69 |
+
f0_res = torch.from_numpy(f0_res).unsqueeze(0).float().to(self.device) # 1, n_frames
|
70 |
+
|
71 |
+
# enhance
|
72 |
+
enhanced_audio, enhancer_sample_rate = self.enhancer(audio_res, f0_res)
|
73 |
+
|
74 |
+
# resample the enhanced output
|
75 |
+
if adaptive_sample_rate != enhancer_sample_rate:
|
76 |
+
key_str = str(adaptive_sample_rate) + str(enhancer_sample_rate)
|
77 |
+
if key_str not in self.resample_kernel:
|
78 |
+
self.resample_kernel[key_str] = Resample(adaptive_sample_rate, enhancer_sample_rate, lowpass_filter_width = 128).to(self.device)
|
79 |
+
enhanced_audio = self.resample_kernel[key_str](enhanced_audio)
|
80 |
+
|
81 |
+
# pad the silence frames
|
82 |
+
if start_frame > 0:
|
83 |
+
enhanced_audio = F.pad(enhanced_audio, (int(np.round(enhancer_sample_rate * real_silence_front)), 0))
|
84 |
+
|
85 |
+
return enhanced_audio, enhancer_sample_rate
|
86 |
+
|
87 |
+
|
88 |
+
class NsfHifiGAN(torch.nn.Module):
|
89 |
+
def __init__(self, model_path, device=None):
|
90 |
+
super().__init__()
|
91 |
+
if device is None:
|
92 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
93 |
+
self.device = device
|
94 |
+
print('| Load HifiGAN: ', model_path)
|
95 |
+
self.model, self.h = load_model(model_path, device=self.device)
|
96 |
+
self.stft = STFT(
|
97 |
+
self.h.sampling_rate,
|
98 |
+
self.h.num_mels,
|
99 |
+
self.h.n_fft,
|
100 |
+
self.h.win_size,
|
101 |
+
self.h.hop_size,
|
102 |
+
self.h.fmin,
|
103 |
+
self.h.fmax)
|
104 |
+
|
105 |
+
def sample_rate(self):
|
106 |
+
return self.h.sampling_rate
|
107 |
+
|
108 |
+
def hop_size(self):
|
109 |
+
return self.h.hop_size
|
110 |
+
|
111 |
+
def forward(self, audio, f0):
|
112 |
+
with torch.no_grad():
|
113 |
+
mel = self.stft.get_mel(audio)
|
114 |
+
enhanced_audio = self.model(mel, f0[:,:mel.size(-1)])
|
115 |
+
return enhanced_audio, self.h.sampling_rate
|
DDSP-SVC/exp/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
DDSP-SVC/flask_api.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import logging
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import slicer
|
6 |
+
import soundfile as sf
|
7 |
+
import librosa
|
8 |
+
from flask import Flask, request, send_file
|
9 |
+
from flask_cors import CORS
|
10 |
+
|
11 |
+
from ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder
|
12 |
+
from ddsp.core import upsample
|
13 |
+
from enhancer import Enhancer
|
14 |
+
|
15 |
+
|
16 |
+
app = Flask(__name__)
|
17 |
+
|
18 |
+
CORS(app)
|
19 |
+
|
20 |
+
logging.getLogger("numba").setLevel(logging.WARNING)
|
21 |
+
|
22 |
+
|
23 |
+
@app.route("/voiceChangeModel", methods=["POST"])
|
24 |
+
def voice_change_model():
|
25 |
+
request_form = request.form
|
26 |
+
wave_file = request.files.get("sample", None)
|
27 |
+
# get fSafePrefixPadLength
|
28 |
+
f_safe_prefix_pad_length = float(request_form.get("fSafePrefixPadLength", 0))
|
29 |
+
print("f_safe_prefix_pad_length:"+str(f_safe_prefix_pad_length))
|
30 |
+
# 变调信息
|
31 |
+
f_pitch_change = float(request_form.get("fPitchChange", 0))
|
32 |
+
# 获取spk_id
|
33 |
+
int_speak_id = int(request_form.get("sSpeakId", 0))
|
34 |
+
if enable_spk_id_cover:
|
35 |
+
int_speak_id = spk_id
|
36 |
+
# print("说话人:" + str(int_speak_id))
|
37 |
+
# DAW所需的采样率
|
38 |
+
daw_sample = int(float(request_form.get("sampleRate", 0)))
|
39 |
+
# http获得wav文件并转换
|
40 |
+
input_wav_read = io.BytesIO(wave_file.read())
|
41 |
+
# 模型推理
|
42 |
+
_audio, _model_sr = svc_model.infer(input_wav_read, f_pitch_change, int_speak_id, f_safe_prefix_pad_length)
|
43 |
+
tar_audio = librosa.resample(_audio, _model_sr, daw_sample)
|
44 |
+
# 返回音频
|
45 |
+
out_wav_path = io.BytesIO()
|
46 |
+
sf.write(out_wav_path, tar_audio, daw_sample, format="wav")
|
47 |
+
out_wav_path.seek(0)
|
48 |
+
return send_file(out_wav_path, download_name="temp.wav", as_attachment=True)
|
49 |
+
|
50 |
+
|
51 |
+
class SvcDDSP:
|
52 |
+
def __init__(self, model_path, vocoder_based_enhancer, enhancer_adaptive_key, input_pitch_extractor,
|
53 |
+
f0_min, f0_max, threhold, spk_id, spk_mix_dict, enable_spk_id_cover):
|
54 |
+
self.model_path = model_path
|
55 |
+
self.vocoder_based_enhancer = vocoder_based_enhancer
|
56 |
+
self.enhancer_adaptive_key = enhancer_adaptive_key
|
57 |
+
self.input_pitch_extractor = input_pitch_extractor
|
58 |
+
self.f0_min = f0_min
|
59 |
+
self.f0_max = f0_max
|
60 |
+
self.threhold = threhold
|
61 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
62 |
+
self.spk_id = spk_id
|
63 |
+
self.spk_mix_dict = spk_mix_dict
|
64 |
+
self.enable_spk_id_cover = enable_spk_id_cover
|
65 |
+
|
66 |
+
# load ddsp model
|
67 |
+
self.model, self.args = load_model(self.model_path, device=self.device)
|
68 |
+
|
69 |
+
# load units encoder
|
70 |
+
if self.args.data.encoder == 'cnhubertsoftfish':
|
71 |
+
cnhubertsoft_gate = self.args.data.cnhubertsoft_gate
|
72 |
+
else:
|
73 |
+
cnhubertsoft_gate = 10
|
74 |
+
self.units_encoder = Units_Encoder(
|
75 |
+
self.args.data.encoder,
|
76 |
+
self.args.data.encoder_ckpt,
|
77 |
+
self.args.data.encoder_sample_rate,
|
78 |
+
self.args.data.encoder_hop_size,
|
79 |
+
cnhubertsoft_gate=cnhubertsoft_gate,
|
80 |
+
device=self.device)
|
81 |
+
|
82 |
+
# load enhancer
|
83 |
+
if self.vocoder_based_enhancer:
|
84 |
+
self.enhancer = Enhancer(self.args.enhancer.type, self.args.enhancer.ckpt, device=self.device)
|
85 |
+
|
86 |
+
def infer(self, input_wav, pitch_adjust, speaker_id, safe_prefix_pad_length):
|
87 |
+
print("Infer!")
|
88 |
+
# load input
|
89 |
+
audio, sample_rate = librosa.load(input_wav, sr=None, mono=True)
|
90 |
+
if len(audio.shape) > 1:
|
91 |
+
audio = librosa.to_mono(audio)
|
92 |
+
hop_size = self.args.data.block_size * sample_rate / self.args.data.sampling_rate
|
93 |
+
|
94 |
+
# safe front silence
|
95 |
+
if safe_prefix_pad_length > 0.03:
|
96 |
+
silence_front = safe_prefix_pad_length - 0.03
|
97 |
+
else:
|
98 |
+
silence_front = 0
|
99 |
+
|
100 |
+
# extract f0
|
101 |
+
pitch_extractor = F0_Extractor(
|
102 |
+
self.input_pitch_extractor,
|
103 |
+
sample_rate,
|
104 |
+
hop_size,
|
105 |
+
float(self.f0_min),
|
106 |
+
float(self.f0_max))
|
107 |
+
f0 = pitch_extractor.extract(audio, uv_interp=True, device=self.device, silence_front=silence_front)
|
108 |
+
f0 = torch.from_numpy(f0).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
109 |
+
f0 = f0 * 2 ** (float(pitch_adjust) / 12)
|
110 |
+
|
111 |
+
# extract volume
|
112 |
+
volume_extractor = Volume_Extractor(hop_size)
|
113 |
+
volume = volume_extractor.extract(audio)
|
114 |
+
mask = (volume > 10 ** (float(self.threhold) / 20)).astype('float')
|
115 |
+
mask = np.pad(mask, (4, 4), constant_values=(mask[0], mask[-1]))
|
116 |
+
mask = np.array([np.max(mask[n : n + 9]) for n in range(len(mask) - 8)])
|
117 |
+
mask = torch.from_numpy(mask).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
118 |
+
mask = upsample(mask, self.args.data.block_size).squeeze(-1)
|
119 |
+
volume = torch.from_numpy(volume).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
120 |
+
|
121 |
+
# extract units
|
122 |
+
audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(self.device)
|
123 |
+
units = self.units_encoder.encode(audio_t, sample_rate, hop_size)
|
124 |
+
|
125 |
+
# spk_id or spk_mix_dict
|
126 |
+
if self.enable_spk_id_cover:
|
127 |
+
spk_id = self.spk_id
|
128 |
+
else:
|
129 |
+
spk_id = speaker_id
|
130 |
+
spk_id = torch.LongTensor(np.array([[spk_id]])).to(self.device)
|
131 |
+
|
132 |
+
# forward and return the output
|
133 |
+
with torch.no_grad():
|
134 |
+
output, _, (s_h, s_n) = self.model(units, f0, volume, spk_id = spk_id, spk_mix_dict = self.spk_mix_dict)
|
135 |
+
output *= mask
|
136 |
+
if self.vocoder_based_enhancer:
|
137 |
+
output, output_sample_rate = self.enhancer.enhance(
|
138 |
+
output,
|
139 |
+
self.args.data.sampling_rate,
|
140 |
+
f0,
|
141 |
+
self.args.data.block_size,
|
142 |
+
adaptive_key = self.enhancer_adaptive_key,
|
143 |
+
silence_front = silence_front)
|
144 |
+
else:
|
145 |
+
output_sample_rate = self.args.data.sampling_rate
|
146 |
+
|
147 |
+
output = output.squeeze().cpu().numpy()
|
148 |
+
return output, output_sample_rate
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == "__main__":
|
152 |
+
# ddsp-svc下只需传入下列参数。
|
153 |
+
# 对接的是串串香火锅大佬https://github.com/zhaohui8969/VST_NetProcess-。建议使用最新版本。
|
154 |
+
# flask部分来自diffsvc小狼大佬编写的代码。
|
155 |
+
# config和模型得同一目录。
|
156 |
+
checkpoint_path = "exp/multi_speaker/model_300000.pt"
|
157 |
+
# 是否使用预训练的基于声码器的增强器增强输出,但对硬件要求更高。
|
158 |
+
use_vocoder_based_enhancer = True
|
159 |
+
# 结合增强器使用,0为正常音域范围(最高G5)内的高音频质量,大于0则可以防止超高音破音
|
160 |
+
enhancer_adaptive_key = 0
|
161 |
+
# f0提取器,有parselmouth, dio, harvest, crepe
|
162 |
+
select_pitch_extractor = 'crepe'
|
163 |
+
# f0范围限制(Hz)
|
164 |
+
limit_f0_min = 50
|
165 |
+
limit_f0_max = 1100
|
166 |
+
# 音量响应阈值(dB)
|
167 |
+
threhold = -60
|
168 |
+
# 默认说话人。以及是否优先使用默认说话人覆盖vst传入的参数。
|
169 |
+
spk_id = 1
|
170 |
+
enable_spk_id_cover = True
|
171 |
+
# 混合说话人字典(捏音色功能)
|
172 |
+
# 设置为非 None 字典会覆盖 spk_id
|
173 |
+
spk_mix_dict = None # {1:0.5, 2:0.5} 表示1号说话人和2号说话人的音色按照0.5:0.5的比例混合
|
174 |
+
svc_model = SvcDDSP(checkpoint_path, use_vocoder_based_enhancer, enhancer_adaptive_key, select_pitch_extractor,
|
175 |
+
limit_f0_min, limit_f0_max, threhold, spk_id, spk_mix_dict, enable_spk_id_cover)
|
176 |
+
|
177 |
+
# 此处与vst插件对应,端口必须接上。
|
178 |
+
app.run(port=6844, host="0.0.0.0", debug=False, threaded=False)
|
DDSP-SVC/gui.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PySimpleGUI as sg
|
2 |
+
import sounddevice as sd
|
3 |
+
import torch, librosa, threading, pickle
|
4 |
+
from enhancer import Enhancer
|
5 |
+
import numpy as np
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from torchaudio.transforms import Resample
|
8 |
+
from ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder
|
9 |
+
from ddsp.core import upsample
|
10 |
+
import time
|
11 |
+
import gui_locale
|
12 |
+
|
13 |
+
|
14 |
+
def phase_vocoder(a, b, fade_out, fade_in):
|
15 |
+
fa = torch.fft.rfft(a)
|
16 |
+
fb = torch.fft.rfft(b)
|
17 |
+
absab = torch.abs(fa) + torch.abs(fb)
|
18 |
+
n = a.shape[0]
|
19 |
+
if n % 2 == 0:
|
20 |
+
absab[1:-1] *= 2
|
21 |
+
else:
|
22 |
+
absab[1:] *= 2
|
23 |
+
phia = torch.angle(fa)
|
24 |
+
phib = torch.angle(fb)
|
25 |
+
deltaphase = phib - phia
|
26 |
+
deltaphase = deltaphase - 2 * np.pi * torch.floor(deltaphase / 2 / np.pi + 0.5)
|
27 |
+
w = 2 * np.pi * torch.arange(n // 2 + 1).to(a) + deltaphase
|
28 |
+
t = torch.arange(n).unsqueeze(-1).to(a) / n
|
29 |
+
result = a * (fade_out ** 2) + b * (fade_in ** 2) + torch.sum(absab * torch.cos(w * t + phia),
|
30 |
+
-1) * fade_out * fade_in / n
|
31 |
+
return result
|
32 |
+
|
33 |
+
|
34 |
+
class SvcDDSP:
|
35 |
+
def __init__(self) -> None:
|
36 |
+
self.model = None
|
37 |
+
self.units_encoder = None
|
38 |
+
self.encoder_type = None
|
39 |
+
self.encoder_ckpt = None
|
40 |
+
self.enhancer = None
|
41 |
+
self.enhancer_type = None
|
42 |
+
self.enhancer_ckpt = None
|
43 |
+
|
44 |
+
def update_model(self, model_path):
|
45 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
46 |
+
|
47 |
+
# load ddsp model
|
48 |
+
if self.model is None or self.model_path != model_path:
|
49 |
+
self.model, self.args = load_model(model_path, device=self.device)
|
50 |
+
self.model_path = model_path
|
51 |
+
|
52 |
+
# load units encoder
|
53 |
+
if self.units_encoder is None or self.args.data.encoder != self.encoder_type or self.args.data.encoder_ckpt != self.encoder_ckpt:
|
54 |
+
if self.args.data.encoder == 'cnhubertsoftfish':
|
55 |
+
cnhubertsoft_gate = self.args.data.cnhubertsoft_gate
|
56 |
+
else:
|
57 |
+
cnhubertsoft_gate = 10
|
58 |
+
self.units_encoder = Units_Encoder(
|
59 |
+
self.args.data.encoder,
|
60 |
+
self.args.data.encoder_ckpt,
|
61 |
+
self.args.data.encoder_sample_rate,
|
62 |
+
self.args.data.encoder_hop_size,
|
63 |
+
cnhubertsoft_gate=cnhubertsoft_gate,
|
64 |
+
device=self.device)
|
65 |
+
self.encoder_type = self.args.data.encoder
|
66 |
+
self.encoder_ckpt = self.args.data.encoder_ckpt
|
67 |
+
|
68 |
+
# load enhancer
|
69 |
+
if self.enhancer is None or self.args.enhancer.type != self.enhancer_type or self.args.enhancer.ckpt != self.enhancer_ckpt:
|
70 |
+
self.enhancer = Enhancer(self.args.enhancer.type, self.args.enhancer.ckpt, device=self.device)
|
71 |
+
self.enhancer_type = self.args.enhancer.type
|
72 |
+
self.enhancer_ckpt = self.args.enhancer.ckpt
|
73 |
+
|
74 |
+
def infer(self,
|
75 |
+
audio,
|
76 |
+
sample_rate,
|
77 |
+
spk_id=1,
|
78 |
+
threhold=-45,
|
79 |
+
pitch_adjust=0,
|
80 |
+
use_spk_mix=False,
|
81 |
+
spk_mix_dict=None,
|
82 |
+
use_enhancer=True,
|
83 |
+
enhancer_adaptive_key='auto',
|
84 |
+
pitch_extractor_type='crepe',
|
85 |
+
f0_min=50,
|
86 |
+
f0_max=1100,
|
87 |
+
safe_prefix_pad_length=0,
|
88 |
+
):
|
89 |
+
print("Infering...")
|
90 |
+
# load input
|
91 |
+
# audio, sample_rate = librosa.load(input_wav, sr=None, mono=True)
|
92 |
+
hop_size = self.args.data.block_size * sample_rate / self.args.data.sampling_rate
|
93 |
+
# safe front silence
|
94 |
+
if safe_prefix_pad_length > 0.03:
|
95 |
+
silence_front = safe_prefix_pad_length - 0.03
|
96 |
+
else:
|
97 |
+
silence_front = 0
|
98 |
+
|
99 |
+
# extract f0
|
100 |
+
pitch_extractor = F0_Extractor(
|
101 |
+
pitch_extractor_type,
|
102 |
+
sample_rate,
|
103 |
+
hop_size,
|
104 |
+
float(f0_min),
|
105 |
+
float(f0_max))
|
106 |
+
f0 = pitch_extractor.extract(audio, uv_interp=True, device=self.device, silence_front=silence_front)
|
107 |
+
f0 = torch.from_numpy(f0).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
108 |
+
f0 = f0 * 2 ** (float(pitch_adjust) / 12)
|
109 |
+
|
110 |
+
# extract volume
|
111 |
+
volume_extractor = Volume_Extractor(hop_size)
|
112 |
+
volume = volume_extractor.extract(audio)
|
113 |
+
mask = (volume > 10 ** (float(threhold) / 20)).astype('float')
|
114 |
+
mask = np.pad(mask, (4, 4), constant_values=(mask[0], mask[-1]))
|
115 |
+
mask = np.array([np.max(mask[n: n + 9]) for n in range(len(mask) - 8)])
|
116 |
+
mask = torch.from_numpy(mask).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
117 |
+
mask = upsample(mask, self.args.data.block_size).squeeze(-1)
|
118 |
+
volume = torch.from_numpy(volume).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
119 |
+
|
120 |
+
# extract units
|
121 |
+
audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(self.device)
|
122 |
+
units = self.units_encoder.encode(audio_t, sample_rate, hop_size)
|
123 |
+
|
124 |
+
# spk_id or spk_mix_dict
|
125 |
+
spk_id = torch.LongTensor(np.array([[spk_id]])).to(self.device)
|
126 |
+
dictionary = None
|
127 |
+
if use_spk_mix:
|
128 |
+
dictionary = spk_mix_dict
|
129 |
+
|
130 |
+
# forward and return the output
|
131 |
+
with torch.no_grad():
|
132 |
+
output, _, (s_h, s_n) = self.model(units, f0, volume, spk_id=spk_id, spk_mix_dict=dictionary)
|
133 |
+
output *= mask
|
134 |
+
if use_enhancer:
|
135 |
+
output, output_sample_rate = self.enhancer.enhance(
|
136 |
+
output,
|
137 |
+
self.args.data.sampling_rate,
|
138 |
+
f0,
|
139 |
+
self.args.data.block_size,
|
140 |
+
adaptive_key=enhancer_adaptive_key,
|
141 |
+
silence_front=silence_front)
|
142 |
+
else:
|
143 |
+
output_sample_rate = self.args.data.sampling_rate
|
144 |
+
|
145 |
+
output = output.squeeze()
|
146 |
+
return output, output_sample_rate
|
147 |
+
|
148 |
+
|
149 |
+
class Config:
|
150 |
+
def __init__(self) -> None:
|
151 |
+
self.samplerate = 44100 # Hz
|
152 |
+
self.block_time = 1.5 # s
|
153 |
+
self.f_pitch_change: float = 0.0 # float(request_form.get("fPitchChange", 0))
|
154 |
+
self.spk_id = 1 # 默认说话人。
|
155 |
+
self.spk_mix_dict = None # {1:0.5, 2:0.5} 表示1号说话人和2号说话人的音色按照0.5:0.5的比例混合
|
156 |
+
self.use_vocoder_based_enhancer = True
|
157 |
+
self.use_phase_vocoder = True
|
158 |
+
self.checkpoint_path = ''
|
159 |
+
self.threhold = -35
|
160 |
+
self.buffer_num = 2
|
161 |
+
self.crossfade_time = 0.03
|
162 |
+
self.select_pitch_extractor = 'harvest' # F0预测器["parselmouth", "dio", "harvest", "crepe"]
|
163 |
+
self.use_spk_mix = False
|
164 |
+
self.sounddevices = ['', '']
|
165 |
+
|
166 |
+
def save(self, path):
|
167 |
+
with open(path + '\\config.pkl', 'wb') as f:
|
168 |
+
pickle.dump(vars(self), f)
|
169 |
+
|
170 |
+
def load(self, path) -> bool:
|
171 |
+
try:
|
172 |
+
with open(path + '\\config.pkl', 'rb') as f:
|
173 |
+
self.update(pickle.load(f))
|
174 |
+
return True
|
175 |
+
except:
|
176 |
+
print('config.pkl does not exist')
|
177 |
+
return False
|
178 |
+
|
179 |
+
def update(self, data_dict):
|
180 |
+
for key, value in data_dict.items():
|
181 |
+
setattr(self, key, value)
|
182 |
+
|
183 |
+
|
184 |
+
class GUI:
|
185 |
+
def __init__(self) -> None:
|
186 |
+
self.config = Config()
|
187 |
+
self.flag_vc: bool = False # 变声线程flag
|
188 |
+
self.block_frame = 0
|
189 |
+
self.crossfade_frame = 0
|
190 |
+
self.sola_search_frame = 0
|
191 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
192 |
+
self.svc_model: SvcDDSP = SvcDDSP()
|
193 |
+
self.fade_in_window: np.ndarray = None # crossfade计算用numpy数组
|
194 |
+
self.fade_out_window: np.ndarray = None # crossfade计算用numpy数组
|
195 |
+
self.input_wav: np.ndarray = None # 输入音频规范化后的保存地址
|
196 |
+
self.output_wav: np.ndarray = None # 输出音频规范化后的保存地址
|
197 |
+
self.sola_buffer: torch.Tensor = None # 保存上一个output的crossfade
|
198 |
+
self.f0_mode_list = ["parselmouth", "dio", "harvest", "crepe"] # F0预测器
|
199 |
+
self.f_safe_prefix_pad_length: float = 0.0
|
200 |
+
self.resample_kernel = {}
|
201 |
+
self.launcher() # start
|
202 |
+
|
203 |
+
def launcher(self):
|
204 |
+
'''窗口加载'''
|
205 |
+
input_devices, output_devices, _, _ = self.get_devices()
|
206 |
+
sg.theme('DarkAmber') # 设置主题
|
207 |
+
# 界面布局
|
208 |
+
layout = [
|
209 |
+
[sg.Frame(layout=[
|
210 |
+
[sg.Input(key='sg_model', default_text='exp\\multi_speaker\\model_300000.pt'),
|
211 |
+
sg.FileBrowse(i18n('选择模型文件'), key='choose_model')]
|
212 |
+
], title=i18n('模型:.pt格式(自动识别同目录下config.yaml)')),
|
213 |
+
sg.Frame(layout=[
|
214 |
+
[sg.Text(i18n('选择配置文件所在目录')), sg.Input(key='config_file_dir', default_text='exp'),
|
215 |
+
sg.FolderBrowse(i18n('打开文件夹'), key='choose_config')],
|
216 |
+
[sg.Button(i18n('读取配置文件'), key='load_config'), sg.Button(i18n('保存配置文件'), key='save_config')]
|
217 |
+
], title=i18n('快速配置文件'))
|
218 |
+
],
|
219 |
+
[sg.Frame(layout=[
|
220 |
+
[sg.Text(i18n("输入设备")),
|
221 |
+
sg.Combo(input_devices, key='sg_input_device', default_value=input_devices[sd.default.device[0]],
|
222 |
+
enable_events=True)],
|
223 |
+
[sg.Text(i18n("输出设备")),
|
224 |
+
sg.Combo(output_devices, key='sg_output_device', default_value=output_devices[sd.default.device[1]],
|
225 |
+
enable_events=True)]
|
226 |
+
], title=i18n('音频设备'))
|
227 |
+
],
|
228 |
+
[sg.Frame(layout=[
|
229 |
+
[sg.Text(i18n("说话人id")), sg.Input(key='spk_id', default_text='1')],
|
230 |
+
[sg.Text(i18n("响应阈值")),
|
231 |
+
sg.Slider(range=(-60, 0), orientation='h', key='threhold', resolution=1, default_value=-45,
|
232 |
+
enable_events=True)],
|
233 |
+
[sg.Text(i18n("变调")),
|
234 |
+
sg.Slider(range=(-24, 24), orientation='h', key='pitch', resolution=1, default_value=0,
|
235 |
+
enable_events=True)],
|
236 |
+
[sg.Text(i18n("采样率")), sg.Input(key='samplerate', default_text='44100')],
|
237 |
+
[sg.Checkbox(text=i18n('启用捏音色功能'), default=False, key='spk_mix', enable_events=True),
|
238 |
+
sg.Button(i18n("设置混合音色"), key='set_spk_mix')]
|
239 |
+
], title=i18n('普通设置')),
|
240 |
+
sg.Frame(layout=[
|
241 |
+
[sg.Text(i18n("音频切分大小")),
|
242 |
+
sg.Slider(range=(0.05, 3.0), orientation='h', key='block', resolution=0.01, default_value=0.3,
|
243 |
+
enable_events=True)],
|
244 |
+
[sg.Text(i18n("交叉淡化时长")),
|
245 |
+
sg.Slider(range=(0.01, 0.15), orientation='h', key='crossfade', resolution=0.01,
|
246 |
+
default_value=0.04, enable_events=True)],
|
247 |
+
[sg.Text(i18n("使用历史区块数量")),
|
248 |
+
sg.Slider(range=(1, 20), orientation='h', key='buffernum', resolution=1, default_value=4,
|
249 |
+
enable_events=True)],
|
250 |
+
[sg.Text(i18n("f0预测模式")),
|
251 |
+
sg.Combo(values=self.f0_mode_list, key='f0_mode', default_value=self.f0_mode_list[2],
|
252 |
+
enable_events=True)],
|
253 |
+
[sg.Checkbox(text=i18n('启用增强器'), default=True, key='use_enhancer', enable_events=True),
|
254 |
+
sg.Checkbox(text=i18n('启用相位声码器'), default=False, key='use_phase_vocoder', enable_events=True)]
|
255 |
+
], title=i18n('性能设置')),
|
256 |
+
],
|
257 |
+
[sg.Button(i18n("开始音频转换"), key="start_vc"), sg.Button(i18n("停止音频转换"), key="stop_vc"),
|
258 |
+
sg.Text(i18n('推理所用时间(ms):')), sg.Text('0', key='infer_time')]
|
259 |
+
]
|
260 |
+
|
261 |
+
# 创造窗口
|
262 |
+
self.window = sg.Window('DDSP - GUI', layout, finalize=True)
|
263 |
+
self.window['spk_id'].bind('<Return>', '')
|
264 |
+
self.window['samplerate'].bind('<Return>', '')
|
265 |
+
self.event_handler()
|
266 |
+
|
267 |
+
def event_handler(self):
|
268 |
+
'''事件处理'''
|
269 |
+
while True: # 事件处理循环
|
270 |
+
event, values = self.window.read()
|
271 |
+
print('event: ' + event)
|
272 |
+
if event == sg.WINDOW_CLOSED: # 如果用户关闭窗口
|
273 |
+
self.flag_vc = False
|
274 |
+
exit()
|
275 |
+
elif event == 'start_vc' and self.flag_vc == False:
|
276 |
+
# set values 和界面布局layout顺序一一对应
|
277 |
+
self.set_values(values)
|
278 |
+
print('crossfade_time:' + str(self.config.crossfade_time))
|
279 |
+
print("buffer_num:" + str(self.config.buffer_num))
|
280 |
+
print("samplerate:" + str(self.config.samplerate))
|
281 |
+
print('block_time:' + str(self.config.block_time))
|
282 |
+
print("prefix_pad_length:" + str(self.f_safe_prefix_pad_length))
|
283 |
+
print("mix_mode:" + str(self.config.spk_mix_dict))
|
284 |
+
print("enhancer:" + str(self.config.use_vocoder_based_enhancer))
|
285 |
+
print('using_cuda:' + str(torch.cuda.is_available()))
|
286 |
+
self.start_vc()
|
287 |
+
elif event == 'spk_id':
|
288 |
+
self.config.spk_id = int(values['spk_id'])
|
289 |
+
elif event == 'threhold':
|
290 |
+
self.config.threhold = values['threhold']
|
291 |
+
elif event == 'pitch':
|
292 |
+
self.config.f_pitch_change = values['pitch']
|
293 |
+
elif event == 'spk_mix':
|
294 |
+
self.config.use_spk_mix = values['spk_mix']
|
295 |
+
elif event == 'set_spk_mix':
|
296 |
+
spk_mix = sg.popup_get_text(message='示例:1:0.3,2:0.5,3:0.2', title="设置混合音色,支持多人")
|
297 |
+
if spk_mix != None:
|
298 |
+
self.config.spk_mix_dict = eval("{" + spk_mix.replace(',', ',').replace(':', ':') + "}")
|
299 |
+
elif event == 'f0_mode':
|
300 |
+
self.config.select_pitch_extractor = values['f0_mode']
|
301 |
+
elif event == 'use_enhancer':
|
302 |
+
self.config.use_vocoder_based_enhancer = values['use_enhancer']
|
303 |
+
elif event == 'use_phase_vocoder':
|
304 |
+
self.config.use_phase_vocoder = values['use_phase_vocoder']
|
305 |
+
elif event == 'load_config' and self.flag_vc == False:
|
306 |
+
if self.config.load(values['config_file_dir']):
|
307 |
+
self.update_values()
|
308 |
+
elif event == 'save_config' and self.flag_vc == False:
|
309 |
+
self.set_values(values)
|
310 |
+
self.config.save(values['config_file_dir'])
|
311 |
+
elif event != 'start_vc' and self.flag_vc == True:
|
312 |
+
self.flag_vc = False
|
313 |
+
|
314 |
+
def set_values(self, values):
|
315 |
+
self.set_devices(values["sg_input_device"], values['sg_output_device'])
|
316 |
+
self.config.sounddevices = [values["sg_input_device"], values['sg_output_device']]
|
317 |
+
self.config.checkpoint_path = values['sg_model']
|
318 |
+
self.config.spk_id = int(values['spk_id'])
|
319 |
+
self.config.threhold = values['threhold']
|
320 |
+
self.config.f_pitch_change = values['pitch']
|
321 |
+
self.config.samplerate = int(values['samplerate'])
|
322 |
+
self.config.block_time = float(values['block'])
|
323 |
+
self.config.crossfade_time = float(values['crossfade'])
|
324 |
+
self.config.buffer_num = int(values['buffernum'])
|
325 |
+
self.config.select_pitch_extractor = values['f0_mode']
|
326 |
+
self.config.use_vocoder_based_enhancer = values['use_enhancer']
|
327 |
+
self.config.use_phase_vocoder = values['use_phase_vocoder']
|
328 |
+
self.config.use_spk_mix = values['spk_mix']
|
329 |
+
self.block_frame = int(self.config.block_time * self.config.samplerate)
|
330 |
+
self.crossfade_frame = int(self.config.crossfade_time * self.config.samplerate)
|
331 |
+
self.sola_search_frame = int(0.01 * self.config.samplerate)
|
332 |
+
self.last_delay_frame = int(0.02 * self.config.samplerate)
|
333 |
+
self.input_frames = max(
|
334 |
+
self.block_frame + self.crossfade_frame + self.sola_search_frame + 2 * self.last_delay_frame,
|
335 |
+
(1 + self.config.buffer_num) * self.block_frame)
|
336 |
+
self.f_safe_prefix_pad_length = self.config.block_time * self.config.buffer_num - self.config.crossfade_time - 0.01 - 0.02
|
337 |
+
|
338 |
+
def update_values(self):
|
339 |
+
self.window['sg_model'].update(self.config.checkpoint_path)
|
340 |
+
self.window['sg_input_device'].update(self.config.sounddevices[0])
|
341 |
+
self.window['sg_output_device'].update(self.config.sounddevices[1])
|
342 |
+
self.window['spk_id'].update(self.config.spk_id)
|
343 |
+
self.window['threhold'].update(self.config.threhold)
|
344 |
+
self.window['pitch'].update(self.config.f_pitch_change)
|
345 |
+
self.window['samplerate'].update(self.config.samplerate)
|
346 |
+
self.window['spk_mix'].update(self.config.use_spk_mix)
|
347 |
+
self.window['block'].update(self.config.block_time)
|
348 |
+
self.window['crossfade'].update(self.config.crossfade_time)
|
349 |
+
self.window['buffernum'].update(self.config.buffer_num)
|
350 |
+
self.window['f0_mode'].update(self.config.select_pitch_extractor)
|
351 |
+
self.window['use_enhancer'].update(self.config.use_vocoder_based_enhancer)
|
352 |
+
|
353 |
+
def start_vc(self):
|
354 |
+
'''开始音频转换'''
|
355 |
+
torch.cuda.empty_cache()
|
356 |
+
self.flag_vc = True
|
357 |
+
self.input_wav = np.zeros(self.input_frames, dtype='float32')
|
358 |
+
self.sola_buffer = torch.zeros(self.crossfade_frame, device=self.device)
|
359 |
+
self.fade_in_window = torch.sin(
|
360 |
+
np.pi * torch.arange(0, 1, 1 / self.crossfade_frame, device=self.device) / 2) ** 2
|
361 |
+
self.fade_out_window = 1 - self.fade_in_window
|
362 |
+
self.svc_model.update_model(self.config.checkpoint_path)
|
363 |
+
thread_vc = threading.Thread(target=self.soundinput)
|
364 |
+
thread_vc.start()
|
365 |
+
|
366 |
+
def soundinput(self):
|
367 |
+
'''
|
368 |
+
接受音频输入
|
369 |
+
'''
|
370 |
+
with sd.Stream(callback=self.audio_callback, blocksize=self.block_frame, samplerate=self.config.samplerate,
|
371 |
+
dtype='float32'):
|
372 |
+
while self.flag_vc:
|
373 |
+
time.sleep(self.config.block_time)
|
374 |
+
print('Audio block passed.')
|
375 |
+
print('ENDing VC')
|
376 |
+
|
377 |
+
def audio_callback(self, indata: np.ndarray, outdata: np.ndarray, frames, times, status):
|
378 |
+
'''
|
379 |
+
音频处理
|
380 |
+
'''
|
381 |
+
start_time = time.perf_counter()
|
382 |
+
print("\nStarting callback")
|
383 |
+
self.input_wav[:] = np.roll(self.input_wav, -self.block_frame)
|
384 |
+
self.input_wav[-self.block_frame:] = librosa.to_mono(indata.T)
|
385 |
+
|
386 |
+
# infer
|
387 |
+
_audio, _model_sr = self.svc_model.infer(
|
388 |
+
self.input_wav,
|
389 |
+
self.config.samplerate,
|
390 |
+
spk_id=self.config.spk_id,
|
391 |
+
threhold=self.config.threhold,
|
392 |
+
pitch_adjust=self.config.f_pitch_change,
|
393 |
+
use_spk_mix=self.config.use_spk_mix,
|
394 |
+
spk_mix_dict=self.config.spk_mix_dict,
|
395 |
+
use_enhancer=self.config.use_vocoder_based_enhancer,
|
396 |
+
pitch_extractor_type=self.config.select_pitch_extractor,
|
397 |
+
safe_prefix_pad_length=self.f_safe_prefix_pad_length,
|
398 |
+
)
|
399 |
+
|
400 |
+
# debug sola
|
401 |
+
'''
|
402 |
+
_audio, _model_sr = self.input_wav, self.config.samplerate
|
403 |
+
rs = int(np.random.uniform(-200,200))
|
404 |
+
print('debug_random_shift: ' + str(rs))
|
405 |
+
_audio = np.roll(_audio, rs)
|
406 |
+
_audio = torch.from_numpy(_audio).to(self.device)
|
407 |
+
'''
|
408 |
+
|
409 |
+
if _model_sr != self.config.samplerate:
|
410 |
+
key_str = str(_model_sr) + '_' + str(self.config.samplerate)
|
411 |
+
if key_str not in self.resample_kernel:
|
412 |
+
self.resample_kernel[key_str] = Resample(_model_sr, self.config.samplerate,
|
413 |
+
lowpass_filter_width=128).to(self.device)
|
414 |
+
_audio = self.resample_kernel[key_str](_audio)
|
415 |
+
temp_wav = _audio[
|
416 |
+
- self.block_frame - self.crossfade_frame - self.sola_search_frame - self.last_delay_frame: - self.last_delay_frame]
|
417 |
+
|
418 |
+
# sola shift
|
419 |
+
conv_input = temp_wav[None, None, : self.crossfade_frame + self.sola_search_frame]
|
420 |
+
cor_nom = F.conv1d(conv_input, self.sola_buffer[None, None, :])
|
421 |
+
cor_den = torch.sqrt(
|
422 |
+
F.conv1d(conv_input ** 2, torch.ones(1, 1, self.crossfade_frame, device=self.device)) + 1e-8)
|
423 |
+
sola_shift = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
|
424 |
+
temp_wav = temp_wav[sola_shift: sola_shift + self.block_frame + self.crossfade_frame]
|
425 |
+
print('sola_shift: ' + str(int(sola_shift)))
|
426 |
+
|
427 |
+
# phase vocoder
|
428 |
+
if self.config.use_phase_vocoder:
|
429 |
+
temp_wav[: self.crossfade_frame] = phase_vocoder(
|
430 |
+
self.sola_buffer,
|
431 |
+
temp_wav[: self.crossfade_frame],
|
432 |
+
self.fade_out_window,
|
433 |
+
self.fade_in_window)
|
434 |
+
else:
|
435 |
+
temp_wav[: self.crossfade_frame] *= self.fade_in_window
|
436 |
+
temp_wav[: self.crossfade_frame] += self.sola_buffer * self.fade_out_window
|
437 |
+
|
438 |
+
self.sola_buffer = temp_wav[- self.crossfade_frame:]
|
439 |
+
|
440 |
+
outdata[:] = temp_wav[: - self.crossfade_frame, None].repeat(1, 2).cpu().numpy()
|
441 |
+
end_time = time.perf_counter()
|
442 |
+
print('infer_time: ' + str(end_time - start_time))
|
443 |
+
self.window['infer_time'].update(int((end_time - start_time) * 1000))
|
444 |
+
|
445 |
+
def get_devices(self, update: bool = True):
|
446 |
+
'''获取设备列表'''
|
447 |
+
if update:
|
448 |
+
sd._terminate()
|
449 |
+
sd._initialize()
|
450 |
+
devices = sd.query_devices()
|
451 |
+
hostapis = sd.query_hostapis()
|
452 |
+
for hostapi in hostapis:
|
453 |
+
for device_idx in hostapi["devices"]:
|
454 |
+
devices[device_idx]["hostapi_name"] = hostapi["name"]
|
455 |
+
input_devices = [
|
456 |
+
f"{d['name']} ({d['hostapi_name']})"
|
457 |
+
for d in devices
|
458 |
+
if d["max_input_channels"] > 0
|
459 |
+
]
|
460 |
+
output_devices = [
|
461 |
+
f"{d['name']} ({d['hostapi_name']})"
|
462 |
+
for d in devices
|
463 |
+
if d["max_output_channels"] > 0
|
464 |
+
]
|
465 |
+
input_devices_indices = [d["index"] for d in devices if d["max_input_channels"] > 0]
|
466 |
+
output_devices_indices = [
|
467 |
+
d["index"] for d in devices if d["max_output_channels"] > 0
|
468 |
+
]
|
469 |
+
return input_devices, output_devices, input_devices_indices, output_devices_indices
|
470 |
+
|
471 |
+
def set_devices(self, input_device, output_device):
|
472 |
+
'''设置输出设备'''
|
473 |
+
input_devices, output_devices, input_device_indices, output_device_indices = self.get_devices()
|
474 |
+
sd.default.device[0] = input_device_indices[input_devices.index(input_device)]
|
475 |
+
sd.default.device[1] = output_device_indices[output_devices.index(output_device)]
|
476 |
+
print("input device:" + str(sd.default.device[0]) + ":" + str(input_device))
|
477 |
+
print("output device:" + str(sd.default.device[1]) + ":" + str(output_device))
|
478 |
+
|
479 |
+
|
480 |
+
|
481 |
+
if __name__ == "__main__":
|
482 |
+
i18n = gui_locale.I18nAuto()
|
483 |
+
gui = GUI()
|
DDSP-SVC/gui_diff.py
ADDED
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PySimpleGUI as sg
|
2 |
+
import sounddevice as sd
|
3 |
+
import torch, librosa, threading, pickle
|
4 |
+
from enhancer import Enhancer
|
5 |
+
import numpy as np
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from torchaudio.transforms import Resample
|
8 |
+
import torchaudio
|
9 |
+
from ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder
|
10 |
+
from ddsp.core import upsample
|
11 |
+
import time
|
12 |
+
from gui_diff_locale import I18nAuto
|
13 |
+
from diffusion.infer_gt_mel import DiffGtMel
|
14 |
+
|
15 |
+
|
16 |
+
def phase_vocoder(a, b, fade_out, fade_in):
|
17 |
+
fa = torch.fft.rfft(a)
|
18 |
+
fb = torch.fft.rfft(b)
|
19 |
+
absab = torch.abs(fa) + torch.abs(fb)
|
20 |
+
n = a.shape[0]
|
21 |
+
if n % 2 == 0:
|
22 |
+
absab[1:-1] *= 2
|
23 |
+
else:
|
24 |
+
absab[1:] *= 2
|
25 |
+
phia = torch.angle(fa)
|
26 |
+
phib = torch.angle(fb)
|
27 |
+
deltaphase = phib - phia
|
28 |
+
deltaphase = deltaphase - 2 * np.pi * torch.floor(deltaphase / 2 / np.pi + 0.5)
|
29 |
+
w = 2 * np.pi * torch.arange(n // 2 + 1).to(a) + deltaphase
|
30 |
+
t = torch.arange(n).unsqueeze(-1).to(a) / n
|
31 |
+
result = a * (fade_out ** 2) + b * (fade_in ** 2) + torch.sum(absab * torch.cos(w * t + phia),
|
32 |
+
-1) * fade_out * fade_in / n
|
33 |
+
return result
|
34 |
+
|
35 |
+
|
36 |
+
class SvcDDSP:
|
37 |
+
def __init__(self) -> None:
|
38 |
+
self.model = None
|
39 |
+
self.units_encoder = None
|
40 |
+
self.encoder_type = None
|
41 |
+
self.encoder_ckpt = None
|
42 |
+
self.enhancer = None
|
43 |
+
self.enhancer_type = None
|
44 |
+
self.enhancer_ckpt = None
|
45 |
+
|
46 |
+
def update_model(self, model_path):
|
47 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
48 |
+
|
49 |
+
# load ddsp model
|
50 |
+
if self.model is None or self.model_path != model_path:
|
51 |
+
self.model, self.args = load_model(model_path, device=self.device)
|
52 |
+
self.model_path = model_path
|
53 |
+
|
54 |
+
# load units encoder
|
55 |
+
if self.units_encoder is None or self.args.data.encoder != self.encoder_type or self.args.data.encoder_ckpt != self.encoder_ckpt:
|
56 |
+
if self.args.data.encoder == 'cnhubertsoftfish':
|
57 |
+
cnhubertsoft_gate = self.args.data.cnhubertsoft_gate
|
58 |
+
else:
|
59 |
+
cnhubertsoft_gate = 10
|
60 |
+
self.units_encoder = Units_Encoder(
|
61 |
+
self.args.data.encoder,
|
62 |
+
self.args.data.encoder_ckpt,
|
63 |
+
self.args.data.encoder_sample_rate,
|
64 |
+
self.args.data.encoder_hop_size,
|
65 |
+
cnhubertsoft_gate=cnhubertsoft_gate,
|
66 |
+
device=self.device)
|
67 |
+
self.encoder_type = self.args.data.encoder
|
68 |
+
self.encoder_ckpt = self.args.data.encoder_ckpt
|
69 |
+
|
70 |
+
# load enhancer
|
71 |
+
if self.enhancer is None or self.args.enhancer.type != self.enhancer_type or self.args.enhancer.ckpt != self.enhancer_ckpt:
|
72 |
+
self.enhancer = Enhancer(self.args.enhancer.type, self.args.enhancer.ckpt, device=self.device)
|
73 |
+
self.enhancer_type = self.args.enhancer.type
|
74 |
+
self.enhancer_ckpt = self.args.enhancer.ckpt
|
75 |
+
|
76 |
+
def infer(self,
|
77 |
+
audio,
|
78 |
+
sample_rate,
|
79 |
+
spk_id=1,
|
80 |
+
threhold=-45,
|
81 |
+
pitch_adjust=0,
|
82 |
+
use_spk_mix=False,
|
83 |
+
spk_mix_dict=None,
|
84 |
+
use_enhancer=True,
|
85 |
+
enhancer_adaptive_key='auto',
|
86 |
+
pitch_extractor_type='crepe',
|
87 |
+
f0_min=50,
|
88 |
+
f0_max=1100,
|
89 |
+
safe_prefix_pad_length=0,
|
90 |
+
diff_model=None,
|
91 |
+
diff_acc=None,
|
92 |
+
diff_spk_id=None,
|
93 |
+
diff_use=False,
|
94 |
+
diff_use_dpm=False,
|
95 |
+
k_step=None,
|
96 |
+
diff_silence=False,
|
97 |
+
audio_alignment=False
|
98 |
+
):
|
99 |
+
print("Infering...")
|
100 |
+
# load input
|
101 |
+
# audio, sample_rate = librosa.load(input_wav, sr=None, mono=True)
|
102 |
+
hop_size = self.args.data.block_size * sample_rate / self.args.data.sampling_rate
|
103 |
+
if audio_alignment:
|
104 |
+
audio_length = len(audio)
|
105 |
+
# safe front silence
|
106 |
+
if safe_prefix_pad_length > 0.03:
|
107 |
+
silence_front = safe_prefix_pad_length - 0.03
|
108 |
+
else:
|
109 |
+
silence_front = 0
|
110 |
+
audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(self.device)
|
111 |
+
|
112 |
+
# extract f0
|
113 |
+
pitch_extractor = F0_Extractor(
|
114 |
+
pitch_extractor_type,
|
115 |
+
sample_rate,
|
116 |
+
hop_size,
|
117 |
+
float(f0_min),
|
118 |
+
float(f0_max))
|
119 |
+
f0 = pitch_extractor.extract(audio, uv_interp=True, device=self.device, silence_front=silence_front)
|
120 |
+
f0 = torch.from_numpy(f0).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
121 |
+
f0 = f0 * 2 ** (float(pitch_adjust) / 12)
|
122 |
+
|
123 |
+
# extract volume
|
124 |
+
volume_extractor = Volume_Extractor(hop_size)
|
125 |
+
volume = volume_extractor.extract(audio)
|
126 |
+
mask = (volume > 10 ** (float(threhold) / 20)).astype('float')
|
127 |
+
mask = np.pad(mask, (4, 4), constant_values=(mask[0], mask[-1]))
|
128 |
+
mask = np.array([np.max(mask[n: n + 9]) for n in range(len(mask) - 8)])
|
129 |
+
mask = torch.from_numpy(mask).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
130 |
+
mask = upsample(mask, self.args.data.block_size).squeeze(-1)
|
131 |
+
volume = torch.from_numpy(volume).float().to(self.device).unsqueeze(-1).unsqueeze(0)
|
132 |
+
|
133 |
+
# extract units
|
134 |
+
units = self.units_encoder.encode(audio_t, sample_rate, hop_size)
|
135 |
+
|
136 |
+
# spk_id or spk_mix_dict
|
137 |
+
spk_id = torch.LongTensor(np.array([[spk_id]])).to(self.device)
|
138 |
+
diff_spk_id = torch.LongTensor(np.array([[diff_spk_id]])).to(self.device)
|
139 |
+
dictionary = None
|
140 |
+
if use_spk_mix:
|
141 |
+
dictionary = spk_mix_dict
|
142 |
+
|
143 |
+
# forward and return the output
|
144 |
+
with torch.no_grad():
|
145 |
+
output, _, (s_h, s_n) = self.model(units, f0, volume, spk_id=spk_id, spk_mix_dict=dictionary)
|
146 |
+
if diff_use and diff_model is not None:
|
147 |
+
output = diff_model.infer(output, f0, units, volume, acc=diff_acc, spk_id=diff_spk_id,
|
148 |
+
k_step=k_step, use_dpm=diff_use_dpm, silence_front=silence_front, use_silence=diff_silence,
|
149 |
+
spk_mix_dict=dictionary)
|
150 |
+
output *= mask
|
151 |
+
if use_enhancer and not diff_use:
|
152 |
+
output, output_sample_rate = self.enhancer.enhance(
|
153 |
+
output,
|
154 |
+
self.args.data.sampling_rate,
|
155 |
+
f0,
|
156 |
+
self.args.data.block_size,
|
157 |
+
adaptive_key=enhancer_adaptive_key,
|
158 |
+
silence_front=silence_front)
|
159 |
+
else:
|
160 |
+
output_sample_rate = self.args.data.sampling_rate
|
161 |
+
|
162 |
+
output = output.squeeze()
|
163 |
+
if audio_alignment:
|
164 |
+
output[:audio_length]
|
165 |
+
return output, output_sample_rate
|
166 |
+
|
167 |
+
|
168 |
+
class Config:
|
169 |
+
def __init__(self) -> None:
|
170 |
+
self.samplerate = 44100 # Hz
|
171 |
+
self.block_time = 1.5 # s
|
172 |
+
self.f_pitch_change: float = 0.0 # float(request_form.get("fPitchChange", 0))
|
173 |
+
self.spk_id = 1 # 默认说话人。
|
174 |
+
self.spk_mix_dict = None # {1:0.5, 2:0.5} 表示1号说话人和2号说话人的音色按照0.5:0.5的比例混合
|
175 |
+
self.use_vocoder_based_enhancer = True
|
176 |
+
self.use_phase_vocoder = True
|
177 |
+
self.checkpoint_path = ''
|
178 |
+
self.threhold = -35
|
179 |
+
self.buffer_num = 2
|
180 |
+
self.crossfade_time = 0.03
|
181 |
+
self.select_pitch_extractor = 'harvest' # F0预测器["parselmouth", "dio", "harvest", "crepe"]
|
182 |
+
self.use_spk_mix = False
|
183 |
+
self.sounddevices = ['', '']
|
184 |
+
self.diff_use = False
|
185 |
+
self.diff_project = ''
|
186 |
+
self.diff_acc = 10
|
187 |
+
self.diff_spk_id = 0
|
188 |
+
self.k_step = 100
|
189 |
+
self.diff_use_dpm = False
|
190 |
+
self.diff_silence = False
|
191 |
+
|
192 |
+
def save(self, path):
|
193 |
+
with open(path + '\\config.pkl', 'wb') as f:
|
194 |
+
pickle.dump(vars(self), f)
|
195 |
+
|
196 |
+
def load(self, path) -> bool:
|
197 |
+
try:
|
198 |
+
with open(path + '\\config.pkl', 'rb') as f:
|
199 |
+
self.update(pickle.load(f))
|
200 |
+
return True
|
201 |
+
except:
|
202 |
+
print('config.pkl does not exist')
|
203 |
+
return False
|
204 |
+
|
205 |
+
def update(self, data_dict):
|
206 |
+
for key, value in data_dict.items():
|
207 |
+
setattr(self, key, value)
|
208 |
+
|
209 |
+
class GUI:
|
210 |
+
def __init__(self) -> None:
|
211 |
+
self.config = Config()
|
212 |
+
self.flag_vc: bool = False # 变声线程flag
|
213 |
+
self.block_frame = 0
|
214 |
+
self.crossfade_frame = 0
|
215 |
+
self.sola_search_frame = 0
|
216 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
217 |
+
self.svc_model: SvcDDSP = SvcDDSP()
|
218 |
+
self.diff_model: DiffGtMel = DiffGtMel()
|
219 |
+
self.fade_in_window: np.ndarray = None # crossfade计算用numpy数组
|
220 |
+
self.fade_out_window: np.ndarray = None # crossfade计算用numpy数组
|
221 |
+
self.input_wav: np.ndarray = None # 输入音频规范化后的保存地址
|
222 |
+
self.output_wav: np.ndarray = None # 输出音频规范化后的保存地址
|
223 |
+
self.sola_buffer: torch.Tensor = None # 保存上一个output的crossfade
|
224 |
+
self.f0_mode_list = ["parselmouth", "dio", "harvest", "crepe"] # F0预测器
|
225 |
+
self.f_safe_prefix_pad_length: float = 0.0
|
226 |
+
self.resample_kernel = {}
|
227 |
+
self.launcher() # start
|
228 |
+
|
229 |
+
def launcher(self):
|
230 |
+
'''窗口加载'''
|
231 |
+
input_devices, output_devices, _, _ = self.get_devices()
|
232 |
+
sg.theme('DarkBlue12') # 设置主题
|
233 |
+
# 界面布局
|
234 |
+
layout = [
|
235 |
+
[sg.Frame(layout=[
|
236 |
+
[sg.Input(key='sg_model', default_text='exp\\combsub-test\\model_300000.pt'),
|
237 |
+
sg.FileBrowse(i18n('选择模型文件'), key='choose_model')]
|
238 |
+
], title=i18n('模型:.pt格式(自动识别同目录下config.yaml)')),
|
239 |
+
sg.Frame(layout=[
|
240 |
+
[sg.Text(i18n('选择配置文件所在目录')), sg.Input(key='config_file_dir', default_text='exp'),
|
241 |
+
sg.FolderBrowse(i18n('打开文件夹'), key='choose_config')],
|
242 |
+
[sg.Button(i18n('读取配置文件'), key='load_config'),
|
243 |
+
sg.Button(i18n('保存配置文件'), key='save_config')]
|
244 |
+
], title=i18n('快速配置文件'))
|
245 |
+
],
|
246 |
+
[sg.Frame(layout=[
|
247 |
+
[sg.Text(i18n("输入设备")),
|
248 |
+
sg.Combo(input_devices, key='sg_input_device', default_value=input_devices[sd.default.device[0]],
|
249 |
+
enable_events=True)],
|
250 |
+
[sg.Text(i18n("输出设备")),
|
251 |
+
sg.Combo(output_devices, key='sg_output_device', default_value=output_devices[sd.default.device[1]],
|
252 |
+
enable_events=True)]
|
253 |
+
], title=i18n('音频设备'))
|
254 |
+
],
|
255 |
+
[sg.Frame(layout=[
|
256 |
+
[sg.Text(i18n("说话人id")), sg.Input(key='spk_id', default_text='1', size=8)],
|
257 |
+
[sg.Text(i18n("响应阈值")),
|
258 |
+
sg.Slider(range=(-60, 0), orientation='h', key='threhold', resolution=1, default_value=-45,
|
259 |
+
enable_events=True)],
|
260 |
+
[sg.Text(i18n("变调")),
|
261 |
+
sg.Slider(range=(-24, 24), orientation='h', key='pitch', resolution=1, default_value=0,
|
262 |
+
enable_events=True)],
|
263 |
+
[sg.Text(i18n("采样率")), sg.Input(key='samplerate', default_text='44100', size=8)],
|
264 |
+
[sg.Checkbox(text=i18n('启用捏音色功能'), default=False, key='spk_mix', enable_events=True),
|
265 |
+
sg.Button(i18n("设置混合音色"), key='set_spk_mix')]
|
266 |
+
], title=i18n('普通设置')),
|
267 |
+
sg.Frame(layout=[
|
268 |
+
[sg.Text(i18n("音频切分大小")),
|
269 |
+
sg.Slider(range=(0.05, 3.0), orientation='h', key='block', resolution=0.01, default_value=0.5,
|
270 |
+
enable_events=True)],
|
271 |
+
[sg.Text(i18n("交叉淡化时长")),
|
272 |
+
sg.Slider(range=(0.01, 0.15), orientation='h', key='crossfade', resolution=0.01,
|
273 |
+
default_value=0.04, enable_events=True)],
|
274 |
+
[sg.Text(i18n("使用历史区块数量")),
|
275 |
+
sg.Slider(range=(1, 20), orientation='h', key='buffernum', resolution=1, default_value=3,
|
276 |
+
enable_events=True)],
|
277 |
+
[sg.Text(i18n("f0预测模式")),
|
278 |
+
sg.Combo(values=self.f0_mode_list, key='f0_mode', default_value=self.f0_mode_list[2],
|
279 |
+
enable_events=True)],
|
280 |
+
[sg.Checkbox(text=i18n('启用增强器'), default=True, key='use_enhancer', enable_events=True),
|
281 |
+
sg.Checkbox(text=i18n('启用相位声码器'), default=False, key='use_phase_vocoder',
|
282 |
+
enable_events=True)]
|
283 |
+
], title=i18n('性能设置')),
|
284 |
+
sg.Frame(layout=[
|
285 |
+
[sg.Text(i18n("扩散模型文件"))],
|
286 |
+
[sg.Input(key='diff_project', default_text='exp\\diffusion-test\\model_400000.pt'),
|
287 |
+
sg.FileBrowse(i18n('选择模型文件'), key='choose_model')],
|
288 |
+
[sg.Text(i18n("扩散说话人id")), sg.Input(key='diff_spk_id', default_text='1', size=18)],
|
289 |
+
[sg.Text(i18n("扩散深度")), sg.Input(key='k_step', default_text='120', size=18)],
|
290 |
+
[sg.Text(i18n("扩散加速")), sg.Input(key='diff_acc', default_text='20', size=18)],
|
291 |
+
[sg.Checkbox(text=i18n('启用DPMs(推荐)'), default=False, key='diff_use_dpm', enable_events=True)],
|
292 |
+
[sg.Checkbox(text=i18n('启用扩散'), default=True, key='diff_use', enable_events=True),
|
293 |
+
sg.Checkbox(text=i18n('不扩散安全区(加速但损失效果)'), default=False, key='diff_silence', enable_events=True)]
|
294 |
+
], title=i18n('扩散设置')),
|
295 |
+
],
|
296 |
+
[sg.Button(i18n("开始音频转换"), key="start_vc"), sg.Button(i18n("停止音频转换"), key="stop_vc"),
|
297 |
+
sg.Text(i18n('推理所用时间(ms):')), sg.Text('0', key='infer_time')]
|
298 |
+
]
|
299 |
+
|
300 |
+
# 创造窗口
|
301 |
+
self.window = sg.Window('DDSP - GUI', layout, finalize=True)
|
302 |
+
self.window['spk_id'].bind('<Return>', '')
|
303 |
+
self.window['samplerate'].bind('<Return>', '')
|
304 |
+
self.window['diff_spk_id'].bind('<Return>', '')
|
305 |
+
self.window['k_step'].bind('<Return>', '')
|
306 |
+
self.window['diff_acc'].bind('<Return>', '')
|
307 |
+
self.event_handler()
|
308 |
+
|
309 |
+
def event_handler(self):
|
310 |
+
'''事件处理'''
|
311 |
+
while True: # 事件处理循环
|
312 |
+
event, values = self.window.read()
|
313 |
+
if event == sg.WINDOW_CLOSED: # 如果用户关闭窗口
|
314 |
+
self.flag_vc = False
|
315 |
+
exit()
|
316 |
+
|
317 |
+
print('event: ' + event)
|
318 |
+
|
319 |
+
if event == 'start_vc' and self.flag_vc == False:
|
320 |
+
# set values 和界面布局layout顺序一一对应
|
321 |
+
self.set_values(values)
|
322 |
+
print('crossfade_time:' + str(self.config.crossfade_time))
|
323 |
+
print("buffer_num:" + str(self.config.buffer_num))
|
324 |
+
print("samplerate:" + str(self.config.samplerate))
|
325 |
+
print('block_time:' + str(self.config.block_time))
|
326 |
+
print("prefix_pad_length:" + str(self.f_safe_prefix_pad_length))
|
327 |
+
print("mix_mode:" + str(self.config.spk_mix_dict))
|
328 |
+
print("enhancer:" + str(self.config.use_vocoder_based_enhancer))
|
329 |
+
print("diffusion:" + str(self.config.diff_use))
|
330 |
+
print('using_cuda:' + str(torch.cuda.is_available()))
|
331 |
+
self.start_vc()
|
332 |
+
elif event == 'k_step':
|
333 |
+
if 1 <= int(values['k_step']) <= 1000:
|
334 |
+
self.config.k_step = int(values['k_step'])
|
335 |
+
else:
|
336 |
+
self.window['k_step'].update(1000)
|
337 |
+
elif event == 'diff_acc':
|
338 |
+
if self.config.k_step < int(values['diff_acc']):
|
339 |
+
self.config.diff_acc = int(self.config.k_step / 4)
|
340 |
+
else:
|
341 |
+
self.config.diff_acc = int(values['diff_acc'])
|
342 |
+
elif event == 'diff_spk_id':
|
343 |
+
self.config.diff_spk_id = int(values['diff_spk_id'])
|
344 |
+
elif event == 'diff_use':
|
345 |
+
self.config.diff_use = values['diff_use']
|
346 |
+
self.window['use_enhancer'].update(False)
|
347 |
+
self.config.use_vocoder_based_enhancer=False
|
348 |
+
elif event == 'diff_silence':
|
349 |
+
self.config.diff_silence = values['diff_silence']
|
350 |
+
elif event == 'diff_use_dpm':
|
351 |
+
self.config.diff_use_dpm = values['diff_use_dpm']
|
352 |
+
elif event == 'spk_id':
|
353 |
+
self.config.spk_id = int(values['spk_id'])
|
354 |
+
elif event == 'threhold':
|
355 |
+
self.config.threhold = values['threhold']
|
356 |
+
elif event == 'pitch':
|
357 |
+
self.config.f_pitch_change = values['pitch']
|
358 |
+
elif event == 'spk_mix':
|
359 |
+
self.config.use_spk_mix = values['spk_mix']
|
360 |
+
elif event == 'set_spk_mix':
|
361 |
+
spk_mix = sg.popup_get_text(message='示例:1:0.3,2:0.5,3:0.2', title="设置混合音色,支持多人")
|
362 |
+
if spk_mix != None:
|
363 |
+
self.config.spk_mix_dict = eval("{" + spk_mix.replace(',', ',').replace(':', ':') + "}")
|
364 |
+
elif event == 'f0_mode':
|
365 |
+
self.config.select_pitch_extractor = values['f0_mode']
|
366 |
+
elif event == 'use_enhancer':
|
367 |
+
self.config.use_vocoder_based_enhancer = values['use_enhancer']
|
368 |
+
self.window['diff_use'].update(False)
|
369 |
+
self.config.diff_use = False
|
370 |
+
elif event == 'use_phase_vocoder':
|
371 |
+
self.config.use_phase_vocoder = values['use_phase_vocoder']
|
372 |
+
elif event == 'load_config' and self.flag_vc == False:
|
373 |
+
if self.config.load(values['config_file_dir']):
|
374 |
+
self.update_values()
|
375 |
+
elif event == 'save_config' and self.flag_vc == False:
|
376 |
+
self.set_values(values)
|
377 |
+
self.config.save(values['config_file_dir'])
|
378 |
+
elif event != 'start_vc' and self.flag_vc == True:
|
379 |
+
self.flag_vc = False
|
380 |
+
|
381 |
+
def set_values(self, values):
|
382 |
+
self.set_devices(values["sg_input_device"], values['sg_output_device'])
|
383 |
+
self.config.sounddevices = [values["sg_input_device"], values['sg_output_device']]
|
384 |
+
self.config.checkpoint_path = values['sg_model']
|
385 |
+
self.config.spk_id = int(values['spk_id'])
|
386 |
+
self.config.threhold = values['threhold']
|
387 |
+
self.config.f_pitch_change = values['pitch']
|
388 |
+
self.config.samplerate = int(values['samplerate'])
|
389 |
+
self.config.block_time = float(values['block'])
|
390 |
+
self.config.crossfade_time = float(values['crossfade'])
|
391 |
+
self.config.buffer_num = int(values['buffernum'])
|
392 |
+
self.config.select_pitch_extractor = values['f0_mode']
|
393 |
+
self.config.use_vocoder_based_enhancer = values['use_enhancer']
|
394 |
+
self.config.use_phase_vocoder = values['use_phase_vocoder']
|
395 |
+
self.config.use_spk_mix = values['spk_mix']
|
396 |
+
self.config.diff_use = values['diff_use']
|
397 |
+
self.config.diff_silence = values['diff_silence']
|
398 |
+
self.config.diff_use_dpm = values['diff_use_dpm']
|
399 |
+
self.config.diff_project = values['diff_project']
|
400 |
+
self.config.diff_acc = int(values['diff_acc'])
|
401 |
+
self.config.diff_spk_id = int(values['diff_spk_id'])
|
402 |
+
self.config.k_step = int(values['k_step'])
|
403 |
+
self.block_frame = int(self.config.block_time * self.config.samplerate)
|
404 |
+
self.crossfade_frame = int(self.config.crossfade_time * self.config.samplerate)
|
405 |
+
self.sola_search_frame = int(0.01 * self.config.samplerate)
|
406 |
+
self.last_delay_frame = int(0.02 * self.config.samplerate)
|
407 |
+
self.input_frames = max(
|
408 |
+
self.block_frame + self.crossfade_frame + self.sola_search_frame + 2 * self.last_delay_frame,
|
409 |
+
(1 + self.config.buffer_num) * self.block_frame)
|
410 |
+
self.f_safe_prefix_pad_length = self.config.block_time * self.config.buffer_num - self.config.crossfade_time - 0.01 - 0.02
|
411 |
+
|
412 |
+
def update_values(self):
|
413 |
+
self.window['sg_model'].update(self.config.checkpoint_path)
|
414 |
+
self.window['sg_input_device'].update(self.config.sounddevices[0])
|
415 |
+
self.window['sg_output_device'].update(self.config.sounddevices[1])
|
416 |
+
self.window['spk_id'].update(self.config.spk_id)
|
417 |
+
self.window['threhold'].update(self.config.threhold)
|
418 |
+
self.window['pitch'].update(self.config.f_pitch_change)
|
419 |
+
self.window['samplerate'].update(self.config.samplerate)
|
420 |
+
self.window['spk_mix'].update(self.config.use_spk_mix)
|
421 |
+
self.window['block'].update(self.config.block_time)
|
422 |
+
self.window['crossfade'].update(self.config.crossfade_time)
|
423 |
+
self.window['buffernum'].update(self.config.buffer_num)
|
424 |
+
self.window['f0_mode'].update(self.config.select_pitch_extractor)
|
425 |
+
self.window['use_enhancer'].update(self.config.use_vocoder_based_enhancer)
|
426 |
+
self.window['diff_use'].update(self.config.diff_use)
|
427 |
+
self.window['diff_silence'].update(self.config.diff_silence)
|
428 |
+
self.window['diff_use_dpm'].update(self.config.diff_use_dpm)
|
429 |
+
self.window['diff_project'].update(self.config.diff_project)
|
430 |
+
self.window['diff_acc'].update(self.config.diff_acc)
|
431 |
+
self.window['diff_spk_id'].update(self.config.diff_spk_id)
|
432 |
+
self.window['k_step'].update(self.config.k_step)
|
433 |
+
|
434 |
+
def start_vc(self):
|
435 |
+
'''开始音频转换'''
|
436 |
+
torch.cuda.empty_cache()
|
437 |
+
self.flag_vc = True
|
438 |
+
self.input_wav = np.zeros(self.input_frames, dtype='float32')
|
439 |
+
self.sola_buffer = torch.zeros(self.crossfade_frame, device=self.device)
|
440 |
+
self.fade_in_window = torch.sin(
|
441 |
+
np.pi * torch.arange(0, 1, 1 / self.crossfade_frame, device=self.device) / 2) ** 2
|
442 |
+
self.fade_out_window = 1 - self.fade_in_window
|
443 |
+
self.svc_model.update_model(self.config.checkpoint_path)
|
444 |
+
if self.config.diff_use:
|
445 |
+
self.diff_model.flush_model(self.config.diff_project, ddsp_config=self.svc_model.args)
|
446 |
+
thread_vc = threading.Thread(target=self.soundinput)
|
447 |
+
thread_vc.start()
|
448 |
+
|
449 |
+
def soundinput(self):
|
450 |
+
'''
|
451 |
+
接受音频输入
|
452 |
+
'''
|
453 |
+
with sd.Stream(callback=self.audio_callback, blocksize=self.block_frame, samplerate=self.config.samplerate,
|
454 |
+
dtype='float32'):
|
455 |
+
while self.flag_vc:
|
456 |
+
time.sleep(self.config.block_time)
|
457 |
+
print('Audio block passed.')
|
458 |
+
print('ENDing VC')
|
459 |
+
|
460 |
+
def audio_callback(self, indata: np.ndarray, outdata: np.ndarray, frames, times, status):
|
461 |
+
'''
|
462 |
+
音频处理
|
463 |
+
'''
|
464 |
+
start_time = time.perf_counter()
|
465 |
+
print("\nStarting callback")
|
466 |
+
self.input_wav[:] = np.roll(self.input_wav, -self.block_frame)
|
467 |
+
self.input_wav[-self.block_frame:] = librosa.to_mono(indata.T)
|
468 |
+
|
469 |
+
# infer
|
470 |
+
if self.config.diff_use:
|
471 |
+
_diff_model = self.diff_model
|
472 |
+
else:
|
473 |
+
_diff_model = None
|
474 |
+
_audio, _model_sr = self.svc_model.infer(
|
475 |
+
self.input_wav,
|
476 |
+
self.config.samplerate,
|
477 |
+
spk_id=self.config.spk_id,
|
478 |
+
threhold=self.config.threhold,
|
479 |
+
pitch_adjust=self.config.f_pitch_change,
|
480 |
+
use_spk_mix=self.config.use_spk_mix,
|
481 |
+
spk_mix_dict=self.config.spk_mix_dict,
|
482 |
+
use_enhancer=self.config.use_vocoder_based_enhancer,
|
483 |
+
pitch_extractor_type=self.config.select_pitch_extractor,
|
484 |
+
safe_prefix_pad_length=self.f_safe_prefix_pad_length,
|
485 |
+
diff_model=_diff_model,
|
486 |
+
diff_acc=self.config.diff_acc,
|
487 |
+
diff_spk_id=self.config.diff_spk_id,
|
488 |
+
diff_use=self.config.diff_use,
|
489 |
+
diff_use_dpm=self.config.diff_use_dpm,
|
490 |
+
k_step=self.config.k_step,
|
491 |
+
diff_silence=self.config.diff_silence
|
492 |
+
)
|
493 |
+
|
494 |
+
# debug sola
|
495 |
+
'''
|
496 |
+
_audio, _model_sr = self.input_wav, self.config.samplerate
|
497 |
+
rs = int(np.random.uniform(-200,200))
|
498 |
+
print('debug_random_shift: ' + str(rs))
|
499 |
+
_audio = np.roll(_audio, rs)
|
500 |
+
_audio = torch.from_numpy(_audio).to(self.device)
|
501 |
+
'''
|
502 |
+
|
503 |
+
if _model_sr != self.config.samplerate:
|
504 |
+
key_str = str(_model_sr) + '_' + str(self.config.samplerate)
|
505 |
+
if key_str not in self.resample_kernel:
|
506 |
+
self.resample_kernel[key_str] = Resample(_model_sr, self.config.samplerate,
|
507 |
+
lowpass_filter_width=128).to(self.device)
|
508 |
+
_audio = self.resample_kernel[key_str](_audio)
|
509 |
+
temp_wav = _audio[
|
510 |
+
- self.block_frame - self.crossfade_frame - self.sola_search_frame - self.last_delay_frame: - self.last_delay_frame]
|
511 |
+
|
512 |
+
# sola shift
|
513 |
+
conv_input = temp_wav[None, None, : self.crossfade_frame + self.sola_search_frame]
|
514 |
+
cor_nom = F.conv1d(conv_input, self.sola_buffer[None, None, :])
|
515 |
+
cor_den = torch.sqrt(
|
516 |
+
F.conv1d(conv_input ** 2, torch.ones(1, 1, self.crossfade_frame, device=self.device)) + 1e-8)
|
517 |
+
sola_shift = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
|
518 |
+
temp_wav = temp_wav[sola_shift: sola_shift + self.block_frame + self.crossfade_frame]
|
519 |
+
print('sola_shift: ' + str(int(sola_shift)))
|
520 |
+
|
521 |
+
# phase vocoder
|
522 |
+
if self.config.use_phase_vocoder:
|
523 |
+
temp_wav[: self.crossfade_frame] = phase_vocoder(
|
524 |
+
self.sola_buffer,
|
525 |
+
temp_wav[: self.crossfade_frame],
|
526 |
+
self.fade_out_window,
|
527 |
+
self.fade_in_window)
|
528 |
+
else:
|
529 |
+
temp_wav[: self.crossfade_frame] *= self.fade_in_window
|
530 |
+
temp_wav[: self.crossfade_frame] += self.sola_buffer * self.fade_out_window
|
531 |
+
|
532 |
+
self.sola_buffer = temp_wav[- self.crossfade_frame:]
|
533 |
+
|
534 |
+
outdata[:] = temp_wav[: - self.crossfade_frame, None].repeat(1, 2).cpu().numpy()
|
535 |
+
end_time = time.perf_counter()
|
536 |
+
print('infer_time: ' + str(end_time - start_time))
|
537 |
+
self.window['infer_time'].update(int((end_time - start_time) * 1000))
|
538 |
+
|
539 |
+
def get_devices(self, update: bool = True):
|
540 |
+
'''获取设备列表'''
|
541 |
+
if update:
|
542 |
+
sd._terminate()
|
543 |
+
sd._initialize()
|
544 |
+
devices = sd.query_devices()
|
545 |
+
hostapis = sd.query_hostapis()
|
546 |
+
for hostapi in hostapis:
|
547 |
+
for device_idx in hostapi["devices"]:
|
548 |
+
devices[device_idx]["hostapi_name"] = hostapi["name"]
|
549 |
+
input_devices = [
|
550 |
+
f"{d['name']} ({d['hostapi_name']})"
|
551 |
+
for d in devices
|
552 |
+
if d["max_input_channels"] > 0
|
553 |
+
]
|
554 |
+
output_devices = [
|
555 |
+
f"{d['name']} ({d['hostapi_name']})"
|
556 |
+
for d in devices
|
557 |
+
if d["max_output_channels"] > 0
|
558 |
+
]
|
559 |
+
input_devices_indices = [d["index"] for d in devices if d["max_input_channels"] > 0]
|
560 |
+
output_devices_indices = [
|
561 |
+
d["index"] for d in devices if d["max_output_channels"] > 0
|
562 |
+
]
|
563 |
+
return input_devices, output_devices, input_devices_indices, output_devices_indices
|
564 |
+
|
565 |
+
def set_devices(self, input_device, output_device):
|
566 |
+
'''设置输出设备'''
|
567 |
+
input_devices, output_devices, input_device_indices, output_device_indices = self.get_devices()
|
568 |
+
sd.default.device[0] = input_device_indices[input_devices.index(input_device)]
|
569 |
+
sd.default.device[1] = output_device_indices[output_devices.index(output_device)]
|
570 |
+
print("input device:" + str(sd.default.device[0]) + ":" + str(input_device))
|
571 |
+
print("output device:" + str(sd.default.device[1]) + ":" + str(output_device))
|
572 |
+
|
573 |
+
|
574 |
+
if __name__ == "__main__":
|
575 |
+
i18n = I18nAuto()
|
576 |
+
gui = GUI()
|
DDSP-SVC/gui_diff_locale.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import locale
|
2 |
+
'''
|
3 |
+
本地化方式如下所示
|
4 |
+
'''
|
5 |
+
|
6 |
+
LANGUAGE_LIST = ['zh_CN', 'en_US', 'ja_JP']
|
7 |
+
LANGUAGE_ALL = {
|
8 |
+
'zh_CN': {
|
9 |
+
'SUPER': 'END',
|
10 |
+
'LANGUAGE': 'zh_CN',
|
11 |
+
'选择模型文件': '选择模型文件',
|
12 |
+
'模型:.pt格式(自动识别同目录下config.yaml)': '模型:.pt格式(自动识别同目录下config.yaml)',
|
13 |
+
'选择配置文件所在目录': '选择配置文件所在目录',
|
14 |
+
'打开文件夹': '打开文件夹',
|
15 |
+
'读取配置文件': '读取配置文件',
|
16 |
+
'保存配置文件': '保存配置文件',
|
17 |
+
'快速配置文件': '快速配置文件',
|
18 |
+
'输入设备': '输入设备',
|
19 |
+
'输出设备': '输出设备',
|
20 |
+
'音频设备': '音频设备',
|
21 |
+
'说话人id': '说话人id',
|
22 |
+
'响应阈值': '响应阈值',
|
23 |
+
'变调': '变调',
|
24 |
+
'采样率': '采样率',
|
25 |
+
'启用捏音色功能': '启用捏音色功能',
|
26 |
+
'设置混合音色': '设置混合音色',
|
27 |
+
'普通设置': '普通设置',
|
28 |
+
'音频切分大小': '音频切分大小',
|
29 |
+
'交叉淡化时长': '交叉淡化时长',
|
30 |
+
'使用历史区块数量': '使用历史区块数量',
|
31 |
+
'f0预测模式': 'f0预测模式',
|
32 |
+
'启用增强器': '启用增强器',
|
33 |
+
'启用相位声码器': '启用相位声码器',
|
34 |
+
'性能设置': '性能设置',
|
35 |
+
'开始音频转换': '开始音频转换',
|
36 |
+
'停止音频转换': '停止音频转换',
|
37 |
+
'推理所用时间(ms):': '推理所用时间(ms):',
|
38 |
+
'扩散设置': '扩散设置',
|
39 |
+
'启用扩散': '启用扩散',
|
40 |
+
'扩散加速': '扩散加速',
|
41 |
+
'扩散深度': '扩散深度',
|
42 |
+
'扩散说话人id': '扩散说话人id',
|
43 |
+
'扩散模型文件': '扩散模型文件',
|
44 |
+
'不扩散安全区(加速但损失效果)': '不扩散安全区(加速但损失效果)',
|
45 |
+
'启用DPMs(推荐)': '启用DPMs(推荐)'
|
46 |
+
},
|
47 |
+
'en_US': {
|
48 |
+
'SUPER': 'zh_CN',
|
49 |
+
'LANGUAGE': 'en_US',
|
50 |
+
'选择模型文件': 'Select Model File',
|
51 |
+
'模型:.pt格式(自动识别同目录下config.yaml)': 'Model:.pt format(Auto ust config.yaml in here)',
|
52 |
+
'选择配置文件所在目录': 'Select the configuration file directory',
|
53 |
+
'打开文件夹': 'Open folder',
|
54 |
+
'读取配置文件': 'Read config file',
|
55 |
+
'保存配置文件': 'Save config file',
|
56 |
+
'快速配置文件': 'Fast config file',
|
57 |
+
'输入设备': 'Input device',
|
58 |
+
'输出设备': 'Output device',
|
59 |
+
'音频设备': 'Audio devices',
|
60 |
+
'说话人id': 'Speaker ID',
|
61 |
+
'响应阈值': 'Response threshold',
|
62 |
+
'变调': 'Pitch',
|
63 |
+
'采样率': 'Sampling rate',
|
64 |
+
'启用捏音色功能': 'Enable Mix Speaker',
|
65 |
+
'设置混合音色': 'Mix Speaker',
|
66 |
+
'普通设置': 'Normal Settings',
|
67 |
+
'音频切分大小': 'Segmentation size',
|
68 |
+
'交叉淡化时长': 'Cross fade duration',
|
69 |
+
'使用历史区块数量': 'Historical blocks used',
|
70 |
+
'f0预测模式': 'f0Extractor',
|
71 |
+
'启用增强器': 'Enable Enhancer',
|
72 |
+
'启用相位声码器': 'Enable Phase Vocoder',
|
73 |
+
'性能设置': 'Performance settings',
|
74 |
+
'开始音频转换': 'Start conversion',
|
75 |
+
'停止音频转换': 'Stop conversion',
|
76 |
+
'推理所用时间(ms):': 'Inference time(ms):',
|
77 |
+
'扩散设置': '扩散设置',
|
78 |
+
'启用扩散': '启用扩散',
|
79 |
+
'扩散加速': '扩散加速',
|
80 |
+
'扩散深度': '扩散深度',
|
81 |
+
'扩散说话人id': '扩散说话人id',
|
82 |
+
'扩散模型文件': '扩散模型文件',
|
83 |
+
'不扩散安全区(加速但损失效果)': '不扩散安全区(加速但损失效果)',
|
84 |
+
'启用DPMs(推荐)': '启用DPMs(推荐)'
|
85 |
+
},
|
86 |
+
'ja_JP': {
|
87 |
+
'SUPER': 'zh_CN',
|
88 |
+
'LANGUAGE': 'ja_JP',
|
89 |
+
'选择模型文件': 'モデルを選択',
|
90 |
+
'模型:.pt格式(自动识别同目录下config.yaml)': 'モデル:.pt形式(同じディレクトリにあるconfig.yamlを自動認識します)',
|
91 |
+
'选择配置文件所在目录': '設定ファイルを選択',
|
92 |
+
'打开文件夹': 'フォルダを開く',
|
93 |
+
'读取配置文件': '設定ファイルを読み込む',
|
94 |
+
'保存配置文件': '設定ファイルを保存',
|
95 |
+
'快速配置文件': '設定プロファイル',
|
96 |
+
'输入设备': '入力デバイス',
|
97 |
+
'输出设备': '出力デバイス',
|
98 |
+
'音频设备': '音声デバイス',
|
99 |
+
'说话人id': '話者ID',
|
100 |
+
'响应阈值': '応答時の閾値',
|
101 |
+
'变调': '音程',
|
102 |
+
'采样率': 'サンプリングレート',
|
103 |
+
'启用捏音色功能': 'ミキシングを有効化',
|
104 |
+
'设置混合音色': 'ミキシング',
|
105 |
+
'普通设置': '通常設定',
|
106 |
+
'音频切分大小': 'セグメンテーションのサイズ',
|
107 |
+
'交叉淡化���长': 'クロスフェードの間隔',
|
108 |
+
'使用历史区块数量': '使用するヒストリカルブロック数',
|
109 |
+
'f0预测模式': 'f0予測モデル',
|
110 |
+
'启用增强器': 'Enhancerを有効化',
|
111 |
+
'启用相位声码器': 'フェーズボコーダを有効化',
|
112 |
+
'性能设置': 'パフォーマンスの設定',
|
113 |
+
'开始音频转换': '変換開始',
|
114 |
+
'停止音频转换': '変換停止',
|
115 |
+
'推理所用时间(ms):': '推論時間(ms):',
|
116 |
+
'扩散设置': '扩散设置',
|
117 |
+
'启用扩散': '启用扩散',
|
118 |
+
'扩散加速': '扩散加速',
|
119 |
+
'扩散深度': '扩散深度',
|
120 |
+
'扩散说话人id': '扩散说话人id',
|
121 |
+
'扩散模型文件': '扩散模型文件',
|
122 |
+
'不扩散安全区(加速但损失效果)': '不扩散安全区(加速但损失效果)',
|
123 |
+
'启用DPMs(推荐)': '启用DPMs(推荐)'
|
124 |
+
}
|
125 |
+
}
|
126 |
+
|
127 |
+
|
128 |
+
class I18nAuto:
|
129 |
+
def __init__(self, language=None):
|
130 |
+
self.language_list = LANGUAGE_LIST
|
131 |
+
self.language_all = LANGUAGE_ALL
|
132 |
+
self.language_map = {}
|
133 |
+
if language is None:
|
134 |
+
language = 'auto'
|
135 |
+
if language == 'auto':
|
136 |
+
language = locale.getdefaultlocale()[0]
|
137 |
+
if language not in self.language_list:
|
138 |
+
language = 'zh_CN'
|
139 |
+
self.language = language
|
140 |
+
super_language_list = []
|
141 |
+
while self.language_all[language]['SUPER'] != 'END':
|
142 |
+
super_language_list.append(language)
|
143 |
+
language = self.language_all[language]['SUPER']
|
144 |
+
super_language_list.append('zh_CN')
|
145 |
+
super_language_list.reverse()
|
146 |
+
for _lang in super_language_list:
|
147 |
+
self.read_language(self.language_all[_lang])
|
148 |
+
|
149 |
+
def read_language(self, lang_dict: dict):
|
150 |
+
for _key in lang_dict.keys():
|
151 |
+
self.language_map[_key] = lang_dict[_key]
|
152 |
+
|
153 |
+
def __call__(self, key):
|
154 |
+
return self.language_map[key]
|
DDSP-SVC/gui_locale.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import locale
|
2 |
+
'''
|
3 |
+
本地化方式如下所示
|
4 |
+
'''
|
5 |
+
|
6 |
+
LANGUAGE_LIST = ['zh_CN', 'en_US', 'ja_JP']
|
7 |
+
LANGUAGE_ALL = {
|
8 |
+
'zh_CN': {
|
9 |
+
'SUPER': 'END',
|
10 |
+
'LANGUAGE': 'zh_CN',
|
11 |
+
'选择模型文件': '选择模型文件',
|
12 |
+
'模型:.pt格式(自动识别同目录下config.yaml)': '模型:.pt格式(自动识别同目录下config.yaml)',
|
13 |
+
'选择配置文件所在目录': '选择配置文件所在目录',
|
14 |
+
'打开文件夹': '打开文件夹',
|
15 |
+
'读取配置文件': '读取配置文件',
|
16 |
+
'保存配置文件': '保存配置文件',
|
17 |
+
'快速配置文件': '快速配置文件',
|
18 |
+
'输入设备': '输入设备',
|
19 |
+
'输出设备': '输出设备',
|
20 |
+
'音频设备': '音频设备',
|
21 |
+
'说话人id': '说话人id',
|
22 |
+
'响应阈值': '响应阈值',
|
23 |
+
'变调': '变调',
|
24 |
+
'采样率': '采样率',
|
25 |
+
'启用捏音色功能': '启用捏音色功能',
|
26 |
+
'设置混合音色': '设置混合音色',
|
27 |
+
'普通设置': '普通设置',
|
28 |
+
'音频切分大小': '音频切分大小',
|
29 |
+
'交叉淡化时长': '交叉淡化时长',
|
30 |
+
'使用历史区块数量': '使用历史区块数量',
|
31 |
+
'f0预测模式': 'f0预测模式',
|
32 |
+
'启用增强器': '启用增强器',
|
33 |
+
'启用相位声码器': '启用相位声码器',
|
34 |
+
'性能设置': '性能设置',
|
35 |
+
'开始音频转换': '开始音频转换',
|
36 |
+
'停止音频转换': '停止音频转换',
|
37 |
+
'推理所用时间(ms):': '推理所用时间(ms):'
|
38 |
+
},
|
39 |
+
'en_US': {
|
40 |
+
'SUPER': 'zh_CN',
|
41 |
+
'LANGUAGE': 'en_US',
|
42 |
+
'选择模型文件': 'Select Model File',
|
43 |
+
'模型:.pt格式(自动识别同目录下config.yaml)': 'Model:.pt format(Auto ust config.yaml in here)',
|
44 |
+
'选择配置文件所在目录': 'Select the configuration file directory',
|
45 |
+
'打开文件夹': 'Open folder',
|
46 |
+
'读取配置文件': 'Read config file',
|
47 |
+
'保存配置文件': 'Save config file',
|
48 |
+
'快速配置文件': 'Fast config file',
|
49 |
+
'输入设备': 'Input device',
|
50 |
+
'输出设备': 'Output device',
|
51 |
+
'音频设备': 'Audio devices',
|
52 |
+
'说话人id': 'Speaker ID',
|
53 |
+
'响应阈值': 'Response threshold',
|
54 |
+
'变调': 'Pitch',
|
55 |
+
'采样率': 'Sampling rate',
|
56 |
+
'启用捏音色功能': 'Enable Mix Speaker',
|
57 |
+
'设置混合音色': 'Mix Speaker',
|
58 |
+
'普通设置': 'Normal Settings',
|
59 |
+
'音频切分大小': 'Segmentation size',
|
60 |
+
'交叉淡化时长': 'Cross fade duration',
|
61 |
+
'使用历史区块数量': 'Historical blocks used',
|
62 |
+
'f0预测模式': 'f0Extractor',
|
63 |
+
'启用增强器': 'Enable Enhancer',
|
64 |
+
'启用相位声码器': 'Enable Phase Vocoder',
|
65 |
+
'性能设置': 'Performance settings',
|
66 |
+
'开始音频转换': 'Start conversion',
|
67 |
+
'停止音频转换': 'Stop conversion',
|
68 |
+
'推理所用时间(ms):': 'Inference time(ms):'
|
69 |
+
},
|
70 |
+
'ja_JP': {
|
71 |
+
'SUPER': 'zh_CN',
|
72 |
+
'LANGUAGE': 'ja_JP',
|
73 |
+
'选择模型文件': 'モデルを選択',
|
74 |
+
'模型:.pt格式(自动识别同目录下config.yaml)': 'モデル:.pt形式(同じディレクトリにあるconfig.yamlを自動認識します)',
|
75 |
+
'选择配置文件所在目录': '設定ファイルを選択',
|
76 |
+
'打开文件夹': 'フォルダを開く',
|
77 |
+
'读取配置文件': '設定ファイルを読み込む',
|
78 |
+
'保存配置文件': '設定ファイルを保存',
|
79 |
+
'快速配置文件': '設定プロファイル',
|
80 |
+
'输入设备': '入力デバイス',
|
81 |
+
'输出设备': '出力デバイス',
|
82 |
+
'音频设备': '音声デバイス',
|
83 |
+
'说话人id': '話者ID',
|
84 |
+
'响应阈值': '応答時の閾値',
|
85 |
+
'变调': '音程',
|
86 |
+
'采样率': 'サンプリングレート',
|
87 |
+
'启用捏音色功能': 'ミキシングを有効化',
|
88 |
+
'设置混合音色': 'ミキシング',
|
89 |
+
'普通设置': '通常設定',
|
90 |
+
'音频切分大小': 'セグメンテーションのサイズ',
|
91 |
+
'交叉淡化时长': 'クロスフェードの間隔',
|
92 |
+
'使用历史区块数量': '使用するヒストリカルブロック数',
|
93 |
+
'f0预测模式': 'f0予測モデル',
|
94 |
+
'启用增强器': 'Enhancerを有効化',
|
95 |
+
'启用相位声码器': 'フェーズボコーダを有効化',
|
96 |
+
'性能设置': 'パフォーマンスの設定',
|
97 |
+
'开始音频转换': '変換開始',
|
98 |
+
'停止音频转换': '変換停止',
|
99 |
+
'推理所用时间(ms):': '推論時間(ms):'
|
100 |
+
}
|
101 |
+
}
|
102 |
+
|
103 |
+
|
104 |
+
class I18nAuto:
|
105 |
+
def __init__(self, language=None):
|
106 |
+
self.language_list = LANGUAGE_LIST
|
107 |
+
self.language_all = LANGUAGE_ALL
|
108 |
+
self.language_map = {}
|
109 |
+
if language is None:
|
110 |
+
language = 'auto'
|
111 |
+
if language == 'auto':
|
112 |
+
language = locale.getdefaultlocale()[0]
|
113 |
+
if language not in self.language_list:
|
114 |
+
language = 'zh_CN'
|
115 |
+
self.language = language
|
116 |
+
super_language_list = []
|
117 |
+
while self.language_all[language]['SUPER'] != 'END':
|
118 |
+
super_language_list.append(language)
|
119 |
+
language = self.language_all[language]['SUPER']
|
120 |
+
super_language_list.append('zh_CN')
|
121 |
+
super_language_list.reverse()
|
122 |
+
for _lang in super_language_list:
|
123 |
+
self.read_language(self.language_all[_lang])
|
124 |
+
|
125 |
+
def read_language(self, lang_dict: dict):
|
126 |
+
for _key in lang_dict.keys():
|
127 |
+
self.language_map[_key] = lang_dict[_key]
|
128 |
+
|
129 |
+
def __call__(self, key):
|
130 |
+
return self.language_map[key]
|
DDSP-SVC/logger/__init__.py
ADDED
File without changes
|
DDSP-SVC/logger/saver.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
author: wayn391@mastertones
|
3 |
+
'''
|
4 |
+
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
import yaml
|
9 |
+
import datetime
|
10 |
+
import torch
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
from . import utils
|
13 |
+
from torch.utils.tensorboard import SummaryWriter
|
14 |
+
|
15 |
+
class Saver(object):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
args,
|
19 |
+
initial_global_step=-1):
|
20 |
+
|
21 |
+
self.expdir = args.env.expdir
|
22 |
+
self.sample_rate = args.data.sampling_rate
|
23 |
+
|
24 |
+
# cold start
|
25 |
+
self.global_step = initial_global_step
|
26 |
+
self.init_time = time.time()
|
27 |
+
self.last_time = time.time()
|
28 |
+
|
29 |
+
# makedirs
|
30 |
+
os.makedirs(self.expdir, exist_ok=True)
|
31 |
+
|
32 |
+
# path
|
33 |
+
self.path_log_info = os.path.join(self.expdir, 'log_info.txt')
|
34 |
+
|
35 |
+
# ckpt
|
36 |
+
os.makedirs(self.expdir, exist_ok=True)
|
37 |
+
|
38 |
+
# writer
|
39 |
+
self.writer = SummaryWriter(os.path.join(self.expdir, 'logs'))
|
40 |
+
|
41 |
+
# save config
|
42 |
+
path_config = os.path.join(self.expdir, 'config.yaml')
|
43 |
+
with open(path_config, "w") as out_config:
|
44 |
+
yaml.dump(dict(args), out_config)
|
45 |
+
|
46 |
+
|
47 |
+
def log_info(self, msg):
|
48 |
+
'''log method'''
|
49 |
+
if isinstance(msg, dict):
|
50 |
+
msg_list = []
|
51 |
+
for k, v in msg.items():
|
52 |
+
tmp_str = ''
|
53 |
+
if isinstance(v, int):
|
54 |
+
tmp_str = '{}: {:,}'.format(k, v)
|
55 |
+
else:
|
56 |
+
tmp_str = '{}: {}'.format(k, v)
|
57 |
+
|
58 |
+
msg_list.append(tmp_str)
|
59 |
+
msg_str = '\n'.join(msg_list)
|
60 |
+
else:
|
61 |
+
msg_str = msg
|
62 |
+
|
63 |
+
# dsplay
|
64 |
+
print(msg_str)
|
65 |
+
|
66 |
+
# save
|
67 |
+
with open(self.path_log_info, 'a') as fp:
|
68 |
+
fp.write(msg_str+'\n')
|
69 |
+
|
70 |
+
def log_value(self, dict):
|
71 |
+
for k, v in dict.items():
|
72 |
+
self.writer.add_scalar(k, v, self.global_step)
|
73 |
+
|
74 |
+
def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5):
|
75 |
+
spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1)
|
76 |
+
spec = spec_cat[0]
|
77 |
+
if isinstance(spec, torch.Tensor):
|
78 |
+
spec = spec.cpu().numpy()
|
79 |
+
fig = plt.figure(figsize=(12, 9))
|
80 |
+
plt.pcolor(spec.T, vmin=vmin, vmax=vmax)
|
81 |
+
plt.tight_layout()
|
82 |
+
self.writer.add_figure(name, fig, self.global_step)
|
83 |
+
|
84 |
+
def log_audio(self, dict):
|
85 |
+
for k, v in dict.items():
|
86 |
+
self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate)
|
87 |
+
|
88 |
+
def get_interval_time(self, update=True):
|
89 |
+
cur_time = time.time()
|
90 |
+
time_interval = cur_time - self.last_time
|
91 |
+
if update:
|
92 |
+
self.last_time = cur_time
|
93 |
+
return time_interval
|
94 |
+
|
95 |
+
def get_total_time(self, to_str=True):
|
96 |
+
total_time = time.time() - self.init_time
|
97 |
+
if to_str:
|
98 |
+
total_time = str(datetime.timedelta(
|
99 |
+
seconds=total_time))[:-5]
|
100 |
+
return total_time
|
101 |
+
|
102 |
+
def save_model(
|
103 |
+
self,
|
104 |
+
model,
|
105 |
+
optimizer,
|
106 |
+
name='model',
|
107 |
+
postfix='',
|
108 |
+
to_json=False):
|
109 |
+
# path
|
110 |
+
if postfix:
|
111 |
+
postfix = '_' + postfix
|
112 |
+
path_pt = os.path.join(
|
113 |
+
self.expdir , name+postfix+'.pt')
|
114 |
+
|
115 |
+
# check
|
116 |
+
print(' [*] model checkpoint saved: {}'.format(path_pt))
|
117 |
+
|
118 |
+
# save
|
119 |
+
torch.save({
|
120 |
+
'global_step': self.global_step,
|
121 |
+
'model': model.state_dict(),
|
122 |
+
'optimizer': optimizer.state_dict()}, path_pt)
|
123 |
+
|
124 |
+
# to json
|
125 |
+
if to_json:
|
126 |
+
path_json = os.path.join(
|
127 |
+
self.expdir , name+'.json')
|
128 |
+
utils.to_json(path_params, path_json)
|
129 |
+
|
130 |
+
def delete_model(self, name='model', postfix=''):
|
131 |
+
# path
|
132 |
+
if postfix:
|
133 |
+
postfix = '_' + postfix
|
134 |
+
path_pt = os.path.join(
|
135 |
+
self.expdir , name+postfix+'.pt')
|
136 |
+
|
137 |
+
# delete
|
138 |
+
if os.path.exists(path_pt):
|
139 |
+
os.remove(path_pt)
|
140 |
+
print(' [*] model checkpoint deleted: {}'.format(path_pt))
|
141 |
+
|
142 |
+
def global_step_increment(self):
|
143 |
+
self.global_step += 1
|
144 |
+
|
145 |
+
|
DDSP-SVC/logger/utils.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import json
|
4 |
+
import pickle
|
5 |
+
import torch
|
6 |
+
|
7 |
+
def traverse_dir(
|
8 |
+
root_dir,
|
9 |
+
extension,
|
10 |
+
amount=None,
|
11 |
+
str_include=None,
|
12 |
+
str_exclude=None,
|
13 |
+
is_pure=False,
|
14 |
+
is_sort=False,
|
15 |
+
is_ext=True):
|
16 |
+
|
17 |
+
file_list = []
|
18 |
+
cnt = 0
|
19 |
+
for root, _, files in os.walk(root_dir):
|
20 |
+
for file in files:
|
21 |
+
if file.endswith(extension):
|
22 |
+
# path
|
23 |
+
mix_path = os.path.join(root, file)
|
24 |
+
pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
|
25 |
+
|
26 |
+
# amount
|
27 |
+
if (amount is not None) and (cnt == amount):
|
28 |
+
if is_sort:
|
29 |
+
file_list.sort()
|
30 |
+
return file_list
|
31 |
+
|
32 |
+
# check string
|
33 |
+
if (str_include is not None) and (str_include not in pure_path):
|
34 |
+
continue
|
35 |
+
if (str_exclude is not None) and (str_exclude in pure_path):
|
36 |
+
continue
|
37 |
+
|
38 |
+
if not is_ext:
|
39 |
+
ext = pure_path.split('.')[-1]
|
40 |
+
pure_path = pure_path[:-(len(ext)+1)]
|
41 |
+
file_list.append(pure_path)
|
42 |
+
cnt += 1
|
43 |
+
if is_sort:
|
44 |
+
file_list.sort()
|
45 |
+
return file_list
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
class DotDict(dict):
|
50 |
+
def __getattr__(*args):
|
51 |
+
val = dict.get(*args)
|
52 |
+
return DotDict(val) if type(val) is dict else val
|
53 |
+
|
54 |
+
__setattr__ = dict.__setitem__
|
55 |
+
__delattr__ = dict.__delitem__
|
56 |
+
|
57 |
+
|
58 |
+
def get_network_paras_amount(model_dict):
|
59 |
+
info = dict()
|
60 |
+
for model_name, model in model_dict.items():
|
61 |
+
# all_params = sum(p.numel() for p in model.parameters())
|
62 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
63 |
+
|
64 |
+
info[model_name] = trainable_params
|
65 |
+
return info
|
66 |
+
|
67 |
+
|
68 |
+
def load_config(path_config):
|
69 |
+
with open(path_config, "r") as config:
|
70 |
+
args = yaml.safe_load(config)
|
71 |
+
args = DotDict(args)
|
72 |
+
# print(args)
|
73 |
+
return args
|
74 |
+
|
75 |
+
|
76 |
+
def to_json(path_params, path_json):
|
77 |
+
params = torch.load(path_params, map_location=torch.device('cpu'))
|
78 |
+
raw_state_dict = {}
|
79 |
+
for k, v in params.items():
|
80 |
+
val = v.flatten().numpy().tolist()
|
81 |
+
raw_state_dict[k] = val
|
82 |
+
|
83 |
+
with open(path_json, 'w') as outfile:
|
84 |
+
json.dump(raw_state_dict, outfile,indent= "\t")
|
85 |
+
|
86 |
+
|
87 |
+
def convert_tensor_to_numpy(tensor, is_squeeze=True):
|
88 |
+
if is_squeeze:
|
89 |
+
tensor = tensor.squeeze()
|
90 |
+
if tensor.requires_grad:
|
91 |
+
tensor = tensor.detach()
|
92 |
+
if tensor.is_cuda:
|
93 |
+
tensor = tensor.cpu()
|
94 |
+
return tensor.numpy()
|
95 |
+
|
96 |
+
|
97 |
+
def load_model(
|
98 |
+
expdir,
|
99 |
+
model,
|
100 |
+
optimizer,
|
101 |
+
name='model',
|
102 |
+
postfix='',
|
103 |
+
device='cpu'):
|
104 |
+
if postfix == '':
|
105 |
+
postfix = '_' + postfix
|
106 |
+
path = os.path.join(expdir, name+postfix)
|
107 |
+
path_pt = traverse_dir(expdir, '.pt', is_ext=False)
|
108 |
+
global_step = 0
|
109 |
+
if len(path_pt) > 0:
|
110 |
+
steps = [s[len(path):] for s in path_pt]
|
111 |
+
maxstep = max([int(s) if s.isdigit() else 0 for s in steps])
|
112 |
+
if maxstep >= 0:
|
113 |
+
path_pt = path+str(maxstep)+'.pt'
|
114 |
+
else:
|
115 |
+
path_pt = path+'best.pt'
|
116 |
+
print(' [*] restoring model from', path_pt)
|
117 |
+
ckpt = torch.load(path_pt, map_location=torch.device(device))
|
118 |
+
global_step = ckpt['global_step']
|
119 |
+
model.load_state_dict(ckpt['model'], strict=False)
|
120 |
+
if maxstep != 0:
|
121 |
+
optimizer.load_state_dict(ckpt['optimizer'])
|
122 |
+
return global_step, model, optimizer
|
DDSP-SVC/main.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import librosa
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
import soundfile as sf
|
7 |
+
import pyworld as pw
|
8 |
+
import parselmouth
|
9 |
+
import hashlib
|
10 |
+
from ast import literal_eval
|
11 |
+
from slicer import Slicer
|
12 |
+
from ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder
|
13 |
+
from ddsp.core import upsample
|
14 |
+
from enhancer import Enhancer
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
def parse_args(args=None, namespace=None):
|
18 |
+
"""Parse command-line arguments."""
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument(
|
21 |
+
"-m",
|
22 |
+
"--model_path",
|
23 |
+
type=str,
|
24 |
+
required=True,
|
25 |
+
help="path to the model file",
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"-d",
|
29 |
+
"--device",
|
30 |
+
type=str,
|
31 |
+
default=None,
|
32 |
+
required=False,
|
33 |
+
help="cpu or cuda, auto if not set")
|
34 |
+
parser.add_argument(
|
35 |
+
"-i",
|
36 |
+
"--input",
|
37 |
+
type=str,
|
38 |
+
required=True,
|
39 |
+
help="path to the input audio file",
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"-o",
|
43 |
+
"--output",
|
44 |
+
type=str,
|
45 |
+
required=True,
|
46 |
+
help="path to the output audio file",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"-id",
|
50 |
+
"--spk_id",
|
51 |
+
type=str,
|
52 |
+
required=False,
|
53 |
+
default=1,
|
54 |
+
help="speaker id (for multi-speaker model) | default: 1",
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"-mix",
|
58 |
+
"--spk_mix_dict",
|
59 |
+
type=str,
|
60 |
+
required=False,
|
61 |
+
default="None",
|
62 |
+
help="mix-speaker dictionary (for multi-speaker model) | default: None",
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"-k",
|
66 |
+
"--key",
|
67 |
+
type=str,
|
68 |
+
required=False,
|
69 |
+
default=0,
|
70 |
+
help="key changed (number of semitones) | default: 0",
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"-e",
|
74 |
+
"--enhance",
|
75 |
+
type=str,
|
76 |
+
required=False,
|
77 |
+
default='true',
|
78 |
+
help="true or false | default: true",
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
"-pe",
|
82 |
+
"--pitch_extractor",
|
83 |
+
type=str,
|
84 |
+
required=False,
|
85 |
+
default='crepe',
|
86 |
+
help="pitch extrator type: parselmouth, dio, harvest, crepe (default)",
|
87 |
+
)
|
88 |
+
parser.add_argument(
|
89 |
+
"-fmin",
|
90 |
+
"--f0_min",
|
91 |
+
type=str,
|
92 |
+
required=False,
|
93 |
+
default=50,
|
94 |
+
help="min f0 (Hz) | default: 50",
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"-fmax",
|
98 |
+
"--f0_max",
|
99 |
+
type=str,
|
100 |
+
required=False,
|
101 |
+
default=1100,
|
102 |
+
help="max f0 (Hz) | default: 1100",
|
103 |
+
)
|
104 |
+
parser.add_argument(
|
105 |
+
"-th",
|
106 |
+
"--threhold",
|
107 |
+
type=str,
|
108 |
+
required=False,
|
109 |
+
default=-60,
|
110 |
+
help="response threhold (dB) | default: -60",
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"-eak",
|
114 |
+
"--enhancer_adaptive_key",
|
115 |
+
type=str,
|
116 |
+
required=False,
|
117 |
+
default=0,
|
118 |
+
help="adapt the enhancer to a higher vocal range (number of semitones) | default: 0",
|
119 |
+
)
|
120 |
+
return parser.parse_args(args=args, namespace=namespace)
|
121 |
+
|
122 |
+
|
123 |
+
def split(audio, sample_rate, hop_size, db_thresh = -40, min_len = 5000):
|
124 |
+
slicer = Slicer(
|
125 |
+
sr=sample_rate,
|
126 |
+
threshold=db_thresh,
|
127 |
+
min_length=min_len)
|
128 |
+
chunks = dict(slicer.slice(audio))
|
129 |
+
result = []
|
130 |
+
for k, v in chunks.items():
|
131 |
+
tag = v["split_time"].split(",")
|
132 |
+
if tag[0] != tag[1]:
|
133 |
+
start_frame = int(int(tag[0]) // hop_size)
|
134 |
+
end_frame = int(int(tag[1]) // hop_size)
|
135 |
+
if end_frame > start_frame:
|
136 |
+
result.append((
|
137 |
+
start_frame,
|
138 |
+
audio[int(start_frame * hop_size) : int(end_frame * hop_size)]))
|
139 |
+
return result
|
140 |
+
|
141 |
+
|
142 |
+
def cross_fade(a: np.ndarray, b: np.ndarray, idx: int):
|
143 |
+
result = np.zeros(idx + b.shape[0])
|
144 |
+
fade_len = a.shape[0] - idx
|
145 |
+
np.copyto(dst=result[:idx], src=a[:idx])
|
146 |
+
k = np.linspace(0, 1.0, num=fade_len, endpoint=True)
|
147 |
+
result[idx: a.shape[0]] = (1 - k) * a[idx:] + k * b[: fade_len]
|
148 |
+
np.copyto(dst=result[a.shape[0]:], src=b[fade_len:])
|
149 |
+
return result
|
150 |
+
|
151 |
+
|
152 |
+
if __name__ == '__main__':
|
153 |
+
# parse commands
|
154 |
+
cmd = parse_args()
|
155 |
+
|
156 |
+
#device = 'cpu'
|
157 |
+
device = cmd.device
|
158 |
+
if device is None:
|
159 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
160 |
+
|
161 |
+
# load ddsp model
|
162 |
+
model, args = load_model(cmd.model_path, device=device)
|
163 |
+
|
164 |
+
# load input
|
165 |
+
audio, sample_rate = librosa.load(cmd.input, sr=None)
|
166 |
+
if len(audio.shape) > 1:
|
167 |
+
audio = librosa.to_mono(audio)
|
168 |
+
hop_size = args.data.block_size * sample_rate / args.data.sampling_rate
|
169 |
+
|
170 |
+
# get MD5 hash from wav file
|
171 |
+
md5_hash = ""
|
172 |
+
with open(cmd.input, 'rb') as f:
|
173 |
+
data = f.read()
|
174 |
+
md5_hash = hashlib.md5(data).hexdigest()
|
175 |
+
print("MD5: " + md5_hash)
|
176 |
+
|
177 |
+
cache_dir_path = os.path.join(os.path.dirname(__file__), "cache")
|
178 |
+
cache_file_path = os.path.join(cache_dir_path, f"{cmd.pitch_extractor}_{hop_size}_{cmd.f0_min}_{cmd.f0_max}_{md5_hash}.npy")
|
179 |
+
|
180 |
+
is_cache_available = os.path.exists(cache_file_path)
|
181 |
+
if is_cache_available:
|
182 |
+
# f0 cache load
|
183 |
+
print('Loading pitch curves for input audio from cache directory...')
|
184 |
+
f0 = np.load(cache_file_path, allow_pickle=False)
|
185 |
+
else:
|
186 |
+
# extract f0
|
187 |
+
print('Pitch extractor type: ' + cmd.pitch_extractor)
|
188 |
+
pitch_extractor = F0_Extractor(
|
189 |
+
cmd.pitch_extractor,
|
190 |
+
sample_rate,
|
191 |
+
hop_size,
|
192 |
+
float(cmd.f0_min),
|
193 |
+
float(cmd.f0_max))
|
194 |
+
print('Extracting the pitch curve of the input audio...')
|
195 |
+
f0 = pitch_extractor.extract(audio, uv_interp = True, device = device)
|
196 |
+
|
197 |
+
# f0 cache save
|
198 |
+
os.makedirs(cache_dir_path, exist_ok=True)
|
199 |
+
np.save(cache_file_path, f0, allow_pickle=False)
|
200 |
+
|
201 |
+
f0 = torch.from_numpy(f0).float().to(device).unsqueeze(-1).unsqueeze(0)
|
202 |
+
|
203 |
+
# key change
|
204 |
+
f0 = f0 * 2 ** (float(cmd.key) / 12)
|
205 |
+
|
206 |
+
# extract volume
|
207 |
+
print('Extracting the volume envelope of the input audio...')
|
208 |
+
volume_extractor = Volume_Extractor(hop_size)
|
209 |
+
volume = volume_extractor.extract(audio)
|
210 |
+
mask = (volume > 10 ** (float(cmd.threhold) / 20)).astype('float')
|
211 |
+
mask = np.pad(mask, (4, 4), constant_values=(mask[0], mask[-1]))
|
212 |
+
mask = np.array([np.max(mask[n : n + 9]) for n in range(len(mask) - 8)])
|
213 |
+
mask = torch.from_numpy(mask).float().to(device).unsqueeze(-1).unsqueeze(0)
|
214 |
+
mask = upsample(mask, args.data.block_size).squeeze(-1)
|
215 |
+
volume = torch.from_numpy(volume).float().to(device).unsqueeze(-1).unsqueeze(0)
|
216 |
+
|
217 |
+
# load units encoder
|
218 |
+
if args.data.encoder == 'cnhubertsoftfish':
|
219 |
+
cnhubertsoft_gate = args.data.cnhubertsoft_gate
|
220 |
+
else:
|
221 |
+
cnhubertsoft_gate = 10
|
222 |
+
units_encoder = Units_Encoder(
|
223 |
+
args.data.encoder,
|
224 |
+
args.data.encoder_ckpt,
|
225 |
+
args.data.encoder_sample_rate,
|
226 |
+
args.data.encoder_hop_size,
|
227 |
+
cnhubertsoft_gate=cnhubertsoft_gate,
|
228 |
+
device = device)
|
229 |
+
|
230 |
+
# load enhancer
|
231 |
+
if cmd.enhance == 'true':
|
232 |
+
print('Enhancer type: ' + args.enhancer.type)
|
233 |
+
enhancer = Enhancer(args.enhancer.type, args.enhancer.ckpt, device=device)
|
234 |
+
else:
|
235 |
+
print('Enhancer type: none (using raw output of ddsp)')
|
236 |
+
|
237 |
+
# speaker id or mix-speaker dictionary
|
238 |
+
spk_mix_dict = literal_eval(cmd.spk_mix_dict)
|
239 |
+
if spk_mix_dict is not None:
|
240 |
+
print('Mix-speaker mode')
|
241 |
+
else:
|
242 |
+
print('Speaker ID: '+ str(int(cmd.spk_id)))
|
243 |
+
spk_id = torch.LongTensor(np.array([[int(cmd.spk_id)]])).to(device)
|
244 |
+
|
245 |
+
# forward and save the output
|
246 |
+
result = np.zeros(0)
|
247 |
+
current_length = 0
|
248 |
+
segments = split(audio, sample_rate, hop_size)
|
249 |
+
print('Cut the input audio into ' + str(len(segments)) + ' slices')
|
250 |
+
with torch.no_grad():
|
251 |
+
for segment in tqdm(segments):
|
252 |
+
start_frame = segment[0]
|
253 |
+
seg_input = torch.from_numpy(segment[1]).float().unsqueeze(0).to(device)
|
254 |
+
seg_units = units_encoder.encode(seg_input, sample_rate, hop_size)
|
255 |
+
|
256 |
+
seg_f0 = f0[:, start_frame : start_frame + seg_units.size(1), :]
|
257 |
+
seg_volume = volume[:, start_frame : start_frame + seg_units.size(1), :]
|
258 |
+
|
259 |
+
seg_output, _, (s_h, s_n) = model(seg_units, seg_f0, seg_volume, spk_id = spk_id, spk_mix_dict = spk_mix_dict)
|
260 |
+
seg_output *= mask[:, start_frame * args.data.block_size : (start_frame + seg_units.size(1)) * args.data.block_size]
|
261 |
+
|
262 |
+
if cmd.enhance == 'true':
|
263 |
+
seg_output, output_sample_rate = enhancer.enhance(
|
264 |
+
seg_output,
|
265 |
+
args.data.sampling_rate,
|
266 |
+
seg_f0,
|
267 |
+
args.data.block_size,
|
268 |
+
adaptive_key = cmd.enhancer_adaptive_key)
|
269 |
+
else:
|
270 |
+
output_sample_rate = args.data.sampling_rate
|
271 |
+
|
272 |
+
seg_output = seg_output.squeeze().cpu().numpy()
|
273 |
+
|
274 |
+
silent_length = round(start_frame * args.data.block_size * output_sample_rate / args.data.sampling_rate) - current_length
|
275 |
+
if silent_length >= 0:
|
276 |
+
result = np.append(result, np.zeros(silent_length))
|
277 |
+
result = np.append(result, seg_output)
|
278 |
+
else:
|
279 |
+
result = cross_fade(result, seg_output, current_length + silent_length)
|
280 |
+
current_length = current_length + silent_length + len(seg_output)
|
281 |
+
sf.write(cmd.output, result, output_sample_rate)
|
282 |
+
|
DDSP-SVC/main_diff.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import librosa
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
import soundfile as sf
|
7 |
+
import pyworld as pw
|
8 |
+
import parselmouth
|
9 |
+
import hashlib
|
10 |
+
from ast import literal_eval
|
11 |
+
from slicer import Slicer
|
12 |
+
from ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder
|
13 |
+
from ddsp.core import upsample
|
14 |
+
from diffusion.unit2mel import load_model_vocoder
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
def check_args(ddsp_args, diff_args):
|
18 |
+
if ddsp_args.data.sampling_rate != diff_args.data.sampling_rate:
|
19 |
+
print("Unmatch data.sampling_rate!")
|
20 |
+
return False
|
21 |
+
if ddsp_args.data.block_size != diff_args.data.block_size:
|
22 |
+
print("Unmatch data.block_size!")
|
23 |
+
return False
|
24 |
+
if ddsp_args.data.encoder != diff_args.data.encoder:
|
25 |
+
print("Unmatch data.encoder!")
|
26 |
+
return False
|
27 |
+
return True
|
28 |
+
|
29 |
+
def parse_args(args=None, namespace=None):
|
30 |
+
"""Parse command-line arguments."""
|
31 |
+
parser = argparse.ArgumentParser()
|
32 |
+
parser.add_argument(
|
33 |
+
"-diff",
|
34 |
+
"--diff_ckpt",
|
35 |
+
type=str,
|
36 |
+
required=True,
|
37 |
+
help="path to the diffusion model checkpoint",
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"-ddsp",
|
41 |
+
"--ddsp_ckpt",
|
42 |
+
type=str,
|
43 |
+
required=False,
|
44 |
+
default="None",
|
45 |
+
help="path to the DDSP model checkpoint (for shallow diffusion)",
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"-d",
|
49 |
+
"--device",
|
50 |
+
type=str,
|
51 |
+
default=None,
|
52 |
+
required=False,
|
53 |
+
help="cpu or cuda, auto if not set")
|
54 |
+
parser.add_argument(
|
55 |
+
"-i",
|
56 |
+
"--input",
|
57 |
+
type=str,
|
58 |
+
required=True,
|
59 |
+
help="path to the input audio file",
|
60 |
+
)
|
61 |
+
parser.add_argument(
|
62 |
+
"-o",
|
63 |
+
"--output",
|
64 |
+
type=str,
|
65 |
+
required=True,
|
66 |
+
help="path to the output audio file",
|
67 |
+
)
|
68 |
+
parser.add_argument(
|
69 |
+
"-id",
|
70 |
+
"--spk_id",
|
71 |
+
type=str,
|
72 |
+
required=False,
|
73 |
+
default=1,
|
74 |
+
help="speaker id (for multi-speaker model) | default: 1",
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"-mix",
|
78 |
+
"--spk_mix_dict",
|
79 |
+
type=str,
|
80 |
+
required=False,
|
81 |
+
default="None",
|
82 |
+
help="mix-speaker dictionary (for multi-speaker model) | default: None",
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"-k",
|
86 |
+
"--key",
|
87 |
+
type=str,
|
88 |
+
required=False,
|
89 |
+
default=0,
|
90 |
+
help="key changed (number of semitones) | default: 0",
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"-f",
|
94 |
+
"--formant_shift_key",
|
95 |
+
type=str,
|
96 |
+
required=False,
|
97 |
+
default=0,
|
98 |
+
help="formant changed (number of semitones) , only for pitch-augmented model| default: 0",
|
99 |
+
)
|
100 |
+
parser.add_argument(
|
101 |
+
"-pe",
|
102 |
+
"--pitch_extractor",
|
103 |
+
type=str,
|
104 |
+
required=False,
|
105 |
+
default='crepe',
|
106 |
+
help="pitch extrator type: parselmouth, dio, harvest, crepe (default)",
|
107 |
+
)
|
108 |
+
parser.add_argument(
|
109 |
+
"-fmin",
|
110 |
+
"--f0_min",
|
111 |
+
type=str,
|
112 |
+
required=False,
|
113 |
+
default=50,
|
114 |
+
help="min f0 (Hz) | default: 50",
|
115 |
+
)
|
116 |
+
parser.add_argument(
|
117 |
+
"-fmax",
|
118 |
+
"--f0_max",
|
119 |
+
type=str,
|
120 |
+
required=False,
|
121 |
+
default=1100,
|
122 |
+
help="max f0 (Hz) | default: 1100",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"-th",
|
126 |
+
"--threhold",
|
127 |
+
type=str,
|
128 |
+
required=False,
|
129 |
+
default=-60,
|
130 |
+
help="response threhold (dB) | default: -60",
|
131 |
+
)
|
132 |
+
parser.add_argument(
|
133 |
+
"-diffid",
|
134 |
+
"--diff_spk_id",
|
135 |
+
type=str,
|
136 |
+
required=False,
|
137 |
+
default='auto',
|
138 |
+
help="diffusion speaker id (for multi-speaker model) | default: auto",
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"-speedup",
|
142 |
+
"--speedup",
|
143 |
+
type=str,
|
144 |
+
required=False,
|
145 |
+
default='auto',
|
146 |
+
help="speed up | default: auto",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"-method",
|
150 |
+
"--method",
|
151 |
+
type=str,
|
152 |
+
required=False,
|
153 |
+
default='auto',
|
154 |
+
help="pndm or dpm-solver | default: auto",
|
155 |
+
)
|
156 |
+
parser.add_argument(
|
157 |
+
"-kstep",
|
158 |
+
"--k_step",
|
159 |
+
type=str,
|
160 |
+
required=False,
|
161 |
+
default=None,
|
162 |
+
help="shallow diffusion steps | default: None",
|
163 |
+
)
|
164 |
+
return parser.parse_args(args=args, namespace=namespace)
|
165 |
+
|
166 |
+
|
167 |
+
def split(audio, sample_rate, hop_size, db_thresh = -40, min_len = 5000):
|
168 |
+
slicer = Slicer(
|
169 |
+
sr=sample_rate,
|
170 |
+
threshold=db_thresh,
|
171 |
+
min_length=min_len)
|
172 |
+
chunks = dict(slicer.slice(audio))
|
173 |
+
result = []
|
174 |
+
for k, v in chunks.items():
|
175 |
+
tag = v["split_time"].split(",")
|
176 |
+
if tag[0] != tag[1]:
|
177 |
+
start_frame = int(int(tag[0]) // hop_size)
|
178 |
+
end_frame = int(int(tag[1]) // hop_size)
|
179 |
+
if end_frame > start_frame:
|
180 |
+
result.append((
|
181 |
+
start_frame,
|
182 |
+
audio[int(start_frame * hop_size) : int(end_frame * hop_size)]))
|
183 |
+
return result
|
184 |
+
|
185 |
+
|
186 |
+
def cross_fade(a: np.ndarray, b: np.ndarray, idx: int):
|
187 |
+
result = np.zeros(idx + b.shape[0])
|
188 |
+
fade_len = a.shape[0] - idx
|
189 |
+
np.copyto(dst=result[:idx], src=a[:idx])
|
190 |
+
k = np.linspace(0, 1.0, num=fade_len, endpoint=True)
|
191 |
+
result[idx: a.shape[0]] = (1 - k) * a[idx:] + k * b[: fade_len]
|
192 |
+
np.copyto(dst=result[a.shape[0]:], src=b[fade_len:])
|
193 |
+
return result
|
194 |
+
|
195 |
+
|
196 |
+
if __name__ == '__main__':
|
197 |
+
# parse commands
|
198 |
+
cmd = parse_args()
|
199 |
+
|
200 |
+
#device = 'cpu'
|
201 |
+
device = cmd.device
|
202 |
+
if device is None:
|
203 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
204 |
+
|
205 |
+
# load diffusion model
|
206 |
+
model, vocoder, args = load_model_vocoder(cmd.diff_ckpt, device=device)
|
207 |
+
|
208 |
+
# load input
|
209 |
+
audio, sample_rate = librosa.load(cmd.input, sr=None)
|
210 |
+
if len(audio.shape) > 1:
|
211 |
+
audio = librosa.to_mono(audio)
|
212 |
+
hop_size = args.data.block_size * sample_rate / args.data.sampling_rate
|
213 |
+
|
214 |
+
# get MD5 hash from wav file
|
215 |
+
md5_hash = ""
|
216 |
+
with open(cmd.input, 'rb') as f:
|
217 |
+
data = f.read()
|
218 |
+
md5_hash = hashlib.md5(data).hexdigest()
|
219 |
+
print("MD5: " + md5_hash)
|
220 |
+
|
221 |
+
cache_dir_path = os.path.join(os.path.dirname(__file__), "cache")
|
222 |
+
cache_file_path = os.path.join(cache_dir_path, f"{cmd.pitch_extractor}_{hop_size}_{cmd.f0_min}_{cmd.f0_max}_{md5_hash}.npy")
|
223 |
+
|
224 |
+
is_cache_available = os.path.exists(cache_file_path)
|
225 |
+
if is_cache_available:
|
226 |
+
# f0 cache load
|
227 |
+
print('Loading pitch curves for input audio from cache directory...')
|
228 |
+
f0 = np.load(cache_file_path, allow_pickle=False)
|
229 |
+
else:
|
230 |
+
# extract f0
|
231 |
+
print('Pitch extractor type: ' + cmd.pitch_extractor)
|
232 |
+
pitch_extractor = F0_Extractor(
|
233 |
+
cmd.pitch_extractor,
|
234 |
+
sample_rate,
|
235 |
+
hop_size,
|
236 |
+
float(cmd.f0_min),
|
237 |
+
float(cmd.f0_max))
|
238 |
+
print('Extracting the pitch curve of the input audio...')
|
239 |
+
f0 = pitch_extractor.extract(audio, uv_interp = True, device = device)
|
240 |
+
|
241 |
+
# f0 cache save
|
242 |
+
os.makedirs(cache_dir_path, exist_ok=True)
|
243 |
+
np.save(cache_file_path, f0, allow_pickle=False)
|
244 |
+
|
245 |
+
f0 = torch.from_numpy(f0).float().to(device).unsqueeze(-1).unsqueeze(0)
|
246 |
+
|
247 |
+
# key change
|
248 |
+
f0 = f0 * 2 ** (float(cmd.key) / 12)
|
249 |
+
|
250 |
+
# formant change
|
251 |
+
formant_shift_key = torch.LongTensor(np.array([[float(cmd.formant_shift_key)]])).to(device)
|
252 |
+
|
253 |
+
# extract volume
|
254 |
+
print('Extracting the volume envelope of the input audio...')
|
255 |
+
volume_extractor = Volume_Extractor(hop_size)
|
256 |
+
volume = volume_extractor.extract(audio)
|
257 |
+
mask = (volume > 10 ** (float(cmd.threhold) / 20)).astype('float')
|
258 |
+
mask = np.pad(mask, (4, 4), constant_values=(mask[0], mask[-1]))
|
259 |
+
mask = np.array([np.max(mask[n : n + 9]) for n in range(len(mask) - 8)])
|
260 |
+
mask = torch.from_numpy(mask).float().to(device).unsqueeze(-1).unsqueeze(0)
|
261 |
+
mask = upsample(mask, args.data.block_size).squeeze(-1)
|
262 |
+
volume = torch.from_numpy(volume).float().to(device).unsqueeze(-1).unsqueeze(0)
|
263 |
+
|
264 |
+
# load units encoder
|
265 |
+
if args.data.encoder == 'cnhubertsoftfish':
|
266 |
+
cnhubertsoft_gate = args.data.cnhubertsoft_gate
|
267 |
+
else:
|
268 |
+
cnhubertsoft_gate = 10
|
269 |
+
units_encoder = Units_Encoder(
|
270 |
+
args.data.encoder,
|
271 |
+
args.data.encoder_ckpt,
|
272 |
+
args.data.encoder_sample_rate,
|
273 |
+
args.data.encoder_hop_size,
|
274 |
+
cnhubertsoft_gate=cnhubertsoft_gate,
|
275 |
+
device = device)
|
276 |
+
|
277 |
+
# speaker id or mix-speaker dictionary
|
278 |
+
spk_mix_dict = literal_eval(cmd.spk_mix_dict)
|
279 |
+
spk_id = torch.LongTensor(np.array([[int(cmd.spk_id)]])).to(device)
|
280 |
+
if cmd.diff_spk_id == 'auto':
|
281 |
+
diff_spk_id = spk_id
|
282 |
+
else:
|
283 |
+
diff_spk_id = torch.LongTensor(np.array([[int(cmd.diff_spk_id)]])).to(device)
|
284 |
+
if spk_mix_dict is not None:
|
285 |
+
print('Mix-speaker mode')
|
286 |
+
else:
|
287 |
+
print('DDSP Speaker ID: '+ str(int(cmd.spk_id)))
|
288 |
+
print('Diffusion Speaker ID: '+ str(cmd.diff_spk_id))
|
289 |
+
|
290 |
+
# speed up
|
291 |
+
if cmd.speedup == 'auto':
|
292 |
+
infer_speedup = args.infer.speedup
|
293 |
+
else:
|
294 |
+
infer_speedup = int(cmd.speedup)
|
295 |
+
if cmd.method == 'auto':
|
296 |
+
method = args.infer.method
|
297 |
+
else:
|
298 |
+
method = cmd.method
|
299 |
+
if infer_speedup > 1:
|
300 |
+
print('Sampling method: '+ method)
|
301 |
+
print('Speed up: '+ str(infer_speedup))
|
302 |
+
else:
|
303 |
+
print('Sampling method: DDPM')
|
304 |
+
|
305 |
+
ddsp = None
|
306 |
+
input_mel = None
|
307 |
+
k_step = None
|
308 |
+
if cmd.k_step is not None:
|
309 |
+
k_step = int(cmd.k_step)
|
310 |
+
print('Shallow diffusion step: ' + str(k_step))
|
311 |
+
if cmd.ddsp_ckpt != "None":
|
312 |
+
# load ddsp model
|
313 |
+
ddsp, ddsp_args = load_model(cmd.ddsp_ckpt, device=device)
|
314 |
+
if not check_args(ddsp_args, args):
|
315 |
+
print("Cannot use this DDSP model for shallow diffusion, gaussian diffusion will be used!")
|
316 |
+
ddsp = None
|
317 |
+
else:
|
318 |
+
print('DDSP model is not identified!')
|
319 |
+
print('Extracting the mel spectrum of the input audio for shallow diffusion...')
|
320 |
+
audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(device)
|
321 |
+
input_mel = vocoder.extract(audio_t, sample_rate)
|
322 |
+
input_mel = torch.cat((input_mel, input_mel[:,-1:,:]), 1)
|
323 |
+
else:
|
324 |
+
print('Shallow diffusion step is not identified, gaussian diffusion will be used!')
|
325 |
+
|
326 |
+
# forward and save the output
|
327 |
+
result = np.zeros(0)
|
328 |
+
current_length = 0
|
329 |
+
segments = split(audio, sample_rate, hop_size)
|
330 |
+
print('Cut the input audio into ' + str(len(segments)) + ' slices')
|
331 |
+
with torch.no_grad():
|
332 |
+
for segment in tqdm(segments):
|
333 |
+
start_frame = segment[0]
|
334 |
+
seg_input = torch.from_numpy(segment[1]).float().unsqueeze(0).to(device)
|
335 |
+
seg_units = units_encoder.encode(seg_input, sample_rate, hop_size)
|
336 |
+
|
337 |
+
seg_f0 = f0[:, start_frame : start_frame + seg_units.size(1), :]
|
338 |
+
seg_volume = volume[:, start_frame : start_frame + seg_units.size(1), :]
|
339 |
+
if ddsp is not None:
|
340 |
+
seg_ddsp_f0 = 2 ** (-float(cmd.formant_shift_key) / 12) * seg_f0
|
341 |
+
seg_ddsp_output, _ , (_, _) = ddsp(seg_units, seg_ddsp_f0, seg_volume, spk_id = spk_id, spk_mix_dict = spk_mix_dict)
|
342 |
+
seg_input_mel = vocoder.extract(seg_ddsp_output, args.data.sampling_rate, keyshift=float(cmd.formant_shift_key))
|
343 |
+
elif input_mel != None:
|
344 |
+
seg_input_mel = input_mel[:, start_frame : start_frame + seg_units.size(1), :]
|
345 |
+
else:
|
346 |
+
seg_input_mel = None
|
347 |
+
|
348 |
+
seg_mel = model(
|
349 |
+
seg_units,
|
350 |
+
seg_f0,
|
351 |
+
seg_volume,
|
352 |
+
spk_id = diff_spk_id,
|
353 |
+
spk_mix_dict = spk_mix_dict,
|
354 |
+
aug_shift = formant_shift_key,
|
355 |
+
gt_spec=seg_input_mel,
|
356 |
+
infer=True,
|
357 |
+
infer_speedup=infer_speedup,
|
358 |
+
method=method,
|
359 |
+
k_step=k_step)
|
360 |
+
seg_output = vocoder.infer(seg_mel, seg_f0)
|
361 |
+
seg_output *= mask[:, start_frame * args.data.block_size : (start_frame + seg_units.size(1)) * args.data.block_size]
|
362 |
+
seg_output = seg_output.squeeze().cpu().numpy()
|
363 |
+
|
364 |
+
silent_length = round(start_frame * args.data.block_size) - current_length
|
365 |
+
if silent_length >= 0:
|
366 |
+
result = np.append(result, np.zeros(silent_length))
|
367 |
+
result = np.append(result, seg_output)
|
368 |
+
else:
|
369 |
+
result = cross_fade(result, seg_output, current_length + silent_length)
|
370 |
+
current_length = current_length + silent_length + len(seg_output)
|
371 |
+
sf.write(cmd.output, result, args.data.sampling_rate)
|
372 |
+
|
DDSP-SVC/nsf_hifigan/env.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
|
4 |
+
|
5 |
+
class AttrDict(dict):
|
6 |
+
def __init__(self, *args, **kwargs):
|
7 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
8 |
+
self.__dict__ = self
|
9 |
+
|
10 |
+
|
11 |
+
def build_env(config, config_name, path):
|
12 |
+
t_path = os.path.join(path, config_name)
|
13 |
+
if config != t_path:
|
14 |
+
os.makedirs(path, exist_ok=True)
|
15 |
+
shutil.copyfile(config, os.path.join(path, config_name))
|
DDSP-SVC/nsf_hifigan/models.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from .env import AttrDict
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
9 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
10 |
+
from .utils import init_weights, get_padding
|
11 |
+
|
12 |
+
LRELU_SLOPE = 0.1
|
13 |
+
|
14 |
+
|
15 |
+
def load_model(model_path, device='cuda'):
|
16 |
+
config_file = os.path.join(os.path.split(model_path)[0], 'config.json')
|
17 |
+
with open(config_file) as f:
|
18 |
+
data = f.read()
|
19 |
+
|
20 |
+
json_config = json.loads(data)
|
21 |
+
h = AttrDict(json_config)
|
22 |
+
|
23 |
+
generator = Generator(h).to(device)
|
24 |
+
|
25 |
+
cp_dict = torch.load(model_path, map_location=device)
|
26 |
+
generator.load_state_dict(cp_dict['generator'])
|
27 |
+
generator.eval()
|
28 |
+
generator.remove_weight_norm()
|
29 |
+
del cp_dict
|
30 |
+
return generator, h
|
31 |
+
|
32 |
+
|
33 |
+
class ResBlock1(torch.nn.Module):
|
34 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
35 |
+
super(ResBlock1, self).__init__()
|
36 |
+
self.h = h
|
37 |
+
self.convs1 = nn.ModuleList([
|
38 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
39 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
40 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
41 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
42 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
43 |
+
padding=get_padding(kernel_size, dilation[2])))
|
44 |
+
])
|
45 |
+
self.convs1.apply(init_weights)
|
46 |
+
|
47 |
+
self.convs2 = nn.ModuleList([
|
48 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
49 |
+
padding=get_padding(kernel_size, 1))),
|
50 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
51 |
+
padding=get_padding(kernel_size, 1))),
|
52 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
53 |
+
padding=get_padding(kernel_size, 1)))
|
54 |
+
])
|
55 |
+
self.convs2.apply(init_weights)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
59 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
60 |
+
xt = c1(xt)
|
61 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
62 |
+
xt = c2(xt)
|
63 |
+
x = xt + x
|
64 |
+
return x
|
65 |
+
|
66 |
+
def remove_weight_norm(self):
|
67 |
+
for l in self.convs1:
|
68 |
+
remove_weight_norm(l)
|
69 |
+
for l in self.convs2:
|
70 |
+
remove_weight_norm(l)
|
71 |
+
|
72 |
+
|
73 |
+
class ResBlock2(torch.nn.Module):
|
74 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
75 |
+
super(ResBlock2, self).__init__()
|
76 |
+
self.h = h
|
77 |
+
self.convs = nn.ModuleList([
|
78 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
79 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
80 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
81 |
+
padding=get_padding(kernel_size, dilation[1])))
|
82 |
+
])
|
83 |
+
self.convs.apply(init_weights)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
for c in self.convs:
|
87 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
88 |
+
xt = c(xt)
|
89 |
+
x = xt + x
|
90 |
+
return x
|
91 |
+
|
92 |
+
def remove_weight_norm(self):
|
93 |
+
for l in self.convs:
|
94 |
+
remove_weight_norm(l)
|
95 |
+
|
96 |
+
|
97 |
+
class SineGen(torch.nn.Module):
|
98 |
+
""" Definition of sine generator
|
99 |
+
SineGen(samp_rate, harmonic_num = 0,
|
100 |
+
sine_amp = 0.1, noise_std = 0.003,
|
101 |
+
voiced_threshold = 0,
|
102 |
+
flag_for_pulse=False)
|
103 |
+
samp_rate: sampling rate in Hz
|
104 |
+
harmonic_num: number of harmonic overtones (default 0)
|
105 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
106 |
+
noise_std: std of Gaussian noise (default 0.003)
|
107 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
108 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
109 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
110 |
+
segment is always sin(np.pi) or cos(0)
|
111 |
+
"""
|
112 |
+
|
113 |
+
def __init__(self, samp_rate, harmonic_num=0,
|
114 |
+
sine_amp=0.1, noise_std=0.003,
|
115 |
+
voiced_threshold=0):
|
116 |
+
super(SineGen, self).__init__()
|
117 |
+
self.sine_amp = sine_amp
|
118 |
+
self.noise_std = noise_std
|
119 |
+
self.harmonic_num = harmonic_num
|
120 |
+
self.dim = self.harmonic_num + 1
|
121 |
+
self.sampling_rate = samp_rate
|
122 |
+
self.voiced_threshold = voiced_threshold
|
123 |
+
|
124 |
+
def _f02uv(self, f0):
|
125 |
+
# generate uv signal
|
126 |
+
uv = torch.ones_like(f0)
|
127 |
+
uv = uv * (f0 > self.voiced_threshold)
|
128 |
+
return uv
|
129 |
+
|
130 |
+
@torch.no_grad()
|
131 |
+
def forward(self, f0, upp):
|
132 |
+
""" sine_tensor, uv = forward(f0)
|
133 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
134 |
+
f0 for unvoiced steps should be 0
|
135 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
136 |
+
output uv: tensor(batchsize=1, length, 1)
|
137 |
+
"""
|
138 |
+
f0 = f0.unsqueeze(-1)
|
139 |
+
fn = torch.multiply(f0, torch.arange(1, self.dim + 1, device=f0.device).reshape((1, 1, -1)))
|
140 |
+
rad_values = (fn / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
|
141 |
+
rand_ini = torch.rand(fn.shape[0], fn.shape[2], device=fn.device)
|
142 |
+
rand_ini[:, 0] = 0
|
143 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
144 |
+
is_half = rad_values.dtype is not torch.float32
|
145 |
+
tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化
|
146 |
+
if is_half:
|
147 |
+
tmp_over_one = tmp_over_one.half()
|
148 |
+
else:
|
149 |
+
tmp_over_one = tmp_over_one.float()
|
150 |
+
tmp_over_one *= upp
|
151 |
+
tmp_over_one = F.interpolate(
|
152 |
+
tmp_over_one.transpose(2, 1), scale_factor=upp,
|
153 |
+
mode='linear', align_corners=True
|
154 |
+
).transpose(2, 1)
|
155 |
+
rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)
|
156 |
+
tmp_over_one %= 1
|
157 |
+
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
|
158 |
+
cumsum_shift = torch.zeros_like(rad_values)
|
159 |
+
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
160 |
+
rad_values = rad_values.double()
|
161 |
+
cumsum_shift = cumsum_shift.double()
|
162 |
+
sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
|
163 |
+
if is_half:
|
164 |
+
sine_waves = sine_waves.half()
|
165 |
+
else:
|
166 |
+
sine_waves = sine_waves.float()
|
167 |
+
sine_waves = sine_waves * self.sine_amp
|
168 |
+
return sine_waves
|
169 |
+
|
170 |
+
|
171 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
172 |
+
""" SourceModule for hn-nsf
|
173 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
174 |
+
add_noise_std=0.003, voiced_threshod=0)
|
175 |
+
sampling_rate: sampling_rate in Hz
|
176 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
177 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
178 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
179 |
+
note that amplitude of noise in unvoiced is decided
|
180 |
+
by sine_amp
|
181 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
182 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
183 |
+
F0_sampled (batchsize, length, 1)
|
184 |
+
Sine_source (batchsize, length, 1)
|
185 |
+
noise_source (batchsize, length 1)
|
186 |
+
uv (batchsize, length, 1)
|
187 |
+
"""
|
188 |
+
|
189 |
+
def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
|
190 |
+
add_noise_std=0.003, voiced_threshod=0):
|
191 |
+
super(SourceModuleHnNSF, self).__init__()
|
192 |
+
|
193 |
+
self.sine_amp = sine_amp
|
194 |
+
self.noise_std = add_noise_std
|
195 |
+
|
196 |
+
# to produce sine waveforms
|
197 |
+
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
198 |
+
sine_amp, add_noise_std, voiced_threshod)
|
199 |
+
|
200 |
+
# to merge source harmonics into a single excitation
|
201 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
202 |
+
self.l_tanh = torch.nn.Tanh()
|
203 |
+
|
204 |
+
def forward(self, x, upp):
|
205 |
+
sine_wavs = self.l_sin_gen(x, upp)
|
206 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
207 |
+
return sine_merge
|
208 |
+
|
209 |
+
|
210 |
+
class Generator(torch.nn.Module):
|
211 |
+
def __init__(self, h):
|
212 |
+
super(Generator, self).__init__()
|
213 |
+
self.h = h
|
214 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
215 |
+
self.num_upsamples = len(h.upsample_rates)
|
216 |
+
self.m_source = SourceModuleHnNSF(
|
217 |
+
sampling_rate=h.sampling_rate,
|
218 |
+
harmonic_num=8
|
219 |
+
)
|
220 |
+
self.noise_convs = nn.ModuleList()
|
221 |
+
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
|
222 |
+
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
223 |
+
|
224 |
+
self.ups = nn.ModuleList()
|
225 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
226 |
+
c_cur = h.upsample_initial_channel // (2 ** (i + 1))
|
227 |
+
self.ups.append(weight_norm(
|
228 |
+
ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
|
229 |
+
k, u, padding=(k - u) // 2)))
|
230 |
+
if i + 1 < len(h.upsample_rates): #
|
231 |
+
stride_f0 = int(np.prod(h.upsample_rates[i + 1:]))
|
232 |
+
self.noise_convs.append(Conv1d(
|
233 |
+
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
|
234 |
+
else:
|
235 |
+
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
236 |
+
self.resblocks = nn.ModuleList()
|
237 |
+
ch = h.upsample_initial_channel
|
238 |
+
for i in range(len(self.ups)):
|
239 |
+
ch //= 2
|
240 |
+
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
241 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
242 |
+
|
243 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
244 |
+
self.ups.apply(init_weights)
|
245 |
+
self.conv_post.apply(init_weights)
|
246 |
+
self.upp = int(np.prod(h.upsample_rates))
|
247 |
+
|
248 |
+
def forward(self, x, f0):
|
249 |
+
har_source = self.m_source(f0, self.upp).transpose(1, 2)
|
250 |
+
x = self.conv_pre(x)
|
251 |
+
for i in range(self.num_upsamples):
|
252 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
253 |
+
x = self.ups[i](x)
|
254 |
+
x_source = self.noise_convs[i](har_source)
|
255 |
+
x = x + x_source
|
256 |
+
xs = None
|
257 |
+
for j in range(self.num_kernels):
|
258 |
+
if xs is None:
|
259 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
260 |
+
else:
|
261 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
262 |
+
x = xs / self.num_kernels
|
263 |
+
x = F.leaky_relu(x)
|
264 |
+
x = self.conv_post(x)
|
265 |
+
x = torch.tanh(x)
|
266 |
+
|
267 |
+
return x
|
268 |
+
|
269 |
+
def remove_weight_norm(self):
|
270 |
+
print('Removing weight norm...')
|
271 |
+
for l in self.ups:
|
272 |
+
remove_weight_norm(l)
|
273 |
+
for l in self.resblocks:
|
274 |
+
l.remove_weight_norm()
|
275 |
+
remove_weight_norm(self.conv_pre)
|
276 |
+
remove_weight_norm(self.conv_post)
|
277 |
+
|
278 |
+
|
279 |
+
class DiscriminatorP(torch.nn.Module):
|
280 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
281 |
+
super(DiscriminatorP, self).__init__()
|
282 |
+
self.period = period
|
283 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
284 |
+
self.convs = nn.ModuleList([
|
285 |
+
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
286 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
287 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
288 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
289 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
290 |
+
])
|
291 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
292 |
+
|
293 |
+
def forward(self, x):
|
294 |
+
fmap = []
|
295 |
+
|
296 |
+
# 1d to 2d
|
297 |
+
b, c, t = x.shape
|
298 |
+
if t % self.period != 0: # pad first
|
299 |
+
n_pad = self.period - (t % self.period)
|
300 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
301 |
+
t = t + n_pad
|
302 |
+
x = x.view(b, c, t // self.period, self.period)
|
303 |
+
|
304 |
+
for l in self.convs:
|
305 |
+
x = l(x)
|
306 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
307 |
+
fmap.append(x)
|
308 |
+
x = self.conv_post(x)
|
309 |
+
fmap.append(x)
|
310 |
+
x = torch.flatten(x, 1, -1)
|
311 |
+
|
312 |
+
return x, fmap
|
313 |
+
|
314 |
+
|
315 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
316 |
+
def __init__(self, periods=None):
|
317 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
318 |
+
self.periods = periods if periods is not None else [2, 3, 5, 7, 11]
|
319 |
+
self.discriminators = nn.ModuleList()
|
320 |
+
for period in self.periods:
|
321 |
+
self.discriminators.append(DiscriminatorP(period))
|
322 |
+
|
323 |
+
def forward(self, y, y_hat):
|
324 |
+
y_d_rs = []
|
325 |
+
y_d_gs = []
|
326 |
+
fmap_rs = []
|
327 |
+
fmap_gs = []
|
328 |
+
for i, d in enumerate(self.discriminators):
|
329 |
+
y_d_r, fmap_r = d(y)
|
330 |
+
y_d_g, fmap_g = d(y_hat)
|
331 |
+
y_d_rs.append(y_d_r)
|
332 |
+
fmap_rs.append(fmap_r)
|
333 |
+
y_d_gs.append(y_d_g)
|
334 |
+
fmap_gs.append(fmap_g)
|
335 |
+
|
336 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
337 |
+
|
338 |
+
|
339 |
+
class DiscriminatorS(torch.nn.Module):
|
340 |
+
def __init__(self, use_spectral_norm=False):
|
341 |
+
super(DiscriminatorS, self).__init__()
|
342 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
343 |
+
self.convs = nn.ModuleList([
|
344 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
345 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
346 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
347 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
348 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
349 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
350 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
351 |
+
])
|
352 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
353 |
+
|
354 |
+
def forward(self, x):
|
355 |
+
fmap = []
|
356 |
+
for l in self.convs:
|
357 |
+
x = l(x)
|
358 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
359 |
+
fmap.append(x)
|
360 |
+
x = self.conv_post(x)
|
361 |
+
fmap.append(x)
|
362 |
+
x = torch.flatten(x, 1, -1)
|
363 |
+
|
364 |
+
return x, fmap
|
365 |
+
|
366 |
+
|
367 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
368 |
+
def __init__(self):
|
369 |
+
super(MultiScaleDiscriminator, self).__init__()
|
370 |
+
self.discriminators = nn.ModuleList([
|
371 |
+
DiscriminatorS(use_spectral_norm=True),
|
372 |
+
DiscriminatorS(),
|
373 |
+
DiscriminatorS(),
|
374 |
+
])
|
375 |
+
self.meanpools = nn.ModuleList([
|
376 |
+
AvgPool1d(4, 2, padding=2),
|
377 |
+
AvgPool1d(4, 2, padding=2)
|
378 |
+
])
|
379 |
+
|
380 |
+
def forward(self, y, y_hat):
|
381 |
+
y_d_rs = []
|
382 |
+
y_d_gs = []
|
383 |
+
fmap_rs = []
|
384 |
+
fmap_gs = []
|
385 |
+
for i, d in enumerate(self.discriminators):
|
386 |
+
if i != 0:
|
387 |
+
y = self.meanpools[i - 1](y)
|
388 |
+
y_hat = self.meanpools[i - 1](y_hat)
|
389 |
+
y_d_r, fmap_r = d(y)
|
390 |
+
y_d_g, fmap_g = d(y_hat)
|
391 |
+
y_d_rs.append(y_d_r)
|
392 |
+
fmap_rs.append(fmap_r)
|
393 |
+
y_d_gs.append(y_d_g)
|
394 |
+
fmap_gs.append(fmap_g)
|
395 |
+
|
396 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
397 |
+
|
398 |
+
|
399 |
+
def feature_loss(fmap_r, fmap_g):
|
400 |
+
loss = 0
|
401 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
402 |
+
for rl, gl in zip(dr, dg):
|
403 |
+
loss += torch.mean(torch.abs(rl - gl))
|
404 |
+
|
405 |
+
return loss * 2
|
406 |
+
|
407 |
+
|
408 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
409 |
+
loss = 0
|
410 |
+
r_losses = []
|
411 |
+
g_losses = []
|
412 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
413 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
414 |
+
g_loss = torch.mean(dg ** 2)
|
415 |
+
loss += (r_loss + g_loss)
|
416 |
+
r_losses.append(r_loss.item())
|
417 |
+
g_losses.append(g_loss.item())
|
418 |
+
|
419 |
+
return loss, r_losses, g_losses
|
420 |
+
|
421 |
+
|
422 |
+
def generator_loss(disc_outputs):
|
423 |
+
loss = 0
|
424 |
+
gen_losses = []
|
425 |
+
for dg in disc_outputs:
|
426 |
+
l = torch.mean((1 - dg) ** 2)
|
427 |
+
gen_losses.append(l)
|
428 |
+
loss += l
|
429 |
+
|
430 |
+
return loss, gen_losses
|
DDSP-SVC/nsf_hifigan/nvSTFT.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
os.environ["LRU_CACHE_CAPACITY"] = "3"
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import torch.utils.data
|
7 |
+
import numpy as np
|
8 |
+
import librosa
|
9 |
+
from librosa.util import normalize
|
10 |
+
from librosa.filters import mel as librosa_mel_fn
|
11 |
+
from scipy.io.wavfile import read
|
12 |
+
import soundfile as sf
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
|
16 |
+
sampling_rate = None
|
17 |
+
try:
|
18 |
+
data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile.
|
19 |
+
except Exception as ex:
|
20 |
+
print(f"'{full_path}' failed to load.\nException:")
|
21 |
+
print(ex)
|
22 |
+
if return_empty_on_exception:
|
23 |
+
return [], sampling_rate or target_sr or 48000
|
24 |
+
else:
|
25 |
+
raise Exception(ex)
|
26 |
+
|
27 |
+
if len(data.shape) > 1:
|
28 |
+
data = data[:, 0]
|
29 |
+
assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension)
|
30 |
+
|
31 |
+
if np.issubdtype(data.dtype, np.integer): # if audio data is type int
|
32 |
+
max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX
|
33 |
+
else: # if audio data is type fp32
|
34 |
+
max_mag = max(np.amax(data), -np.amin(data))
|
35 |
+
max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32
|
36 |
+
|
37 |
+
data = torch.FloatTensor(data.astype(np.float32))/max_mag
|
38 |
+
|
39 |
+
if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except
|
40 |
+
return [], sampling_rate or target_sr or 48000
|
41 |
+
if target_sr is not None and sampling_rate != target_sr:
|
42 |
+
data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr))
|
43 |
+
sampling_rate = target_sr
|
44 |
+
|
45 |
+
return data, sampling_rate
|
46 |
+
|
47 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
48 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
49 |
+
|
50 |
+
def dynamic_range_decompression(x, C=1):
|
51 |
+
return np.exp(x) / C
|
52 |
+
|
53 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
54 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
55 |
+
|
56 |
+
def dynamic_range_decompression_torch(x, C=1):
|
57 |
+
return torch.exp(x) / C
|
58 |
+
|
59 |
+
class STFT():
|
60 |
+
def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
|
61 |
+
self.target_sr = sr
|
62 |
+
|
63 |
+
self.n_mels = n_mels
|
64 |
+
self.n_fft = n_fft
|
65 |
+
self.win_size = win_size
|
66 |
+
self.hop_length = hop_length
|
67 |
+
self.fmin = fmin
|
68 |
+
self.fmax = fmax
|
69 |
+
self.clip_val = clip_val
|
70 |
+
self.mel_basis = {}
|
71 |
+
self.hann_window = {}
|
72 |
+
|
73 |
+
def get_mel(self, y, keyshift=0, speed=1, center=False):
|
74 |
+
sampling_rate = self.target_sr
|
75 |
+
n_mels = self.n_mels
|
76 |
+
n_fft = self.n_fft
|
77 |
+
win_size = self.win_size
|
78 |
+
hop_length = self.hop_length
|
79 |
+
fmin = self.fmin
|
80 |
+
fmax = self.fmax
|
81 |
+
clip_val = self.clip_val
|
82 |
+
|
83 |
+
factor = 2 ** (keyshift / 12)
|
84 |
+
n_fft_new = int(np.round(n_fft * factor))
|
85 |
+
win_size_new = int(np.round(win_size * factor))
|
86 |
+
hop_length_new = int(np.round(hop_length * speed))
|
87 |
+
|
88 |
+
if torch.min(y) < -1.:
|
89 |
+
print('min value is ', torch.min(y))
|
90 |
+
if torch.max(y) > 1.:
|
91 |
+
print('max value is ', torch.max(y))
|
92 |
+
|
93 |
+
mel_basis_key = str(fmax)+'_'+str(y.device)
|
94 |
+
if mel_basis_key not in self.mel_basis:
|
95 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
96 |
+
self.mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device)
|
97 |
+
|
98 |
+
keyshift_key = str(keyshift)+'_'+str(y.device)
|
99 |
+
if keyshift_key not in self.hann_window:
|
100 |
+
self.hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
|
101 |
+
|
102 |
+
pad_left = (win_size_new - hop_length_new) //2
|
103 |
+
pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left)
|
104 |
+
if pad_right < y.size(-1):
|
105 |
+
mode = 'reflect'
|
106 |
+
else:
|
107 |
+
mode = 'constant'
|
108 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode)
|
109 |
+
y = y.squeeze(1)
|
110 |
+
|
111 |
+
spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=self.hann_window[keyshift_key],
|
112 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
113 |
+
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))
|
114 |
+
if keyshift != 0:
|
115 |
+
size = n_fft // 2 + 1
|
116 |
+
resize = spec.size(1)
|
117 |
+
if resize < size:
|
118 |
+
spec = F.pad(spec, (0, 0, 0, size-resize))
|
119 |
+
spec = spec[:, :size, :] * win_size / win_size_new
|
120 |
+
spec = torch.matmul(self.mel_basis[mel_basis_key], spec)
|
121 |
+
spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
|
122 |
+
return spec
|
123 |
+
|
124 |
+
def __call__(self, audiopath):
|
125 |
+
audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
|
126 |
+
spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
|
127 |
+
return spect
|
128 |
+
|
129 |
+
stft = STFT()
|
DDSP-SVC/nsf_hifigan/utils.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import matplotlib
|
4 |
+
import torch
|
5 |
+
from torch.nn.utils import weight_norm
|
6 |
+
matplotlib.use("Agg")
|
7 |
+
import matplotlib.pylab as plt
|
8 |
+
|
9 |
+
|
10 |
+
def plot_spectrogram(spectrogram):
|
11 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
12 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
13 |
+
interpolation='none')
|
14 |
+
plt.colorbar(im, ax=ax)
|
15 |
+
|
16 |
+
fig.canvas.draw()
|
17 |
+
plt.close()
|
18 |
+
|
19 |
+
return fig
|
20 |
+
|
21 |
+
|
22 |
+
def init_weights(m, mean=0.0, std=0.01):
|
23 |
+
classname = m.__class__.__name__
|
24 |
+
if classname.find("Conv") != -1:
|
25 |
+
m.weight.data.normal_(mean, std)
|
26 |
+
|
27 |
+
|
28 |
+
def apply_weight_norm(m):
|
29 |
+
classname = m.__class__.__name__
|
30 |
+
if classname.find("Conv") != -1:
|
31 |
+
weight_norm(m)
|
32 |
+
|
33 |
+
|
34 |
+
def get_padding(kernel_size, dilation=1):
|
35 |
+
return int((kernel_size*dilation - dilation)/2)
|
36 |
+
|
37 |
+
|
38 |
+
def load_checkpoint(filepath, device):
|
39 |
+
assert os.path.isfile(filepath)
|
40 |
+
print("Loading '{}'".format(filepath))
|
41 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
42 |
+
print("Complete.")
|
43 |
+
return checkpoint_dict
|
44 |
+
|
45 |
+
|
46 |
+
def save_checkpoint(filepath, obj):
|
47 |
+
print("Saving checkpoint to {}".format(filepath))
|
48 |
+
torch.save(obj, filepath)
|
49 |
+
print("Complete.")
|
50 |
+
|
51 |
+
|
52 |
+
def del_old_checkpoints(cp_dir, prefix, n_models=2):
|
53 |
+
pattern = os.path.join(cp_dir, prefix + '????????')
|
54 |
+
cp_list = glob.glob(pattern) # get checkpoint paths
|
55 |
+
cp_list = sorted(cp_list)# sort by iter
|
56 |
+
if len(cp_list) > n_models: # if more than n_models models are found
|
57 |
+
for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models
|
58 |
+
open(cp, 'w').close()# empty file contents
|
59 |
+
os.unlink(cp)# delete file (move to trash when using Colab)
|
60 |
+
|
61 |
+
|
62 |
+
def scan_checkpoint(cp_dir, prefix):
|
63 |
+
pattern = os.path.join(cp_dir, prefix + '????????')
|
64 |
+
cp_list = glob.glob(pattern)
|
65 |
+
if len(cp_list) == 0:
|
66 |
+
return None
|
67 |
+
return sorted(cp_list)[-1]
|
68 |
+
|
DDSP-SVC/preprocess.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import librosa
|
5 |
+
import torch
|
6 |
+
import pyworld as pw
|
7 |
+
import parselmouth
|
8 |
+
import argparse
|
9 |
+
import shutil
|
10 |
+
from logger import utils
|
11 |
+
from tqdm import tqdm
|
12 |
+
from ddsp.vocoder import F0_Extractor, Volume_Extractor, Units_Encoder
|
13 |
+
from diffusion.vocoder import Vocoder
|
14 |
+
from logger.utils import traverse_dir
|
15 |
+
import concurrent.futures
|
16 |
+
|
17 |
+
def parse_args(args=None, namespace=None):
|
18 |
+
"""Parse command-line arguments."""
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument(
|
21 |
+
"-c",
|
22 |
+
"--config",
|
23 |
+
type=str,
|
24 |
+
required=True,
|
25 |
+
help="path to the config file")
|
26 |
+
parser.add_argument(
|
27 |
+
"-d",
|
28 |
+
"--device",
|
29 |
+
type=str,
|
30 |
+
default=None,
|
31 |
+
required=False,
|
32 |
+
help="cpu or cuda, auto if not set")
|
33 |
+
return parser.parse_args(args=args, namespace=namespace)
|
34 |
+
|
35 |
+
def preprocess(path, f0_extractor, volume_extractor, mel_extractor, units_encoder, sample_rate, hop_size, device = 'cuda', use_pitch_aug = False):
|
36 |
+
|
37 |
+
path_srcdir = os.path.join(path, 'audio')
|
38 |
+
path_unitsdir = os.path.join(path, 'units')
|
39 |
+
path_f0dir = os.path.join(path, 'f0')
|
40 |
+
path_volumedir = os.path.join(path, 'volume')
|
41 |
+
path_augvoldir = os.path.join(path, 'aug_vol')
|
42 |
+
path_meldir = os.path.join(path, 'mel')
|
43 |
+
path_augmeldir = os.path.join(path, 'aug_mel')
|
44 |
+
path_skipdir = os.path.join(path, 'skip')
|
45 |
+
|
46 |
+
# list files
|
47 |
+
filelist = traverse_dir(
|
48 |
+
path_srcdir,
|
49 |
+
extension='wav',
|
50 |
+
is_pure=True,
|
51 |
+
is_sort=True,
|
52 |
+
is_ext=True)
|
53 |
+
|
54 |
+
# pitch augmentation dictionary
|
55 |
+
pitch_aug_dict = {}
|
56 |
+
|
57 |
+
# run
|
58 |
+
def process(file):
|
59 |
+
ext = file.split('.')[-1]
|
60 |
+
binfile = file[:-(len(ext)+1)]+'.npy'
|
61 |
+
path_srcfile = os.path.join(path_srcdir, file)
|
62 |
+
path_unitsfile = os.path.join(path_unitsdir, binfile)
|
63 |
+
path_f0file = os.path.join(path_f0dir, binfile)
|
64 |
+
path_volumefile = os.path.join(path_volumedir, binfile)
|
65 |
+
path_augvolfile = os.path.join(path_augvoldir, binfile)
|
66 |
+
path_melfile = os.path.join(path_meldir, binfile)
|
67 |
+
path_augmelfile = os.path.join(path_augmeldir, binfile)
|
68 |
+
path_skipfile = os.path.join(path_skipdir, file)
|
69 |
+
|
70 |
+
# load audio
|
71 |
+
audio, _ = librosa.load(path_srcfile, sr=sample_rate)
|
72 |
+
if len(audio.shape) > 1:
|
73 |
+
audio = librosa.to_mono(audio)
|
74 |
+
audio_t = torch.from_numpy(audio).float().to(device)
|
75 |
+
audio_t = audio_t.unsqueeze(0)
|
76 |
+
|
77 |
+
# extract volume
|
78 |
+
volume = volume_extractor.extract(audio)
|
79 |
+
|
80 |
+
# extract mel and volume augmentaion
|
81 |
+
if mel_extractor is not None:
|
82 |
+
mel_t = mel_extractor.extract(audio_t, sample_rate)
|
83 |
+
mel = mel_t.squeeze().to('cpu').numpy()
|
84 |
+
|
85 |
+
max_amp = float(torch.max(torch.abs(audio_t))) + 1e-5
|
86 |
+
max_shift = min(1, np.log10(1/max_amp))
|
87 |
+
log10_vol_shift = random.uniform(-1, max_shift)
|
88 |
+
if use_pitch_aug:
|
89 |
+
keyshift = random.uniform(-5, 5)
|
90 |
+
else:
|
91 |
+
keyshift = 0
|
92 |
+
|
93 |
+
aug_mel_t = mel_extractor.extract(audio_t * (10 ** log10_vol_shift), sample_rate, keyshift = keyshift)
|
94 |
+
aug_mel = aug_mel_t.squeeze().to('cpu').numpy()
|
95 |
+
aug_vol = volume_extractor.extract(audio * (10 ** log10_vol_shift))
|
96 |
+
|
97 |
+
# units encode
|
98 |
+
units_t = units_encoder.encode(audio_t, sample_rate, hop_size)
|
99 |
+
units = units_t.squeeze().to('cpu').numpy()
|
100 |
+
|
101 |
+
# extract f0
|
102 |
+
f0 = f0_extractor.extract(audio, uv_interp = False)
|
103 |
+
|
104 |
+
uv = f0 == 0
|
105 |
+
if len(f0[~uv]) > 0:
|
106 |
+
# interpolate the unvoiced f0
|
107 |
+
f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
|
108 |
+
|
109 |
+
# save npy
|
110 |
+
os.makedirs(os.path.dirname(path_unitsfile), exist_ok=True)
|
111 |
+
np.save(path_unitsfile, units)
|
112 |
+
os.makedirs(os.path.dirname(path_f0file), exist_ok=True)
|
113 |
+
np.save(path_f0file, f0)
|
114 |
+
os.makedirs(os.path.dirname(path_volumefile), exist_ok=True)
|
115 |
+
np.save(path_volumefile, volume)
|
116 |
+
if mel_extractor is not None:
|
117 |
+
pitch_aug_dict[file[:-(len(ext)+1)]] = keyshift
|
118 |
+
os.makedirs(os.path.dirname(path_melfile), exist_ok=True)
|
119 |
+
np.save(path_melfile, mel)
|
120 |
+
os.makedirs(os.path.dirname(path_augmelfile), exist_ok=True)
|
121 |
+
np.save(path_augmelfile, aug_mel)
|
122 |
+
os.makedirs(os.path.dirname(path_augvolfile), exist_ok=True)
|
123 |
+
np.save(path_augvolfile, aug_vol)
|
124 |
+
else:
|
125 |
+
print('\n[Error] F0 extraction failed: ' + path_srcfile)
|
126 |
+
os.makedirs(os.path.dirname(path_skipfile), exist_ok=True)
|
127 |
+
shutil.move(path_srcfile, os.path.dirname(path_skipfile))
|
128 |
+
print('This file has been moved to ' + path_skipfile)
|
129 |
+
print('Preprocess the audio clips in :', path_srcdir)
|
130 |
+
|
131 |
+
# single process
|
132 |
+
for file in tqdm(filelist, total=len(filelist)):
|
133 |
+
process(file)
|
134 |
+
|
135 |
+
if mel_extractor is not None:
|
136 |
+
path_pitchaugdict = os.path.join(path, 'pitch_aug_dict.npy')
|
137 |
+
np.save(path_pitchaugdict, pitch_aug_dict)
|
138 |
+
# multi-process (have bugs)
|
139 |
+
'''
|
140 |
+
with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor:
|
141 |
+
list(tqdm(executor.map(process, filelist), total=len(filelist)))
|
142 |
+
'''
|
143 |
+
|
144 |
+
if __name__ == '__main__':
|
145 |
+
# parse commands
|
146 |
+
cmd = parse_args()
|
147 |
+
|
148 |
+
device = cmd.device
|
149 |
+
if device is None:
|
150 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
151 |
+
|
152 |
+
# load config
|
153 |
+
args = utils.load_config(cmd.config)
|
154 |
+
sample_rate = args.data.sampling_rate
|
155 |
+
hop_size = args.data.block_size
|
156 |
+
|
157 |
+
# initialize f0 extractor
|
158 |
+
f0_extractor = F0_Extractor(
|
159 |
+
args.data.f0_extractor,
|
160 |
+
args.data.sampling_rate,
|
161 |
+
args.data.block_size,
|
162 |
+
args.data.f0_min,
|
163 |
+
args.data.f0_max)
|
164 |
+
|
165 |
+
# initialize volume extractor
|
166 |
+
volume_extractor = Volume_Extractor(args.data.block_size)
|
167 |
+
|
168 |
+
# initialize mel extractor
|
169 |
+
mel_extractor = None
|
170 |
+
use_pitch_aug = False
|
171 |
+
if args.model.type == 'Diffusion':
|
172 |
+
mel_extractor = Vocoder(args.vocoder.type, args.vocoder.ckpt, device = device)
|
173 |
+
if mel_extractor.vocoder_sample_rate != sample_rate or mel_extractor.vocoder_hop_size != hop_size:
|
174 |
+
mel_extractor = None
|
175 |
+
print('Unmatch vocoder parameters, mel extraction is ignored!')
|
176 |
+
elif args.model.use_pitch_aug:
|
177 |
+
use_pitch_aug = True
|
178 |
+
|
179 |
+
# initialize units encoder
|
180 |
+
if args.data.encoder == 'cnhubertsoftfish':
|
181 |
+
cnhubertsoft_gate = args.data.cnhubertsoft_gate
|
182 |
+
else:
|
183 |
+
cnhubertsoft_gate = 10
|
184 |
+
units_encoder = Units_Encoder(
|
185 |
+
args.data.encoder,
|
186 |
+
args.data.encoder_ckpt,
|
187 |
+
args.data.encoder_sample_rate,
|
188 |
+
args.data.encoder_hop_size,
|
189 |
+
cnhubertsoft_gate=cnhubertsoft_gate,
|
190 |
+
device = device)
|
191 |
+
|
192 |
+
# preprocess training set
|
193 |
+
preprocess(args.data.train_path, f0_extractor, volume_extractor, mel_extractor, units_encoder, sample_rate, hop_size, device = device, use_pitch_aug = use_pitch_aug)
|
194 |
+
|
195 |
+
# preprocess validation set
|
196 |
+
preprocess(args.data.valid_path, f0_extractor, volume_extractor, mel_extractor, units_encoder, sample_rate, hop_size, device = device, use_pitch_aug = False)
|
197 |
+
|
DDSP-SVC/pretrain/hubert/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
DDSP-SVC/pretrain/nsf_hifigan/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
DDSP-SVC/requirements.txt
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einops
|
2 |
+
fairseq
|
3 |
+
flask
|
4 |
+
flask_cors
|
5 |
+
gin
|
6 |
+
gin_config
|
7 |
+
librosa
|
8 |
+
local_attention
|
9 |
+
matplotlib
|
10 |
+
numpy
|
11 |
+
praat-parselmouth
|
12 |
+
pyworld
|
13 |
+
PyYAML
|
14 |
+
resampy
|
15 |
+
scikit_learn
|
16 |
+
scipy
|
17 |
+
SoundFile
|
18 |
+
tensorboard
|
19 |
+
torchcrepe
|
20 |
+
tqdm
|
21 |
+
transformers
|
22 |
+
wave
|
23 |
+
pysimplegui
|
24 |
+
sounddevice
|
25 |
+
gradio
|
DDSP-SVC/slicer.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
|
5 |
+
|
6 |
+
class Slicer:
|
7 |
+
def __init__(self,
|
8 |
+
sr: int,
|
9 |
+
threshold: float = -40.,
|
10 |
+
min_length: int = 5000,
|
11 |
+
min_interval: int = 300,
|
12 |
+
hop_size: int = 20,
|
13 |
+
max_sil_kept: int = 5000):
|
14 |
+
if not min_length >= min_interval >= hop_size:
|
15 |
+
raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size')
|
16 |
+
if not max_sil_kept >= hop_size:
|
17 |
+
raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size')
|
18 |
+
min_interval = sr * min_interval / 1000
|
19 |
+
self.threshold = 10 ** (threshold / 20.)
|
20 |
+
self.hop_size = round(sr * hop_size / 1000)
|
21 |
+
self.win_size = min(round(min_interval), 4 * self.hop_size)
|
22 |
+
self.min_length = round(sr * min_length / 1000 / self.hop_size)
|
23 |
+
self.min_interval = round(min_interval / self.hop_size)
|
24 |
+
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
|
25 |
+
|
26 |
+
def _apply_slice(self, waveform, begin, end):
|
27 |
+
if len(waveform.shape) > 1:
|
28 |
+
return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)]
|
29 |
+
else:
|
30 |
+
return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)]
|
31 |
+
|
32 |
+
# @timeit
|
33 |
+
def slice(self, waveform):
|
34 |
+
if len(waveform.shape) > 1:
|
35 |
+
samples = librosa.to_mono(waveform)
|
36 |
+
else:
|
37 |
+
samples = waveform
|
38 |
+
if samples.shape[0] <= self.min_length:
|
39 |
+
return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
|
40 |
+
rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
41 |
+
sil_tags = []
|
42 |
+
silence_start = None
|
43 |
+
clip_start = 0
|
44 |
+
for i, rms in enumerate(rms_list):
|
45 |
+
# Keep looping while frame is silent.
|
46 |
+
if rms < self.threshold:
|
47 |
+
# Record start of silent frames.
|
48 |
+
if silence_start is None:
|
49 |
+
silence_start = i
|
50 |
+
continue
|
51 |
+
# Keep looping while frame is not silent and silence start has not been recorded.
|
52 |
+
if silence_start is None:
|
53 |
+
continue
|
54 |
+
# Clear recorded silence start if interval is not enough or clip is too short
|
55 |
+
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
56 |
+
need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
|
57 |
+
if not is_leading_silence and not need_slice_middle:
|
58 |
+
silence_start = None
|
59 |
+
continue
|
60 |
+
# Need slicing. Record the range of silent frames to be removed.
|
61 |
+
if i - silence_start <= self.max_sil_kept:
|
62 |
+
pos = rms_list[silence_start: i + 1].argmin() + silence_start
|
63 |
+
if silence_start == 0:
|
64 |
+
sil_tags.append((0, pos))
|
65 |
+
else:
|
66 |
+
sil_tags.append((pos, pos))
|
67 |
+
clip_start = pos
|
68 |
+
elif i - silence_start <= self.max_sil_kept * 2:
|
69 |
+
pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin()
|
70 |
+
pos += i - self.max_sil_kept
|
71 |
+
pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
72 |
+
pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
|
73 |
+
if silence_start == 0:
|
74 |
+
sil_tags.append((0, pos_r))
|
75 |
+
clip_start = pos_r
|
76 |
+
else:
|
77 |
+
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
|
78 |
+
clip_start = max(pos_r, pos)
|
79 |
+
else:
|
80 |
+
pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
81 |
+
pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
|
82 |
+
if silence_start == 0:
|
83 |
+
sil_tags.append((0, pos_r))
|
84 |
+
else:
|
85 |
+
sil_tags.append((pos_l, pos_r))
|
86 |
+
clip_start = pos_r
|
87 |
+
silence_start = None
|
88 |
+
# Deal with trailing silence.
|
89 |
+
total_frames = rms_list.shape[0]
|
90 |
+
if silence_start is not None and total_frames - silence_start >= self.min_interval:
|
91 |
+
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
92 |
+
pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start
|
93 |
+
sil_tags.append((pos, total_frames + 1))
|
94 |
+
# Apply and return slices.
|
95 |
+
if len(sil_tags) == 0:
|
96 |
+
return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
|
97 |
+
else:
|
98 |
+
chunks = []
|
99 |
+
# 第一段静音并非从头开始,补上有声片段
|
100 |
+
if sil_tags[0][0]:
|
101 |
+
chunks.append(
|
102 |
+
{"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"})
|
103 |
+
for i in range(0, len(sil_tags)):
|
104 |
+
# 标识有声片段(跳过第一段)
|
105 |
+
if i:
|
106 |
+
chunks.append({"slice": False,
|
107 |
+
"split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"})
|
108 |
+
# 标识所有静音片段
|
109 |
+
chunks.append({"slice": True,
|
110 |
+
"split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"})
|
111 |
+
# 最后一段静音并非结尾,补上结尾片段
|
112 |
+
if sil_tags[-1][1] * self.hop_size < len(waveform):
|
113 |
+
chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"})
|
114 |
+
chunk_dict = {}
|
115 |
+
for i in range(len(chunks)):
|
116 |
+
chunk_dict[str(i)] = chunks[i]
|
117 |
+
return chunk_dict
|
118 |
+
|
119 |
+
|
120 |
+
def cut(audio_path, db_thresh=-30, min_len=5000, flask_mode=False, flask_sr=None):
|
121 |
+
if not flask_mode:
|
122 |
+
audio, sr = librosa.load(audio_path, sr=None)
|
123 |
+
else:
|
124 |
+
audio = audio_path
|
125 |
+
sr = flask_sr
|
126 |
+
slicer = Slicer(
|
127 |
+
sr=sr,
|
128 |
+
threshold=db_thresh,
|
129 |
+
min_length=min_len
|
130 |
+
)
|
131 |
+
chunks = slicer.slice(audio)
|
132 |
+
return chunks
|
133 |
+
|
134 |
+
|
135 |
+
def chunks2audio(audio_path, chunks):
|
136 |
+
chunks = dict(chunks)
|
137 |
+
audio, sr = torchaudio.load(audio_path)
|
138 |
+
if len(audio.shape) == 2 and audio.shape[1] >= 2:
|
139 |
+
audio = torch.mean(audio, dim=0).unsqueeze(0)
|
140 |
+
audio = audio.cpu().numpy()[0]
|
141 |
+
result = []
|
142 |
+
for k, v in chunks.items():
|
143 |
+
tag = v["split_time"].split(",")
|
144 |
+
if tag[0] != tag[1]:
|
145 |
+
result.append((v["slice"], audio[int(tag[0]):int(tag[1])]))
|
146 |
+
return result, sr
|
DDSP-SVC/solver.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from logger.saver import Saver
|
7 |
+
from logger import utils
|
8 |
+
|
9 |
+
def test(args, model, loss_func, loader_test, saver):
|
10 |
+
print(' [*] testing...')
|
11 |
+
model.eval()
|
12 |
+
|
13 |
+
# losses
|
14 |
+
test_loss = 0.
|
15 |
+
test_loss_rss = 0.
|
16 |
+
test_loss_uv = 0.
|
17 |
+
|
18 |
+
# intialization
|
19 |
+
num_batches = len(loader_test)
|
20 |
+
rtf_all = []
|
21 |
+
|
22 |
+
# run
|
23 |
+
with torch.no_grad():
|
24 |
+
for bidx, data in enumerate(loader_test):
|
25 |
+
fn = data['name'][0]
|
26 |
+
print('--------')
|
27 |
+
print('{}/{} - {}'.format(bidx, num_batches, fn))
|
28 |
+
|
29 |
+
# unpack data
|
30 |
+
for k in data.keys():
|
31 |
+
if k != 'name':
|
32 |
+
data[k] = data[k].to(args.device)
|
33 |
+
print('>>', data['name'][0])
|
34 |
+
|
35 |
+
# forward
|
36 |
+
st_time = time.time()
|
37 |
+
signal, _, (s_h, s_n) = model(data['units'], data['f0'], data['volume'], data['spk_id'])
|
38 |
+
ed_time = time.time()
|
39 |
+
|
40 |
+
# crop
|
41 |
+
min_len = np.min([signal.shape[1], data['audio'].shape[1]])
|
42 |
+
signal = signal[:,:min_len]
|
43 |
+
data['audio'] = data['audio'][:,:min_len]
|
44 |
+
|
45 |
+
# RTF
|
46 |
+
run_time = ed_time - st_time
|
47 |
+
song_time = data['audio'].shape[-1] / args.data.sampling_rate
|
48 |
+
rtf = run_time / song_time
|
49 |
+
print('RTF: {} | {} / {}'.format(rtf, run_time, song_time))
|
50 |
+
rtf_all.append(rtf)
|
51 |
+
|
52 |
+
# loss
|
53 |
+
loss = loss_func(signal, data['audio'])
|
54 |
+
|
55 |
+
test_loss += loss.item()
|
56 |
+
|
57 |
+
# log
|
58 |
+
saver.log_audio({fn+'/gt.wav': data['audio'], fn+'/pred.wav': signal})
|
59 |
+
|
60 |
+
# report
|
61 |
+
test_loss /= num_batches
|
62 |
+
|
63 |
+
# check
|
64 |
+
print(' [test_loss] test_loss:', test_loss)
|
65 |
+
print(' Real Time Factor', np.mean(rtf_all))
|
66 |
+
return test_loss
|
67 |
+
|
68 |
+
|
69 |
+
def train(args, initial_global_step, model, optimizer, loss_func, loader_train, loader_test):
|
70 |
+
# saver
|
71 |
+
saver = Saver(args, initial_global_step=initial_global_step)
|
72 |
+
|
73 |
+
# model size
|
74 |
+
params_count = utils.get_network_paras_amount({'model': model})
|
75 |
+
saver.log_info('--- model size ---')
|
76 |
+
saver.log_info(params_count)
|
77 |
+
|
78 |
+
# run
|
79 |
+
best_loss = np.inf
|
80 |
+
num_batches = len(loader_train)
|
81 |
+
model.train()
|
82 |
+
saver.log_info('======= start training =======')
|
83 |
+
for epoch in range(args.train.epochs):
|
84 |
+
for batch_idx, data in enumerate(loader_train):
|
85 |
+
saver.global_step_increment()
|
86 |
+
optimizer.zero_grad()
|
87 |
+
|
88 |
+
# unpack data
|
89 |
+
for k in data.keys():
|
90 |
+
if k != 'name':
|
91 |
+
data[k] = data[k].to(args.device)
|
92 |
+
|
93 |
+
# forward
|
94 |
+
signal, _, (s_h, s_n) = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'], infer=False)
|
95 |
+
|
96 |
+
# loss
|
97 |
+
loss = loss_func(signal, data['audio'])
|
98 |
+
|
99 |
+
# handle nan loss
|
100 |
+
if torch.isnan(loss):
|
101 |
+
raise ValueError(' [x] nan loss ')
|
102 |
+
else:
|
103 |
+
# backpropagate
|
104 |
+
loss.backward()
|
105 |
+
optimizer.step()
|
106 |
+
|
107 |
+
# log loss
|
108 |
+
if saver.global_step % args.train.interval_log == 0:
|
109 |
+
saver.log_info(
|
110 |
+
'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | loss: {:.3f} | time: {} | step: {}'.format(
|
111 |
+
epoch,
|
112 |
+
batch_idx,
|
113 |
+
num_batches,
|
114 |
+
args.env.expdir,
|
115 |
+
args.train.interval_log/saver.get_interval_time(),
|
116 |
+
loss.item(),
|
117 |
+
saver.get_total_time(),
|
118 |
+
saver.global_step
|
119 |
+
)
|
120 |
+
)
|
121 |
+
|
122 |
+
saver.log_value({
|
123 |
+
'train/loss': loss.item()
|
124 |
+
})
|
125 |
+
|
126 |
+
# validation
|
127 |
+
if saver.global_step % args.train.interval_val == 0:
|
128 |
+
# save latest
|
129 |
+
saver.save_model(model, optimizer, postfix=f'{saver.global_step}')
|
130 |
+
|
131 |
+
# run testing set
|
132 |
+
test_loss = test(args, model, loss_func, loader_test, saver)
|
133 |
+
|
134 |
+
saver.log_info(
|
135 |
+
' --- <validation> --- \nloss: {:.3f}. '.format(
|
136 |
+
test_loss,
|
137 |
+
)
|
138 |
+
)
|
139 |
+
|
140 |
+
saver.log_value({
|
141 |
+
'validation/loss': test_loss
|
142 |
+
})
|
143 |
+
model.train()
|
144 |
+
|
145 |
+
# save best model
|
146 |
+
if test_loss < best_loss:
|
147 |
+
saver.log_info(' [V] best model updated.')
|
148 |
+
saver.save_model(model, optimizer, postfix='best')
|
149 |
+
best_loss = test_loss
|
150 |
+
|
151 |
+
|
DDSP-SVC/train.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from logger import utils
|
6 |
+
from data_loaders import get_data_loaders
|
7 |
+
from solver import train
|
8 |
+
from ddsp.vocoder import Sins, CombSub, CombSubFast
|
9 |
+
from ddsp.loss import RSSLoss
|
10 |
+
|
11 |
+
|
12 |
+
def parse_args(args=None, namespace=None):
|
13 |
+
"""Parse command-line arguments."""
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument(
|
16 |
+
"-c",
|
17 |
+
"--config",
|
18 |
+
type=str,
|
19 |
+
required=True,
|
20 |
+
help="path to the config file")
|
21 |
+
return parser.parse_args(args=args, namespace=namespace)
|
22 |
+
|
23 |
+
|
24 |
+
if __name__ == '__main__':
|
25 |
+
# parse commands
|
26 |
+
cmd = parse_args()
|
27 |
+
|
28 |
+
# load config
|
29 |
+
args = utils.load_config(cmd.config)
|
30 |
+
print(' > config:', cmd.config)
|
31 |
+
print(' > exp:', args.env.expdir)
|
32 |
+
|
33 |
+
# load model
|
34 |
+
model = None
|
35 |
+
|
36 |
+
if args.model.type == 'Sins':
|
37 |
+
model = Sins(
|
38 |
+
sampling_rate=args.data.sampling_rate,
|
39 |
+
block_size=args.data.block_size,
|
40 |
+
n_harmonics=args.model.n_harmonics,
|
41 |
+
n_mag_allpass=args.model.n_mag_allpass,
|
42 |
+
n_mag_noise=args.model.n_mag_noise,
|
43 |
+
n_unit=args.data.encoder_out_channels,
|
44 |
+
n_spk=args.model.n_spk)
|
45 |
+
|
46 |
+
elif args.model.type == 'CombSub':
|
47 |
+
model = CombSub(
|
48 |
+
sampling_rate=args.data.sampling_rate,
|
49 |
+
block_size=args.data.block_size,
|
50 |
+
n_mag_allpass=args.model.n_mag_allpass,
|
51 |
+
n_mag_harmonic=args.model.n_mag_harmonic,
|
52 |
+
n_mag_noise=args.model.n_mag_noise,
|
53 |
+
n_unit=args.data.encoder_out_channels,
|
54 |
+
n_spk=args.model.n_spk)
|
55 |
+
|
56 |
+
elif args.model.type == 'CombSubFast':
|
57 |
+
model = CombSubFast(
|
58 |
+
sampling_rate=args.data.sampling_rate,
|
59 |
+
block_size=args.data.block_size,
|
60 |
+
n_unit=args.data.encoder_out_channels,
|
61 |
+
n_spk=args.model.n_spk)
|
62 |
+
|
63 |
+
else:
|
64 |
+
raise ValueError(f" [x] Unknown Model: {args.model.type}")
|
65 |
+
|
66 |
+
# load parameters
|
67 |
+
optimizer = torch.optim.AdamW(model.parameters())
|
68 |
+
initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device)
|
69 |
+
for param_group in optimizer.param_groups:
|
70 |
+
param_group['lr'] = args.train.lr
|
71 |
+
param_group['weight_decay'] = args.train.weight_decay
|
72 |
+
|
73 |
+
# loss
|
74 |
+
loss_func = RSSLoss(args.loss.fft_min, args.loss.fft_max, args.loss.n_scale, device = args.device)
|
75 |
+
|
76 |
+
# device
|
77 |
+
if args.device == 'cuda':
|
78 |
+
torch.cuda.set_device(args.env.gpu_id)
|
79 |
+
model.to(args.device)
|
80 |
+
|
81 |
+
for state in optimizer.state.values():
|
82 |
+
for k, v in state.items():
|
83 |
+
if torch.is_tensor(v):
|
84 |
+
state[k] = v.to(args.device)
|
85 |
+
|
86 |
+
loss_func.to(args.device)
|
87 |
+
|
88 |
+
# datas
|
89 |
+
loader_train, loader_valid = get_data_loaders(args, whole_audio=False)
|
90 |
+
|
91 |
+
# run
|
92 |
+
train(args, initial_global_step, model, optimizer, loss_func, loader_train, loader_valid)
|
93 |
+
|
DDSP-SVC/train_diff.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
from torch.optim import lr_scheduler
|
5 |
+
from logger import utils
|
6 |
+
from diffusion.data_loaders import get_data_loaders
|
7 |
+
from diffusion.solver import train
|
8 |
+
from diffusion.unit2mel import Unit2Mel
|
9 |
+
from diffusion.vocoder import Vocoder
|
10 |
+
|
11 |
+
|
12 |
+
def parse_args(args=None, namespace=None):
|
13 |
+
"""Parse command-line arguments."""
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument(
|
16 |
+
"-c",
|
17 |
+
"--config",
|
18 |
+
type=str,
|
19 |
+
required=True,
|
20 |
+
help="path to the config file")
|
21 |
+
return parser.parse_args(args=args, namespace=namespace)
|
22 |
+
|
23 |
+
|
24 |
+
if __name__ == '__main__':
|
25 |
+
# parse commands
|
26 |
+
cmd = parse_args()
|
27 |
+
|
28 |
+
# load config
|
29 |
+
args = utils.load_config(cmd.config)
|
30 |
+
print(' > config:', cmd.config)
|
31 |
+
print(' > exp:', args.env.expdir)
|
32 |
+
|
33 |
+
# load vocoder
|
34 |
+
vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device)
|
35 |
+
|
36 |
+
# load model
|
37 |
+
model = Unit2Mel(
|
38 |
+
args.data.encoder_out_channels,
|
39 |
+
args.model.n_spk,
|
40 |
+
args.model.use_pitch_aug,
|
41 |
+
vocoder.dimension,
|
42 |
+
args.model.n_layers,
|
43 |
+
args.model.n_chans,
|
44 |
+
args.model.n_hidden)
|
45 |
+
|
46 |
+
|
47 |
+
# load parameters
|
48 |
+
optimizer = torch.optim.AdamW(model.parameters())
|
49 |
+
initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device)
|
50 |
+
for param_group in optimizer.param_groups:
|
51 |
+
param_group['lr'] = args.train.lr
|
52 |
+
param_group['weight_decay'] = args.train.weight_decay
|
53 |
+
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.train.decay_step, gamma=args.train.gamma)
|
54 |
+
|
55 |
+
# device
|
56 |
+
if args.device == 'cuda':
|
57 |
+
torch.cuda.set_device(args.env.gpu_id)
|
58 |
+
model.to(args.device)
|
59 |
+
|
60 |
+
for state in optimizer.state.values():
|
61 |
+
for k, v in state.items():
|
62 |
+
if torch.is_tensor(v):
|
63 |
+
state[k] = v.to(args.device)
|
64 |
+
|
65 |
+
# datas
|
66 |
+
loader_train, loader_valid = get_data_loaders(args, whole_audio=False)
|
67 |
+
|
68 |
+
# run
|
69 |
+
train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_valid)
|
70 |
+
|
DDSP-SVC/webui.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os,subprocess,yaml
|
3 |
+
|
4 |
+
class WebUI:
|
5 |
+
def __init__(self) -> None:
|
6 |
+
self.info=Info()
|
7 |
+
self.opt_cfg_pth='configs/opt.yaml'
|
8 |
+
self.main_ui()
|
9 |
+
|
10 |
+
def main_ui(self):
|
11 |
+
with gr.Blocks() as ui:
|
12 |
+
gr.Markdown('## 一个便于训练和推理的DDSP-webui,每一步的说明在下面,可以自己展开看。')
|
13 |
+
with gr.Tab("训练/Training"):
|
14 |
+
gr.Markdown(self.info.general)
|
15 |
+
with gr.Accordion('预训练模型说明',open=False):
|
16 |
+
gr.Markdown(self.info.pretrain_model)
|
17 |
+
with gr.Accordion('数据集说明',open=False):
|
18 |
+
gr.Markdown(self.info.dataset)
|
19 |
+
|
20 |
+
gr.Markdown('## 生成配置文件')
|
21 |
+
with gr.Row():
|
22 |
+
self.batch_size=gr.Slider(minimum=2,maximum=60,value=24,label='Batch_size',interactive=True)
|
23 |
+
self.learning_rate=gr.Number(value=0.0005,label='学习率',info='和batch_size关系大概是0.0001:6')
|
24 |
+
self.f0_extractor=gr.Dropdown(['parselmouth', 'dio', 'harvest', 'crepe'],type='value',value='crepe',label='f0提取器种类',interactive=True)
|
25 |
+
self.sampling_rate=gr.Number(value=44100,label='采样率',info='数据集音频的采样率',interactive=True)
|
26 |
+
self.n_spk=gr.Number(value=1,label='说话人数量',interactive=True)
|
27 |
+
with gr.Row():
|
28 |
+
self.device=gr.Dropdown(['cuda','cpu'],value='cuda',label='使用设备',interactive=True)
|
29 |
+
self.num_workers=gr.Number(value=2,label='读取数据进程数',info='如果你的设备性能很好,可以设置为0',interactive=True)
|
30 |
+
self.cache_all_data=gr.Checkbox(value=True,label='启用缓存',info='将数据全部加载以加速训练',interactive=True)
|
31 |
+
self.cache_device=gr.Dropdown(['cuda','cpu'],value='cuda',type='value',label='缓存设备',info='如果你的显存比较大,设置为cuda',interactive=True)
|
32 |
+
self.bt_create_config=gr.Button(value='创建配置文件')
|
33 |
+
|
34 |
+
gr.Markdown('## 预处理')
|
35 |
+
with gr.Accordion('预训练说明',open=False):
|
36 |
+
gr.Markdown(self.info.preprocess)
|
37 |
+
with gr.Row():
|
38 |
+
self.bt_open_data_folder=gr.Button('打开数据集文件夹')
|
39 |
+
self.bt_preprocess=gr.Button('开始预处理')
|
40 |
+
gr.Markdown('## 训练')
|
41 |
+
with gr.Accordion('训练说明',open=False):
|
42 |
+
gr.Markdown(self.info.train)
|
43 |
+
with gr.Row():
|
44 |
+
self.bt_train=gr.Button('开始训练')
|
45 |
+
self.bt_visual=gr.Button('启动可视化')
|
46 |
+
gr.Markdown('启动可视化后[点击打开](http://127.0.0.1:6006)')
|
47 |
+
|
48 |
+
with gr.Tab('推理/Inference'):
|
49 |
+
with gr.Accordion('推理说明',open=False):
|
50 |
+
gr.Markdown(self.info.infer)
|
51 |
+
with gr.Row():
|
52 |
+
self.input_wav=gr.Audio(type='filepath',label='选择待转换音频')
|
53 |
+
self.choose_model=gr.Textbox('exp/model_chino.pt',label='模型路径')
|
54 |
+
with gr.Row():
|
55 |
+
self.keychange=gr.Slider(-24,24,value=0,step=1,label='变调')
|
56 |
+
self.id=gr.Number(value=1,label='说话人id')
|
57 |
+
self.enhancer_adaptive_key=gr.Number(value=0,label='增强器音区偏移',info='调高可以防止超高音(比如大于G5) 破音,但是低音效果可能会下降')
|
58 |
+
with gr.Row():
|
59 |
+
self.bt_infer=gr.Button(value='开始转换')
|
60 |
+
self.output_wav=gr.Audio(type='filepath',label='输出音频')
|
61 |
+
|
62 |
+
self.bt_create_config.click(fn=self.create_config)
|
63 |
+
self.bt_open_data_folder.click(fn=self.openfolder)
|
64 |
+
self.bt_preprocess.click(fn=self.preprocess)
|
65 |
+
self.bt_train.click(fn=self.training)
|
66 |
+
self.bt_visual.click(fn=self.visualize)
|
67 |
+
self.bt_infer.click(fn=self.inference,inputs=[self.input_wav,self.choose_model,self.keychange,self.id,self.enhancer_adaptive_key],outputs=self.output_wav)
|
68 |
+
ui.launch(inbrowser=True,server_port=7858)
|
69 |
+
|
70 |
+
def openfolder(self):
|
71 |
+
try:
|
72 |
+
os.startfile('data')
|
73 |
+
except:
|
74 |
+
print('Fail to open folder!')
|
75 |
+
|
76 |
+
|
77 |
+
def create_config(self):
|
78 |
+
with open('configs/combsub.yaml','r',encoding='utf-8') as f:
|
79 |
+
cfg=yaml.load(f.read(),Loader=yaml.FullLoader)
|
80 |
+
cfg['data']['f0_extractor']=str(self.f0_extractor.value)
|
81 |
+
cfg['data']['sampling_rate']=int(self.sampling_rate.value)
|
82 |
+
cfg['train']['batch_size']=int(self.batch_size.value)
|
83 |
+
cfg['device']=str(self.device.value)
|
84 |
+
cfg['train']['num_workers']=int(self.num_workers.value)
|
85 |
+
cfg['train']['cache_all_data']=str(self.cache_all_data.value)
|
86 |
+
cfg['train']['cache_device']=str(self.cache_device.value)
|
87 |
+
cfg['train']['lr']=int(self.learning_rate.value)
|
88 |
+
print('配置文件信息:'+str(cfg))
|
89 |
+
with open(self.opt_cfg_pth,'w',encoding='utf-8') as f:
|
90 |
+
yaml.dump(cfg,f)
|
91 |
+
print('成功生成配置文件')
|
92 |
+
|
93 |
+
|
94 |
+
def preprocess(self):
|
95 |
+
preprocessing_process=subprocess.Popen('python -u preprocess.py -c '+self.opt_cfg_pth,stdout=subprocess.PIPE)
|
96 |
+
while preprocessing_process.poll() is None:
|
97 |
+
output=preprocessing_process.stdout.readline().decode('utf-8')
|
98 |
+
print(output)
|
99 |
+
print('预处理完成')
|
100 |
+
|
101 |
+
def training(self):
|
102 |
+
train_process=subprocess.Popen('python -u train.py -c '+self.opt_cfg_pth,stdout=subprocess.PIPE)
|
103 |
+
while train_process.poll() is None:
|
104 |
+
output=train_process.stdout.readline().decode('utf-8')
|
105 |
+
print(output)
|
106 |
+
|
107 |
+
|
108 |
+
def visualize(self):
|
109 |
+
tb_process=subprocess.Popen('tensorboard --logdir=exp --port=6006',stdout=subprocess.PIPE)
|
110 |
+
while tb_process.poll() is None:
|
111 |
+
output=tb_process.stdout.readline().decode('utf-8')
|
112 |
+
print(output)
|
113 |
+
|
114 |
+
def inference(self,input_wav:str,model:str,keychange,id,enhancer_adaptive_key):
|
115 |
+
print(input_wav,model)
|
116 |
+
output_wav='samples/'+ input_wav.replace('\\','/').split('/')[-1]
|
117 |
+
cmd='python -u main.py -i '+input_wav+' -m '+model+' -o '+output_wav+' -k '+str(int(keychange))+' -id '+str(int(id))+' -e true -eak '+str(int(enhancer_adaptive_key))
|
118 |
+
infer_process=subprocess.Popen(cmd,stdout=subprocess.PIPE)
|
119 |
+
while infer_process.poll() is None:
|
120 |
+
output=infer_process.stdout.readline().decode('utf-8')
|
121 |
+
print(output)
|
122 |
+
print('推理完成')
|
123 |
+
return output_wav
|
124 |
+
|
125 |
+
|
126 |
+
class Info:
|
127 |
+
def __init__(self) -> None:
|
128 |
+
self.general='''
|
129 |
+
### 不看也没事,大致就是
|
130 |
+
1.设置好配置之后点击创建配置文件
|
131 |
+
2.点击‘打开数据集文件夹’,把数据集选个十个塞到data\\train\\val目录下面,剩下的音频全塞到data\\train\\audio下面
|
132 |
+
3.点击‘开始预处理’等待执行完毕
|
133 |
+
4.点击‘开始训练’和‘启动可视化’然后点击右侧链接
|
134 |
+
'''
|
135 |
+
self.pretrain_model="""
|
136 |
+
- **(必要操作)** 下载预训练 [**HubertSoft**](https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt) 编码器并将其放到 `pretrain/hubert` 文件夹。
|
137 |
+
- 更新:现在支持 ContentVec 编码器了。你可以下载预训练 [ContentVec](https://ibm.ent.box.com/s/z1wgl1stco8ffooyatzdwsqn2psd9lrr) 编码器替代 HubertSoft 编码器并修改配置文件以使用它。
|
138 |
+
- 从 [DiffSinger 社区声码器项目](https://openvpi.github.io/vocoders) 下载基于预训练声码器的增强器,并解压至 `pretrain/` 文件夹。
|
139 |
+
- 注意:你应当下载名称中带有`nsf_hifigan`的压缩文件,而非`nsf_hifigan_finetune`。
|
140 |
+
"""
|
141 |
+
self.dataset="""
|
142 |
+
### 1. 配置训练数据集和验证数据集
|
143 |
+
|
144 |
+
#### 1.1 手动配置:
|
145 |
+
|
146 |
+
将所有的训练集数据 (.wav 格式音频切片) 放到 `data/train/audio`。
|
147 |
+
|
148 |
+
将所有的验证集数据 (.wav 格式音频切片) 放到 `data/val/audio`。
|
149 |
+
|
150 |
+
#### 1.2 程序随机选择(**多人物时不可使用**):
|
151 |
+
|
152 |
+
运行`python draw.py`,程序将帮助你挑选验证集数据(可以调整 `draw.py` 中的参数修改抽取文件的数量等参数)。
|
153 |
+
|
154 |
+
#### 1.3文件夹结构目录展示:
|
155 |
+
- 单人物目录结构:
|
156 |
+
|
157 |
+
```
|
158 |
+
data
|
159 |
+
├─ train
|
160 |
+
│ ├─ audio
|
161 |
+
│ │ ├─ aaa.wav
|
162 |
+
│ │ ├─ bbb.wav
|
163 |
+
│ │ └─ ....wav
|
164 |
+
│ └─ val
|
165 |
+
│ │ ├─ eee.wav
|
166 |
+
│ │ ├─ fff.wav
|
167 |
+
│ │ └─ ....wav
|
168 |
+
```
|
169 |
+
- 多人物目录结构:
|
170 |
+
|
171 |
+
```
|
172 |
+
data
|
173 |
+
├─ train
|
174 |
+
│ ├─ audio
|
175 |
+
│ │ ├─ 1
|
176 |
+
│ │ │ ├─ aaa.wav
|
177 |
+
│ │ │ ├─ bbb.wav
|
178 |
+
│ │ │ └─ ....wav
|
179 |
+
│ │ ├─ 2
|
180 |
+
│ │ │ ├─ ccc.wav
|
181 |
+
│ │ │ ├─ ddd.wav
|
182 |
+
│ │ │ └─ ....wav
|
183 |
+
│ │ └─ ...
|
184 |
+
│ └─ val
|
185 |
+
│ │ ├─ 1
|
186 |
+
│ │ │ ├─ eee.wav
|
187 |
+
│ │ │ ├─ fff.wav
|
188 |
+
│ │ │ └─ ....wav
|
189 |
+
│ │ ├─ 2
|
190 |
+
│ │ │ ├─ ggg.wav
|
191 |
+
│ │ │ ├─ hhh.wav
|
192 |
+
│ │ │ └─ ....wav
|
193 |
+
│ │ └─ ...
|
194 |
+
```
|
195 |
+
"""
|
196 |
+
self.preprocess='''
|
197 |
+
您可以在预处理之前修改配置文件 `config/<model_name>.yaml`,默认配置适用于GTX-1660 显卡训练 44.1khz 高采样率合成器。
|
198 |
+
### 备注:
|
199 |
+
1. 请保持所有音频切片的采样率与 yaml 配置文件中的采样率一致!如果不一致,程序可以跑,但训练过程中的重新采样将非常缓慢。(可选:使用Adobe Audition™的响度匹配功能可以一次性完成重采样修改声道和响度匹配。)
|
200 |
+
|
201 |
+
2. 训练数据集的音频切片���数建议为约 1000 个,另外长音频切成小段可以加快训练速度,但所有音频切片的时长不应少于 2 秒。如果音频切片太多,则需要较大的内存,配置文件中将 `cache_all_data` 选项设置为 false 可以解决此问题。
|
202 |
+
|
203 |
+
3. 验证集的音频切片总数建议为 10 个左右,不要放太多,不然验证过程会很慢。
|
204 |
+
|
205 |
+
4. 如果您的数据集质量不是很高,请在配置文件中将 'f0_extractor' 设为 'crepe'。crepe 算法的抗噪性最好,但代价是会极大增加数据预处理所需的时间。
|
206 |
+
|
207 |
+
5. 配置文件中的 ‘n_spk’ 参数将控制是否训练多说话人模型。如果您要训练**多说话人**模型,为了对说话人进行编号,所有音频文件夹的名称必须是**不大于 ‘n_spk’ 的正整数**。
|
208 |
+
'''
|
209 |
+
self.train='''
|
210 |
+
## 训练
|
211 |
+
|
212 |
+
### 1. 不使用预训练数据进行训练:
|
213 |
+
```bash
|
214 |
+
# 以训练 combsub 模型为例
|
215 |
+
python train.py -c configs/combsub.yaml
|
216 |
+
```
|
217 |
+
1. 训练其他模型方法类似。
|
218 |
+
|
219 |
+
2. 可以随时中止训练,然后运行相同的命令来继续训练。
|
220 |
+
|
221 |
+
3. 微调 (finetune):在中止训练后,重新预处理新数据集或更改训练参数(batchsize、lr等),然后运行相同的命令。
|
222 |
+
### 2. 使用预训练数据(底模)进行训练:
|
223 |
+
1. **使用预训练模型请修改配置文件中的 'n_spk' 参数为 '2' ,同时配置`train`目录结构为多人物目录,不论你是否训练多说话人模型。**
|
224 |
+
2. **如果你要训练一个更多说话人的模型,就不要下载预训练模型了。**
|
225 |
+
3. 欢迎PR训练的多人底模 (请使用授权同意开源的数据集进行训练)。
|
226 |
+
4. 从[**这里**](https://github.com/yxlllc/DDSP-SVC/releases/download/2.0/opencpop+kiritan.zip)下载预训练模型,并将`model_300000.pt`解压到`.\exp\combsub-test\`中
|
227 |
+
5. 同不使用预训练数据进行训练一样,启动训练。
|
228 |
+
'''
|
229 |
+
self.visualize='''
|
230 |
+
## 可视化
|
231 |
+
```bash
|
232 |
+
# 使用tensorboard检查训练状态
|
233 |
+
tensorboard --logdir=exp
|
234 |
+
```
|
235 |
+
第一次验证 (validation) 后,在 TensorBoard 中可以看到合成后的测试音频。
|
236 |
+
|
237 |
+
注:TensorBoard 中的测试音频是 DDSP-SVC 模型的原始输出,并未通过增强器增强。
|
238 |
+
'''
|
239 |
+
self.infer='''
|
240 |
+
## 非实时变声
|
241 |
+
1. (**推荐**)使用预训练声码器增强 DDSP 的输出结果:
|
242 |
+
```bash
|
243 |
+
# 默认 enhancer_adaptive_key = 0 正常音域范围内将有更高的音质
|
244 |
+
# 设置 enhancer_adaptive_key > 0 可将增强器适配于更高的音域
|
245 |
+
python main.py -i <input.wav> -m <model_file.pt> -o <output.wav> -k <keychange (semitones)> -id <speaker_id> -e true -eak <enhancer_adaptive_key (semitones)>
|
246 |
+
```
|
247 |
+
2. DDSP 的原始输出结果:
|
248 |
+
```bash
|
249 |
+
# 速度快,但音质相对较低(像您在tensorboard里听到的那样)
|
250 |
+
python main.py -i <input.wav> -m <model_file.pt> -o <output.wav> -k <keychange (semitones)> -e false -id <speaker_id>
|
251 |
+
```
|
252 |
+
3. 关于 f0 提取器、响应阈值及其他参数,参见:
|
253 |
+
|
254 |
+
```bash
|
255 |
+
python main.py -h
|
256 |
+
```
|
257 |
+
4. 如果要使用混合说话人(捏音色)功能,增添 “-mix” 选项来设计音色,下面是个例子:
|
258 |
+
```bash
|
259 |
+
# 将1号说话人和2号说话人的音色按照0.5:0.5的比例混合
|
260 |
+
python main.py -i <input.wav> -m <model_file.pt> -o <output.wav> -k <keychange (semitones)> -mix "{1:0.5, 2:0.5}" -e true -eak 0
|
261 |
+
```
|
262 |
+
'''
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
|
267 |
+
webui=WebUI()
|