|
import argparse |
|
import pathlib |
|
import yaml |
|
import torch |
|
import torchaudio |
|
from torch.utils.data import DataLoader |
|
import numpy as np |
|
import random |
|
import librosa |
|
from dataset import Dataset |
|
import pickle |
|
from lightning_module import ( |
|
SSLStepLightningModule, |
|
SSLDualLightningModule, |
|
) |
|
from utils import plot_and_save_mels |
|
import os |
|
import tqdm |
|
|
|
|
|
class AETDataset(Dataset): |
|
def __init__(self, filetxt, src_config, tar_config): |
|
self.config = src_config |
|
|
|
self.preprocessed_dir_src = pathlib.Path( |
|
src_config["general"]["preprocessed_path"] |
|
) |
|
self.preprocessed_dir_tar = pathlib.Path( |
|
tar_config["general"]["preprocessed_path"] |
|
) |
|
for item in [ |
|
"sampling_rate", |
|
"fft_length", |
|
"frame_length", |
|
"frame_shift", |
|
"fmin", |
|
"fmax", |
|
"n_mels", |
|
]: |
|
assert src_config["preprocess"][item] == tar_config["preprocess"][item] |
|
|
|
self.spec_module = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=src_config["preprocess"]["sampling_rate"], |
|
n_fft=src_config["preprocess"]["fft_length"], |
|
win_length=src_config["preprocess"]["frame_length"], |
|
hop_length=src_config["preprocess"]["frame_shift"], |
|
f_min=src_config["preprocess"]["fmin"], |
|
f_max=src_config["preprocess"]["fmax"], |
|
n_mels=src_config["preprocess"]["n_mels"], |
|
power=1, |
|
center=True, |
|
norm="slaney", |
|
mel_scale="slaney", |
|
) |
|
|
|
with open(self.preprocessed_dir_src / filetxt, "r") as fr: |
|
self.filelist_src = [pathlib.Path(path.strip("\n")) for path in fr] |
|
with open(self.preprocessed_dir_tar / filetxt, "r") as fr: |
|
self.filelist_tar = [pathlib.Path(path.strip("\n")) for path in fr] |
|
|
|
self.d_out = {"src": {}, "tar": {}} |
|
for item in ["wavs", "wavsaux"]: |
|
self.d_out["src"][item] = [] |
|
self.d_out["tar"][item] = [] |
|
|
|
for swp in self.filelist_src: |
|
if src_config["general"]["corpus_type"] == "single": |
|
basename = str(swp.stem) |
|
else: |
|
basename = str(swp.parent.name) + "-" + str(swp.stem) |
|
with open( |
|
self.preprocessed_dir_src / "{}.pickle".format(basename), "rb" |
|
) as fw: |
|
d_preprocessed = pickle.load(fw) |
|
for item in ["wavs", "wavsaux"]: |
|
try: |
|
self.d_out["src"][item].extend(d_preprocessed[item]) |
|
except: |
|
pass |
|
|
|
for twp in self.filelist_tar: |
|
if tar_config["general"]["corpus_type"] == "single": |
|
basename = str(twp.stem) |
|
else: |
|
basename = str(twp.parent.name) + "-" + str(twp.stem) |
|
with open( |
|
self.preprocessed_dir_tar / "{}.pickle".format(basename), "rb" |
|
) as fw: |
|
d_preprocessed = pickle.load(fw) |
|
for item in ["wavs", "wavsaux"]: |
|
try: |
|
self.d_out["tar"][item].extend(d_preprocessed[item]) |
|
except: |
|
pass |
|
|
|
min_len = min(len(self.d_out["src"]["wavs"]), len(self.d_out["tar"]["wavs"])) |
|
for spk in ["src", "tar"]: |
|
for item in ["wavs", "wavsaux"]: |
|
if self.d_out[spk][item] != None: |
|
self.d_out[spk][item] = np.asarray(self.d_out[spk][item][:min_len]) |
|
|
|
def __len__(self): |
|
return len(self.d_out["src"]["wavs"]) |
|
|
|
def __getitem__(self, idx): |
|
d_batch = {} |
|
|
|
for spk in ["src", "tar"]: |
|
for item in ["wavs", "wavsaux"]: |
|
if self.d_out[spk][item].size > 0: |
|
d_batch["{}_{}".format(item, spk)] = torch.from_numpy( |
|
self.d_out[spk][item][idx] |
|
) |
|
d_batch["{}_{}".format(item, spk)] = self.normalize_waveform( |
|
d_batch["{}_{}".format(item, spk)], db=-3 |
|
) |
|
|
|
d_batch["melspecs_src"] = self.calc_spectrogram(d_batch["wavs_src"]) |
|
return d_batch |
|
|
|
|
|
class AETModule(torch.nn.Module): |
|
""" |
|
src: Dataset from which we extract the channel features |
|
tar: Dataset to which the src channel features are added |
|
""" |
|
|
|
def __init__(self, args, chmatch_config, src_config, tar_config): |
|
super().__init__() |
|
if args.stage == "ssl-step": |
|
LModule = SSLStepLightningModule |
|
elif args.stage == "ssl-dual": |
|
LModule = SSLDualLightningModule |
|
else: |
|
raise NotImplementedError() |
|
|
|
src_model = LModule(src_config).load_from_checkpoint( |
|
checkpoint_path=chmatch_config["general"]["source"]["ckpt_path"], |
|
config=src_config, |
|
) |
|
self.src_config = src_config |
|
|
|
self.encoder_src = src_model.encoder |
|
if src_config["general"]["use_gst"]: |
|
self.gst_src = src_model.gst |
|
else: |
|
self.channelfeats_src = src_model.channelfeats |
|
self.channel_src = src_model.channel |
|
|
|
def forward(self, melspecs_src, wavsaux_tar): |
|
if self.src_config["general"]["use_gst"]: |
|
chfeats_src = self.gst_src(melspecs_src.transpose(1, 2)) |
|
else: |
|
_, enc_hidden_src = self.encoder_src( |
|
melspecs_src.unsqueeze(1).transpose(2, 3) |
|
) |
|
chfeats_src = self.channelfeats_src(enc_hidden_src) |
|
wavschmatch_tar = self.channel_src(wavsaux_tar, chfeats_src) |
|
return wavschmatch_tar |
|
|
|
|
|
def get_arg(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--stage", required=True, type=str) |
|
parser.add_argument("--config_path", required=True, type=pathlib.Path) |
|
parser.add_argument("--exist_src_aux", action="store_true") |
|
parser.add_argument("--run_name", required=True, type=str) |
|
return parser.parse_args() |
|
|
|
|
|
def main(args, chmatch_config, device): |
|
src_config = yaml.load( |
|
open(chmatch_config["general"]["source"]["config_path"], "r"), |
|
Loader=yaml.FullLoader, |
|
) |
|
tar_config = yaml.load( |
|
open(chmatch_config["general"]["target"]["config_path"], "r"), |
|
Loader=yaml.FullLoader, |
|
) |
|
output_path = pathlib.Path(chmatch_config["general"]["output_path"]) / args.run_name |
|
dataset = AETDataset("test.txt", src_config, tar_config) |
|
loader = DataLoader(dataset, batch_size=1, shuffle=False) |
|
chmatch_module = AETModule(args, chmatch_config, src_config, tar_config).to(device) |
|
|
|
if args.exist_src_aux: |
|
char_vector = calc_deg_charactaristics(chmatch_config) |
|
|
|
for idx, batch in enumerate(tqdm.tqdm(loader)): |
|
melspecs_src = batch["melspecs_src"].to(device) |
|
wavsdeg_src = batch["wavs_src"].to(device) |
|
wavsaux_tar = batch["wavsaux_tar"].to(device) |
|
if args.exist_src_aux: |
|
wavsdegbaseline_tar = calc_deg_baseline( |
|
batch["wavsaux_tar"], char_vector, tar_config |
|
) |
|
wavsdegbaseline_tar = normalize_waveform(wavsdegbaseline_tar, tar_config) |
|
wavsdeg_tar = batch["wavs_tar"].to(device) |
|
wavsmatch_tar = normalize_waveform( |
|
chmatch_module(melspecs_src, wavsaux_tar).cpu().detach(), tar_config |
|
) |
|
torchaudio.save( |
|
output_path / "test_wavs" / "{}-src_wavsdeg.wav".format(idx), |
|
wavsdeg_src.cpu(), |
|
src_config["preprocess"]["sampling_rate"], |
|
) |
|
torchaudio.save( |
|
output_path / "test_wavs" / "{}-tar_wavsaux.wav".format(idx), |
|
wavsaux_tar.cpu(), |
|
tar_config["preprocess"]["sampling_rate"], |
|
) |
|
if args.exist_src_aux: |
|
torchaudio.save( |
|
output_path / "test_wavs" / "{}-tar_wavsdegbaseline.wav".format(idx), |
|
wavsdegbaseline_tar.cpu(), |
|
tar_config["preprocess"]["sampling_rate"], |
|
) |
|
torchaudio.save( |
|
output_path / "test_wavs" / "{}-tar_wavsdeg.wav".format(idx), |
|
wavsdeg_tar.cpu(), |
|
tar_config["preprocess"]["sampling_rate"], |
|
) |
|
torchaudio.save( |
|
output_path / "test_wavs" / "{}-tar_wavsmatch.wav".format(idx), |
|
wavsmatch_tar.cpu(), |
|
tar_config["preprocess"]["sampling_rate"], |
|
) |
|
plot_and_save_mels( |
|
wavsdeg_src[0, ...].cpu().detach(), |
|
output_path / "test_mels" / "{}-src_melsdeg.png".format(idx), |
|
src_config, |
|
) |
|
plot_and_save_mels( |
|
wavsaux_tar[0, ...].cpu().detach(), |
|
output_path / "test_mels" / "{}-tar_melsaux.png".format(idx), |
|
tar_config, |
|
) |
|
if args.exist_src_aux: |
|
plot_and_save_mels( |
|
wavsdegbaseline_tar[0, ...].cpu().detach(), |
|
output_path / "test_mels" / "{}-tar_melsdegbaseline.png".format(idx), |
|
tar_config, |
|
) |
|
plot_and_save_mels( |
|
wavsdeg_tar[0, ...].cpu().detach(), |
|
output_path / "test_mels" / "{}-tar_melsdeg.png".format(idx), |
|
tar_config, |
|
) |
|
plot_and_save_mels( |
|
wavsmatch_tar[0, ...].cpu().detach(), |
|
output_path / "test_mels" / "{}-tar_melsmatch.png".format(idx), |
|
tar_config, |
|
) |
|
|
|
|
|
def calc_deg_baseline(wav, char_vector, tar_config): |
|
wav = wav[0, ...].cpu().detach().numpy() |
|
spec = librosa.stft( |
|
wav, |
|
n_fft=tar_config["preprocess"]["fft_length"], |
|
hop_length=tar_config["preprocess"]["frame_shift"], |
|
win_length=tar_config["preprocess"]["frame_length"], |
|
) |
|
spec_converted = spec * char_vector.reshape(-1, 1) |
|
wav_converted = librosa.istft( |
|
spec_converted, |
|
hop_length=tar_config["preprocess"]["frame_shift"], |
|
win_length=tar_config["preprocess"]["frame_length"], |
|
) |
|
wav_converted = torch.from_numpy(wav_converted).to(torch.float32).unsqueeze(0) |
|
return wav_converted |
|
|
|
|
|
def calc_deg_charactaristics(chmatch_config): |
|
src_config = yaml.load( |
|
open(chmatch_config["general"]["source"]["config_path"], "r"), |
|
Loader=yaml.FullLoader, |
|
) |
|
tar_config = yaml.load( |
|
open(chmatch_config["general"]["target"]["config_path"], "r"), |
|
Loader=yaml.FullLoader, |
|
) |
|
|
|
preprocessed_dir = pathlib.Path(src_config["general"]["preprocessed_path"]) |
|
n_train = src_config["preprocess"]["n_train"] |
|
SR = src_config["preprocess"]["sampling_rate"] |
|
|
|
os.makedirs(preprocessed_dir, exist_ok=True) |
|
|
|
sourcepath = pathlib.Path(src_config["general"]["source_path"]) |
|
|
|
if src_config["general"]["corpus_type"] == "single": |
|
fulllist = list(sourcepath.glob("*.wav")) |
|
random.seed(0) |
|
random.shuffle(fulllist) |
|
train_filelist = fulllist[:n_train] |
|
elif src_config["general"]["corpus_type"] == "multi-seen": |
|
fulllist = list(sourcepath.glob("*/*.wav")) |
|
random.seed(0) |
|
random.shuffle(fulllist) |
|
train_filelist = fulllist[:n_train] |
|
elif src_config["general"]["corpus_type"] == "multi-unseen": |
|
spk_list = list(set([x.parent for x in sourcepath.glob("*/*.wav")])) |
|
train_filelist = [] |
|
random.seed(0) |
|
random.shuffle(spk_list) |
|
for i, spk in enumerate(spk_list): |
|
sourcespkpath = sourcepath / spk |
|
if i < n_train: |
|
train_filelist.extend(list(sourcespkpath.glob("*.wav"))) |
|
else: |
|
raise NotImplementedError( |
|
"corpus_type specified in config.yaml should be {single, multi-seen, multi-unseen}" |
|
) |
|
|
|
specs_all = np.zeros((tar_config["preprocess"]["fft_length"] // 2 + 1, 1)) |
|
|
|
for wp in tqdm.tqdm(train_filelist): |
|
wav, _ = librosa.load(wp, sr=SR) |
|
spec = np.abs( |
|
librosa.stft( |
|
wav, |
|
n_fft=src_config["preprocess"]["fft_length"], |
|
hop_length=src_config["preprocess"]["frame_shift"], |
|
win_length=src_config["preprocess"]["frame_length"], |
|
) |
|
) |
|
|
|
auxpath = pathlib.Path(src_config["general"]["aux_path"]) |
|
if src_config["general"]["corpus_type"] == "single": |
|
wav_aux, _ = librosa.load(auxpath / wp.name, sr=SR) |
|
else: |
|
wav_aux, _ = librosa.load(auxpath / wp.parent.name / wp.name, sr=SR) |
|
spec_aux = np.abs( |
|
librosa.stft( |
|
wav_aux, |
|
n_fft=src_config["preprocess"]["fft_length"], |
|
hop_length=src_config["preprocess"]["frame_shift"], |
|
win_length=src_config["preprocess"]["frame_length"], |
|
) |
|
) |
|
min_len = min(spec.shape[1], spec_aux.shape[1]) |
|
spec_diff = spec[:, :min_len] / (spec_aux[:, :min_len] + 1e-10) |
|
specs_all = np.hstack([specs_all, np.mean(spec_diff, axis=1).reshape(-1, 1)]) |
|
|
|
char_vector = np.mean(specs_all, axis=1) |
|
char_vector = char_vector / (np.sum(char_vector) + 1e-10) |
|
return char_vector |
|
|
|
|
|
def normalize_waveform(wav, tar_config, db=-3): |
|
wav, _ = torchaudio.sox_effects.apply_effects_tensor( |
|
wav, |
|
tar_config["preprocess"]["sampling_rate"], |
|
[["norm", "{}".format(db)]], |
|
) |
|
return wav |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_arg() |
|
chmatch_config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader) |
|
output_path = pathlib.Path(chmatch_config["general"]["output_path"]) / args.run_name |
|
os.makedirs(output_path, exist_ok=True) |
|
os.makedirs(output_path / "test_wavs", exist_ok=True) |
|
os.makedirs(output_path / "test_mels", exist_ok=True) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
main(args, chmatch_config, device) |
|
|