import argparse import pathlib import yaml import torch import torchaudio from torch.utils.data import DataLoader import numpy as np import random import librosa from dataset import Dataset import pickle from lightning_module import ( SSLStepLightningModule, SSLDualLightningModule, ) from utils import plot_and_save_mels import os import tqdm class AETDataset(Dataset): def __init__(self, filetxt, src_config, tar_config): self.config = src_config self.preprocessed_dir_src = pathlib.Path( src_config["general"]["preprocessed_path"] ) self.preprocessed_dir_tar = pathlib.Path( tar_config["general"]["preprocessed_path"] ) for item in [ "sampling_rate", "fft_length", "frame_length", "frame_shift", "fmin", "fmax", "n_mels", ]: assert src_config["preprocess"][item] == tar_config["preprocess"][item] self.spec_module = torchaudio.transforms.MelSpectrogram( sample_rate=src_config["preprocess"]["sampling_rate"], n_fft=src_config["preprocess"]["fft_length"], win_length=src_config["preprocess"]["frame_length"], hop_length=src_config["preprocess"]["frame_shift"], f_min=src_config["preprocess"]["fmin"], f_max=src_config["preprocess"]["fmax"], n_mels=src_config["preprocess"]["n_mels"], power=1, center=True, norm="slaney", mel_scale="slaney", ) with open(self.preprocessed_dir_src / filetxt, "r") as fr: self.filelist_src = [pathlib.Path(path.strip("\n")) for path in fr] with open(self.preprocessed_dir_tar / filetxt, "r") as fr: self.filelist_tar = [pathlib.Path(path.strip("\n")) for path in fr] self.d_out = {"src": {}, "tar": {}} for item in ["wavs", "wavsaux"]: self.d_out["src"][item] = [] self.d_out["tar"][item] = [] for swp in self.filelist_src: if src_config["general"]["corpus_type"] == "single": basename = str(swp.stem) else: basename = str(swp.parent.name) + "-" + str(swp.stem) with open( self.preprocessed_dir_src / "{}.pickle".format(basename), "rb" ) as fw: d_preprocessed = pickle.load(fw) for item in ["wavs", "wavsaux"]: try: self.d_out["src"][item].extend(d_preprocessed[item]) except: pass for twp in self.filelist_tar: if tar_config["general"]["corpus_type"] == "single": basename = str(twp.stem) else: basename = str(twp.parent.name) + "-" + str(twp.stem) with open( self.preprocessed_dir_tar / "{}.pickle".format(basename), "rb" ) as fw: d_preprocessed = pickle.load(fw) for item in ["wavs", "wavsaux"]: try: self.d_out["tar"][item].extend(d_preprocessed[item]) except: pass min_len = min(len(self.d_out["src"]["wavs"]), len(self.d_out["tar"]["wavs"])) for spk in ["src", "tar"]: for item in ["wavs", "wavsaux"]: if self.d_out[spk][item] != None: self.d_out[spk][item] = np.asarray(self.d_out[spk][item][:min_len]) def __len__(self): return len(self.d_out["src"]["wavs"]) def __getitem__(self, idx): d_batch = {} for spk in ["src", "tar"]: for item in ["wavs", "wavsaux"]: if self.d_out[spk][item].size > 0: d_batch["{}_{}".format(item, spk)] = torch.from_numpy( self.d_out[spk][item][idx] ) d_batch["{}_{}".format(item, spk)] = self.normalize_waveform( d_batch["{}_{}".format(item, spk)], db=-3 ) d_batch["melspecs_src"] = self.calc_spectrogram(d_batch["wavs_src"]) return d_batch class AETModule(torch.nn.Module): """ src: Dataset from which we extract the channel features tar: Dataset to which the src channel features are added """ def __init__(self, args, chmatch_config, src_config, tar_config): super().__init__() if args.stage == "ssl-step": LModule = SSLStepLightningModule elif args.stage == "ssl-dual": LModule = SSLDualLightningModule else: raise NotImplementedError() src_model = LModule(src_config).load_from_checkpoint( checkpoint_path=chmatch_config["general"]["source"]["ckpt_path"], config=src_config, ) self.src_config = src_config self.encoder_src = src_model.encoder if src_config["general"]["use_gst"]: self.gst_src = src_model.gst else: self.channelfeats_src = src_model.channelfeats self.channel_src = src_model.channel def forward(self, melspecs_src, wavsaux_tar): if self.src_config["general"]["use_gst"]: chfeats_src = self.gst_src(melspecs_src.transpose(1, 2)) else: _, enc_hidden_src = self.encoder_src( melspecs_src.unsqueeze(1).transpose(2, 3) ) chfeats_src = self.channelfeats_src(enc_hidden_src) wavschmatch_tar = self.channel_src(wavsaux_tar, chfeats_src) return wavschmatch_tar def get_arg(): parser = argparse.ArgumentParser() parser.add_argument("--stage", required=True, type=str) parser.add_argument("--config_path", required=True, type=pathlib.Path) parser.add_argument("--exist_src_aux", action="store_true") parser.add_argument("--run_name", required=True, type=str) return parser.parse_args() def main(args, chmatch_config, device): src_config = yaml.load( open(chmatch_config["general"]["source"]["config_path"], "r"), Loader=yaml.FullLoader, ) tar_config = yaml.load( open(chmatch_config["general"]["target"]["config_path"], "r"), Loader=yaml.FullLoader, ) output_path = pathlib.Path(chmatch_config["general"]["output_path"]) / args.run_name dataset = AETDataset("test.txt", src_config, tar_config) loader = DataLoader(dataset, batch_size=1, shuffle=False) chmatch_module = AETModule(args, chmatch_config, src_config, tar_config).to(device) if args.exist_src_aux: char_vector = calc_deg_charactaristics(chmatch_config) for idx, batch in enumerate(tqdm.tqdm(loader)): melspecs_src = batch["melspecs_src"].to(device) wavsdeg_src = batch["wavs_src"].to(device) wavsaux_tar = batch["wavsaux_tar"].to(device) if args.exist_src_aux: wavsdegbaseline_tar = calc_deg_baseline( batch["wavsaux_tar"], char_vector, tar_config ) wavsdegbaseline_tar = normalize_waveform(wavsdegbaseline_tar, tar_config) wavsdeg_tar = batch["wavs_tar"].to(device) wavsmatch_tar = normalize_waveform( chmatch_module(melspecs_src, wavsaux_tar).cpu().detach(), tar_config ) torchaudio.save( output_path / "test_wavs" / "{}-src_wavsdeg.wav".format(idx), wavsdeg_src.cpu(), src_config["preprocess"]["sampling_rate"], ) torchaudio.save( output_path / "test_wavs" / "{}-tar_wavsaux.wav".format(idx), wavsaux_tar.cpu(), tar_config["preprocess"]["sampling_rate"], ) if args.exist_src_aux: torchaudio.save( output_path / "test_wavs" / "{}-tar_wavsdegbaseline.wav".format(idx), wavsdegbaseline_tar.cpu(), tar_config["preprocess"]["sampling_rate"], ) torchaudio.save( output_path / "test_wavs" / "{}-tar_wavsdeg.wav".format(idx), wavsdeg_tar.cpu(), tar_config["preprocess"]["sampling_rate"], ) torchaudio.save( output_path / "test_wavs" / "{}-tar_wavsmatch.wav".format(idx), wavsmatch_tar.cpu(), tar_config["preprocess"]["sampling_rate"], ) plot_and_save_mels( wavsdeg_src[0, ...].cpu().detach(), output_path / "test_mels" / "{}-src_melsdeg.png".format(idx), src_config, ) plot_and_save_mels( wavsaux_tar[0, ...].cpu().detach(), output_path / "test_mels" / "{}-tar_melsaux.png".format(idx), tar_config, ) if args.exist_src_aux: plot_and_save_mels( wavsdegbaseline_tar[0, ...].cpu().detach(), output_path / "test_mels" / "{}-tar_melsdegbaseline.png".format(idx), tar_config, ) plot_and_save_mels( wavsdeg_tar[0, ...].cpu().detach(), output_path / "test_mels" / "{}-tar_melsdeg.png".format(idx), tar_config, ) plot_and_save_mels( wavsmatch_tar[0, ...].cpu().detach(), output_path / "test_mels" / "{}-tar_melsmatch.png".format(idx), tar_config, ) def calc_deg_baseline(wav, char_vector, tar_config): wav = wav[0, ...].cpu().detach().numpy() spec = librosa.stft( wav, n_fft=tar_config["preprocess"]["fft_length"], hop_length=tar_config["preprocess"]["frame_shift"], win_length=tar_config["preprocess"]["frame_length"], ) spec_converted = spec * char_vector.reshape(-1, 1) wav_converted = librosa.istft( spec_converted, hop_length=tar_config["preprocess"]["frame_shift"], win_length=tar_config["preprocess"]["frame_length"], ) wav_converted = torch.from_numpy(wav_converted).to(torch.float32).unsqueeze(0) return wav_converted def calc_deg_charactaristics(chmatch_config): src_config = yaml.load( open(chmatch_config["general"]["source"]["config_path"], "r"), Loader=yaml.FullLoader, ) tar_config = yaml.load( open(chmatch_config["general"]["target"]["config_path"], "r"), Loader=yaml.FullLoader, ) # configs preprocessed_dir = pathlib.Path(src_config["general"]["preprocessed_path"]) n_train = src_config["preprocess"]["n_train"] SR = src_config["preprocess"]["sampling_rate"] os.makedirs(preprocessed_dir, exist_ok=True) sourcepath = pathlib.Path(src_config["general"]["source_path"]) if src_config["general"]["corpus_type"] == "single": fulllist = list(sourcepath.glob("*.wav")) random.seed(0) random.shuffle(fulllist) train_filelist = fulllist[:n_train] elif src_config["general"]["corpus_type"] == "multi-seen": fulllist = list(sourcepath.glob("*/*.wav")) random.seed(0) random.shuffle(fulllist) train_filelist = fulllist[:n_train] elif src_config["general"]["corpus_type"] == "multi-unseen": spk_list = list(set([x.parent for x in sourcepath.glob("*/*.wav")])) train_filelist = [] random.seed(0) random.shuffle(spk_list) for i, spk in enumerate(spk_list): sourcespkpath = sourcepath / spk if i < n_train: train_filelist.extend(list(sourcespkpath.glob("*.wav"))) else: raise NotImplementedError( "corpus_type specified in config.yaml should be {single, multi-seen, multi-unseen}" ) specs_all = np.zeros((tar_config["preprocess"]["fft_length"] // 2 + 1, 1)) for wp in tqdm.tqdm(train_filelist): wav, _ = librosa.load(wp, sr=SR) spec = np.abs( librosa.stft( wav, n_fft=src_config["preprocess"]["fft_length"], hop_length=src_config["preprocess"]["frame_shift"], win_length=src_config["preprocess"]["frame_length"], ) ) auxpath = pathlib.Path(src_config["general"]["aux_path"]) if src_config["general"]["corpus_type"] == "single": wav_aux, _ = librosa.load(auxpath / wp.name, sr=SR) else: wav_aux, _ = librosa.load(auxpath / wp.parent.name / wp.name, sr=SR) spec_aux = np.abs( librosa.stft( wav_aux, n_fft=src_config["preprocess"]["fft_length"], hop_length=src_config["preprocess"]["frame_shift"], win_length=src_config["preprocess"]["frame_length"], ) ) min_len = min(spec.shape[1], spec_aux.shape[1]) spec_diff = spec[:, :min_len] / (spec_aux[:, :min_len] + 1e-10) specs_all = np.hstack([specs_all, np.mean(spec_diff, axis=1).reshape(-1, 1)]) char_vector = np.mean(specs_all, axis=1) char_vector = char_vector / (np.sum(char_vector) + 1e-10) return char_vector def normalize_waveform(wav, tar_config, db=-3): wav, _ = torchaudio.sox_effects.apply_effects_tensor( wav, tar_config["preprocess"]["sampling_rate"], [["norm", "{}".format(db)]], ) return wav if __name__ == "__main__": args = get_arg() chmatch_config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader) output_path = pathlib.Path(chmatch_config["general"]["output_path"]) / args.run_name os.makedirs(output_path, exist_ok=True) os.makedirs(output_path / "test_wavs", exist_ok=True) os.makedirs(output_path / "test_mels", exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") main(args, chmatch_config, device)