ddsp-demo / DDSP-SVC /main_diff.py
pdjdev's picture
add ddsp-svc
85a7d2c
raw
history blame
12.9 kB
import os
import torch
import librosa
import argparse
import numpy as np
import soundfile as sf
import pyworld as pw
import parselmouth
import hashlib
from ast import literal_eval
from slicer import Slicer
from ddsp.vocoder import load_model, F0_Extractor, Volume_Extractor, Units_Encoder
from ddsp.core import upsample
from diffusion.unit2mel import load_model_vocoder
from tqdm import tqdm
def check_args(ddsp_args, diff_args):
if ddsp_args.data.sampling_rate != diff_args.data.sampling_rate:
print("Unmatch data.sampling_rate!")
return False
if ddsp_args.data.block_size != diff_args.data.block_size:
print("Unmatch data.block_size!")
return False
if ddsp_args.data.encoder != diff_args.data.encoder:
print("Unmatch data.encoder!")
return False
return True
def parse_args(args=None, namespace=None):
"""Parse command-line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument(
"-diff",
"--diff_ckpt",
type=str,
required=True,
help="path to the diffusion model checkpoint",
)
parser.add_argument(
"-ddsp",
"--ddsp_ckpt",
type=str,
required=False,
default="None",
help="path to the DDSP model checkpoint (for shallow diffusion)",
)
parser.add_argument(
"-d",
"--device",
type=str,
default=None,
required=False,
help="cpu or cuda, auto if not set")
parser.add_argument(
"-i",
"--input",
type=str,
required=True,
help="path to the input audio file",
)
parser.add_argument(
"-o",
"--output",
type=str,
required=True,
help="path to the output audio file",
)
parser.add_argument(
"-id",
"--spk_id",
type=str,
required=False,
default=1,
help="speaker id (for multi-speaker model) | default: 1",
)
parser.add_argument(
"-mix",
"--spk_mix_dict",
type=str,
required=False,
default="None",
help="mix-speaker dictionary (for multi-speaker model) | default: None",
)
parser.add_argument(
"-k",
"--key",
type=str,
required=False,
default=0,
help="key changed (number of semitones) | default: 0",
)
parser.add_argument(
"-f",
"--formant_shift_key",
type=str,
required=False,
default=0,
help="formant changed (number of semitones) , only for pitch-augmented model| default: 0",
)
parser.add_argument(
"-pe",
"--pitch_extractor",
type=str,
required=False,
default='crepe',
help="pitch extrator type: parselmouth, dio, harvest, crepe (default)",
)
parser.add_argument(
"-fmin",
"--f0_min",
type=str,
required=False,
default=50,
help="min f0 (Hz) | default: 50",
)
parser.add_argument(
"-fmax",
"--f0_max",
type=str,
required=False,
default=1100,
help="max f0 (Hz) | default: 1100",
)
parser.add_argument(
"-th",
"--threhold",
type=str,
required=False,
default=-60,
help="response threhold (dB) | default: -60",
)
parser.add_argument(
"-diffid",
"--diff_spk_id",
type=str,
required=False,
default='auto',
help="diffusion speaker id (for multi-speaker model) | default: auto",
)
parser.add_argument(
"-speedup",
"--speedup",
type=str,
required=False,
default='auto',
help="speed up | default: auto",
)
parser.add_argument(
"-method",
"--method",
type=str,
required=False,
default='auto',
help="pndm or dpm-solver | default: auto",
)
parser.add_argument(
"-kstep",
"--k_step",
type=str,
required=False,
default=None,
help="shallow diffusion steps | default: None",
)
return parser.parse_args(args=args, namespace=namespace)
def split(audio, sample_rate, hop_size, db_thresh = -40, min_len = 5000):
slicer = Slicer(
sr=sample_rate,
threshold=db_thresh,
min_length=min_len)
chunks = dict(slicer.slice(audio))
result = []
for k, v in chunks.items():
tag = v["split_time"].split(",")
if tag[0] != tag[1]:
start_frame = int(int(tag[0]) // hop_size)
end_frame = int(int(tag[1]) // hop_size)
if end_frame > start_frame:
result.append((
start_frame,
audio[int(start_frame * hop_size) : int(end_frame * hop_size)]))
return result
def cross_fade(a: np.ndarray, b: np.ndarray, idx: int):
result = np.zeros(idx + b.shape[0])
fade_len = a.shape[0] - idx
np.copyto(dst=result[:idx], src=a[:idx])
k = np.linspace(0, 1.0, num=fade_len, endpoint=True)
result[idx: a.shape[0]] = (1 - k) * a[idx:] + k * b[: fade_len]
np.copyto(dst=result[a.shape[0]:], src=b[fade_len:])
return result
if __name__ == '__main__':
# parse commands
cmd = parse_args()
#device = 'cpu'
device = cmd.device
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# load diffusion model
model, vocoder, args = load_model_vocoder(cmd.diff_ckpt, device=device)
# load input
audio, sample_rate = librosa.load(cmd.input, sr=None)
if len(audio.shape) > 1:
audio = librosa.to_mono(audio)
hop_size = args.data.block_size * sample_rate / args.data.sampling_rate
# get MD5 hash from wav file
md5_hash = ""
with open(cmd.input, 'rb') as f:
data = f.read()
md5_hash = hashlib.md5(data).hexdigest()
print("MD5: " + md5_hash)
cache_dir_path = os.path.join(os.path.dirname(__file__), "cache")
cache_file_path = os.path.join(cache_dir_path, f"{cmd.pitch_extractor}_{hop_size}_{cmd.f0_min}_{cmd.f0_max}_{md5_hash}.npy")
is_cache_available = os.path.exists(cache_file_path)
if is_cache_available:
# f0 cache load
print('Loading pitch curves for input audio from cache directory...')
f0 = np.load(cache_file_path, allow_pickle=False)
else:
# extract f0
print('Pitch extractor type: ' + cmd.pitch_extractor)
pitch_extractor = F0_Extractor(
cmd.pitch_extractor,
sample_rate,
hop_size,
float(cmd.f0_min),
float(cmd.f0_max))
print('Extracting the pitch curve of the input audio...')
f0 = pitch_extractor.extract(audio, uv_interp = True, device = device)
# f0 cache save
os.makedirs(cache_dir_path, exist_ok=True)
np.save(cache_file_path, f0, allow_pickle=False)
f0 = torch.from_numpy(f0).float().to(device).unsqueeze(-1).unsqueeze(0)
# key change
f0 = f0 * 2 ** (float(cmd.key) / 12)
# formant change
formant_shift_key = torch.LongTensor(np.array([[float(cmd.formant_shift_key)]])).to(device)
# extract volume
print('Extracting the volume envelope of the input audio...')
volume_extractor = Volume_Extractor(hop_size)
volume = volume_extractor.extract(audio)
mask = (volume > 10 ** (float(cmd.threhold) / 20)).astype('float')
mask = np.pad(mask, (4, 4), constant_values=(mask[0], mask[-1]))
mask = np.array([np.max(mask[n : n + 9]) for n in range(len(mask) - 8)])
mask = torch.from_numpy(mask).float().to(device).unsqueeze(-1).unsqueeze(0)
mask = upsample(mask, args.data.block_size).squeeze(-1)
volume = torch.from_numpy(volume).float().to(device).unsqueeze(-1).unsqueeze(0)
# load units encoder
if args.data.encoder == 'cnhubertsoftfish':
cnhubertsoft_gate = args.data.cnhubertsoft_gate
else:
cnhubertsoft_gate = 10
units_encoder = Units_Encoder(
args.data.encoder,
args.data.encoder_ckpt,
args.data.encoder_sample_rate,
args.data.encoder_hop_size,
cnhubertsoft_gate=cnhubertsoft_gate,
device = device)
# speaker id or mix-speaker dictionary
spk_mix_dict = literal_eval(cmd.spk_mix_dict)
spk_id = torch.LongTensor(np.array([[int(cmd.spk_id)]])).to(device)
if cmd.diff_spk_id == 'auto':
diff_spk_id = spk_id
else:
diff_spk_id = torch.LongTensor(np.array([[int(cmd.diff_spk_id)]])).to(device)
if spk_mix_dict is not None:
print('Mix-speaker mode')
else:
print('DDSP Speaker ID: '+ str(int(cmd.spk_id)))
print('Diffusion Speaker ID: '+ str(cmd.diff_spk_id))
# speed up
if cmd.speedup == 'auto':
infer_speedup = args.infer.speedup
else:
infer_speedup = int(cmd.speedup)
if cmd.method == 'auto':
method = args.infer.method
else:
method = cmd.method
if infer_speedup > 1:
print('Sampling method: '+ method)
print('Speed up: '+ str(infer_speedup))
else:
print('Sampling method: DDPM')
ddsp = None
input_mel = None
k_step = None
if cmd.k_step is not None:
k_step = int(cmd.k_step)
print('Shallow diffusion step: ' + str(k_step))
if cmd.ddsp_ckpt != "None":
# load ddsp model
ddsp, ddsp_args = load_model(cmd.ddsp_ckpt, device=device)
if not check_args(ddsp_args, args):
print("Cannot use this DDSP model for shallow diffusion, gaussian diffusion will be used!")
ddsp = None
else:
print('DDSP model is not identified!')
print('Extracting the mel spectrum of the input audio for shallow diffusion...')
audio_t = torch.from_numpy(audio).float().unsqueeze(0).to(device)
input_mel = vocoder.extract(audio_t, sample_rate)
input_mel = torch.cat((input_mel, input_mel[:,-1:,:]), 1)
else:
print('Shallow diffusion step is not identified, gaussian diffusion will be used!')
# forward and save the output
result = np.zeros(0)
current_length = 0
segments = split(audio, sample_rate, hop_size)
print('Cut the input audio into ' + str(len(segments)) + ' slices')
with torch.no_grad():
for segment in tqdm(segments):
start_frame = segment[0]
seg_input = torch.from_numpy(segment[1]).float().unsqueeze(0).to(device)
seg_units = units_encoder.encode(seg_input, sample_rate, hop_size)
seg_f0 = f0[:, start_frame : start_frame + seg_units.size(1), :]
seg_volume = volume[:, start_frame : start_frame + seg_units.size(1), :]
if ddsp is not None:
seg_ddsp_f0 = 2 ** (-float(cmd.formant_shift_key) / 12) * seg_f0
seg_ddsp_output, _ , (_, _) = ddsp(seg_units, seg_ddsp_f0, seg_volume, spk_id = spk_id, spk_mix_dict = spk_mix_dict)
seg_input_mel = vocoder.extract(seg_ddsp_output, args.data.sampling_rate, keyshift=float(cmd.formant_shift_key))
elif input_mel != None:
seg_input_mel = input_mel[:, start_frame : start_frame + seg_units.size(1), :]
else:
seg_input_mel = None
seg_mel = model(
seg_units,
seg_f0,
seg_volume,
spk_id = diff_spk_id,
spk_mix_dict = spk_mix_dict,
aug_shift = formant_shift_key,
gt_spec=seg_input_mel,
infer=True,
infer_speedup=infer_speedup,
method=method,
k_step=k_step)
seg_output = vocoder.infer(seg_mel, seg_f0)
seg_output *= mask[:, start_frame * args.data.block_size : (start_frame + seg_units.size(1)) * args.data.block_size]
seg_output = seg_output.squeeze().cpu().numpy()
silent_length = round(start_frame * args.data.block_size) - current_length
if silent_length >= 0:
result = np.append(result, np.zeros(silent_length))
result = np.append(result, seg_output)
else:
result = cross_fade(result, seg_output, current_length + silent_length)
current_length = current_length + silent_length + len(seg_output)
sf.write(cmd.output, result, args.data.sampling_rate)