import torch import pytorch_lightning as pl import torchaudio import os import pathlib import tqdm from model import ( EncoderModule, ChannelFeatureModule, ChannelModule, MultiScaleSpectralLoss, GSTModule, ) class PretrainLightningModule(pl.LightningModule): def __init__(self, config): super().__init__() self.save_hyperparameters() self.config = config if config["general"]["use_gst"]: self.encoder = EncoderModule(config) self.gst = GSTModule(config) else: self.encoder = EncoderModule(config, use_channel=True) self.channelfeats = ChannelFeatureModule(config) self.channel = ChannelModule(config) self.vocoder = None self.criteria_a = MultiScaleSpectralLoss(config) if "feature_loss" in config["train"]: if config["train"]["feature_loss"]["type"] == "mae": self.criteria_b = torch.nn.L1Loss() else: self.criteria_b = torch.nn.MSELoss() else: self.criteria = torch.nn.L1Loss() self.alpha = config["train"]["alpha"] def forward(self, melspecs, wavsaux): if self.config["general"]["use_gst"]: enc_out = self.encoder(melspecs.unsqueeze(1).transpose(2, 3)) chfeats = self.gst(melspecs.transpose(1, 2)) else: enc_out, enc_hidden = self.encoder(melspecs.unsqueeze(1).transpose(2, 3)) chfeats = self.channelfeats(enc_hidden) enc_out = enc_out.squeeze(1).transpose(1, 2) wavsdeg = self.channel(wavsaux, chfeats) return enc_out, wavsdeg def training_step(self, batch, batch_idx): if self.config["general"]["use_gst"]: enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) chfeats = self.gst(batch["melspecs"].transpose(1, 2)) else: enc_out, enc_hidden = self.encoder( batch["melspecs"].unsqueeze(1).transpose(2, 3) ) chfeats = self.channelfeats(enc_hidden) enc_out = enc_out.squeeze(1).transpose(1, 2) wavsdeg = self.channel(batch["wavsaux"], chfeats) loss_recons = self.criteria_a(wavsdeg, batch["wavs"]) if self.config["general"]["feature_type"] == "melspec": loss_encoder = self.criteria_b(enc_out, batch["melspecsaux"]) elif self.config["general"]["feature_type"] == "vocfeats": loss_encoder = self.criteria_b(enc_out, batch["melceps"]) loss = self.alpha * loss_recons + (1.0 - self.alpha) * loss_encoder self.log( "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True ) self.log( "train_loss_recons", loss_recons, on_step=True, on_epoch=True, prog_bar=True, logger=True, ) self.log( "train_loss_encoder", loss_encoder, on_step=True, on_epoch=True, prog_bar=True, logger=True, ) return loss def validation_step(self, batch, batch_idx): if self.config["general"]["use_gst"]: enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) chfeats = self.gst(batch["melspecs"].transpose(1, 2)) else: enc_out, enc_hidden = self.encoder( batch["melspecs"].unsqueeze(1).transpose(2, 3) ) chfeats = self.channelfeats(enc_hidden) enc_out = enc_out.squeeze(1).transpose(1, 2) wavsdeg = self.channel(batch["wavsaux"], chfeats) loss_recons = self.criteria_a(wavsdeg, batch["wavs"]) if self.config["general"]["feature_type"] == "melspec": val_aux_feats = batch["melspecsaux"] feats_name = "melspec" loss_encoder = self.criteria_b(enc_out, val_aux_feats) elif self.config["general"]["feature_type"] == "vocfeats": val_aux_feats = batch["melceps"] feats_name = "melcep" loss_encoder = self.criteria_b(enc_out, val_aux_feats) loss = self.alpha * loss_recons + (1.0 - self.alpha) * loss_encoder logger_img_dict = { "val_src_melspec": batch["melspecs"], "val_pred_{}".format(feats_name): enc_out, "val_aux_{}".format(feats_name): val_aux_feats, } logger_wav_dict = { "val_src_wav": batch["wavs"], "val_pred_wav": wavsdeg, "val_aux_wav": batch["wavsaux"], } return { "val_loss": loss, "val_loss_recons": loss_recons, "val_loss_encoder": loss_encoder, "logger_dict": [logger_img_dict, logger_wav_dict], } def validation_epoch_end(self, outputs): val_loss = torch.stack([out["val_loss"] for out in outputs]).mean().item() val_loss_recons = ( torch.stack([out["val_loss_recons"] for out in outputs]).mean().item() ) val_loss_encoder = ( torch.stack([out["val_loss_encoder"] for out in outputs]).mean().item() ) self.log("val_loss", val_loss, on_epoch=True, prog_bar=True, logger=True) self.log( "val_loss_recons", val_loss_recons, on_epoch=True, prog_bar=True, logger=True, ) self.log( "val_loss_encoder", val_loss_encoder, on_epoch=True, prog_bar=True, logger=True, ) def test_step(self, batch, batch_idx): if self.config["general"]["use_gst"]: enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) chfeats = self.gst(batch["melspecs"].transpose(1, 2)) else: enc_out, enc_hidden = self.encoder( batch["melspecs"].unsqueeze(1).transpose(2, 3) ) chfeats = self.channelfeats(enc_hidden) enc_out = enc_out.squeeze(1).transpose(1, 2) wavsdeg = self.channel(batch["wavsaux"], chfeats) if self.config["general"]["feature_type"] == "melspec": enc_feats = enc_out enc_feats_aux = batch["melspecsaux"] elif self.config["general"]["feature_type"] == "vocfeats": enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) enc_feats_aux = torch.cat( (batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1 ) recons_wav = self.vocoder(enc_feats_aux).squeeze(1) remas = self.vocoder(enc_feats).squeeze(1) if self.config["general"]["feature_type"] == "melspec": enc_feats_input = batch["melspecs"] elif self.config["general"]["feature_type"] == "vocfeats": enc_feats_input = torch.cat( (batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1 ) input_recons = self.vocoder(enc_feats_input).squeeze(1) if "wavsaux" in batch: gt_wav = batch["wavsaux"] else: gt_wav = None return { "reconstructed": recons_wav, "remastered": remas, "channeled": wavsdeg, "groundtruth": gt_wav, "input": batch["wavs"], "input_recons": input_recons, } def test_epoch_end(self, outputs): wav_dir = ( pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs" ) os.makedirs(wav_dir, exist_ok=True) mel_dir = ( pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels" ) os.makedirs(mel_dir, exist_ok=True) print("Saving mel spectrogram plots ...") for idx, out in enumerate(tqdm.tqdm(outputs)): for key in [ "reconstructed", "remastered", "channeled", "input", "input_recons", "groundtruth", ]: if out[key] != None: torchaudio.save( wav_dir / "{}-{}.wav".format(idx, key), out[key][0, ...].unsqueeze(0).cpu(), sample_rate=self.config["preprocess"]["sampling_rate"], channels_first=True, ) def configure_optimizers(self): optimizer = torch.optim.Adam( self.parameters(), lr=self.config["train"]["learning_rate"] ) lr_scheduler_config = { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.5, min_lr=1e-5, verbose=True ), "interval": "epoch", "frequency": 3, "monitor": "val_loss", } return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} class SSLBaseModule(pl.LightningModule): def __init__(self, config): super().__init__() self.save_hyperparameters() self.config = config if config["general"]["use_gst"]: self.encoder = EncoderModule(config) self.gst = GSTModule(config) else: self.encoder = EncoderModule(config, use_channel=True) self.channelfeats = ChannelFeatureModule(config) self.channel = ChannelModule(config) if config["train"]["load_pretrained"]: pre_model = PretrainLightningModule.load_from_checkpoint( checkpoint_path=config["train"]["pretrained_path"] ) self.encoder.load_state_dict(pre_model.encoder.state_dict(), strict=False) self.channel.load_state_dict(pre_model.channel.state_dict(), strict=False) if config["general"]["use_gst"]: self.gst.load_state_dict(pre_model.gst.state_dict(), strict=False) else: self.channelfeats.load_state_dict( pre_model.channelfeats.state_dict(), strict=False ) self.vocoder = None self.criteria = self.get_loss_function(config) def training_step(self, batch, batch_idx): raise NotImplementedError() def validation_step(self, batch, batch_idx): raise NotImplementedError() def validation_epoch_end(self, outputs): raise NotImplementedError() def configure_optimizers(self): raise NotImplementedError() def get_loss_function(self, config): raise NotImplementedError() def forward(self, melspecs, f0s=None): if self.config["general"]["use_gst"]: enc_out = self.encoder(melspecs.unsqueeze(1).transpose(2, 3)) chfeats = self.gst(melspecs.transpose(1, 2)) else: enc_out, enc_hidden = self.encoder(melspecs.unsqueeze(1).transpose(2, 3)) chfeats = self.channelfeats(enc_hidden) enc_out = enc_out.squeeze(1).transpose(1, 2) if self.config["general"]["feature_type"] == "melspec": enc_feats = enc_out elif self.config["general"]["feature_type"] == "vocfeats": enc_feats = torch.cat((f0s.unsqueeze(1), enc_out), dim=1) remas = self.vocoder(enc_feats).squeeze(1) wavsdeg = self.channel(remas, chfeats) return remas, wavsdeg def test_step(self, batch, batch_idx): if self.config["general"]["use_gst"]: enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) chfeats = self.gst(batch["melspecs"].transpose(1, 2)) else: enc_out, enc_hidden = self.encoder( batch["melspecs"].unsqueeze(1).transpose(2, 3) ) chfeats = self.channelfeats(enc_hidden) enc_out = enc_out.squeeze(1).transpose(1, 2) if self.config["general"]["feature_type"] == "melspec": enc_feats = enc_out elif self.config["general"]["feature_type"] == "vocfeats": enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) remas = self.vocoder(enc_feats).squeeze(1) wavsdeg = self.channel(remas, chfeats) if self.config["general"]["feature_type"] == "melspec": enc_feats_input = batch["melspecs"] elif self.config["general"]["feature_type"] == "vocfeats": enc_feats_input = torch.cat( (batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1 ) input_recons = self.vocoder(enc_feats_input).squeeze(1) if "wavsaux" in batch: gt_wav = batch["wavsaux"] if self.config["general"]["feature_type"] == "melspec": enc_feats_aux = batch["melspecsaux"] elif self.config["general"]["feature_type"] == "vocfeats": enc_feats_aux = torch.cat( (batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1 ) recons_wav = self.vocoder(enc_feats_aux).squeeze(1) else: gt_wav = None recons_wav = None return { "reconstructed": recons_wav, "remastered": remas, "channeled": wavsdeg, "input": batch["wavs"], "input_recons": input_recons, "groundtruth": gt_wav, } def test_epoch_end(self, outputs): wav_dir = ( pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs" ) os.makedirs(wav_dir, exist_ok=True) mel_dir = ( pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels" ) os.makedirs(mel_dir, exist_ok=True) print("Saving mel spectrogram plots ...") for idx, out in enumerate(tqdm.tqdm(outputs)): plot_keys = [] for key in [ "reconstructed", "remastered", "channeled", "input", "input_recons", "groundtruth", ]: if out[key] != None: plot_keys.append(key) torchaudio.save( wav_dir / "{}-{}.wav".format(idx, key), out[key][0, ...].unsqueeze(0).cpu(), sample_rate=self.config["preprocess"]["sampling_rate"], channels_first=True, ) class SSLStepLightningModule(SSLBaseModule): def __init__(self, config): super().__init__(config) if config["train"]["fix_channel"]: for param in self.channel.parameters(): param.requires_grad = False def training_step(self, batch, batch_idx, optimizer_idx): if self.config["general"]["use_gst"]: enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) chfeats = self.gst(batch["melspecs"].transpose(1, 2)) else: enc_out, enc_hidden = self.encoder( batch["melspecs"].unsqueeze(1).transpose(2, 3) ) chfeats = self.channelfeats(enc_hidden) enc_out = enc_out.squeeze(1).transpose(1, 2) if self.config["general"]["feature_type"] == "melspec": enc_feats = enc_out elif self.config["general"]["feature_type"] == "vocfeats": enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) remas = self.vocoder(enc_feats).squeeze(1) wavsdeg = self.channel(remas, chfeats) loss = self.criteria(wavsdeg, batch["wavs"]) self.log( "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, ) return loss def validation_step(self, batch, batch_idx): if self.config["general"]["use_gst"]: enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) chfeats = self.gst(batch["melspecs"].transpose(1, 2)) else: enc_out, enc_hidden = self.encoder( batch["melspecs"].unsqueeze(1).transpose(2, 3) ) chfeats = self.channelfeats(enc_hidden) enc_out = enc_out.squeeze(1).transpose(1, 2) if self.config["general"]["feature_type"] == "melspec": enc_feats = enc_out feats_name = "melspec" elif self.config["general"]["feature_type"] == "vocfeats": enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) feats_name = "melcep" remas = self.vocoder(enc_feats).squeeze(1) wavsdeg = self.channel(remas, chfeats) loss = self.criteria(wavsdeg, batch["wavs"]) logger_img_dict = { "val_src_melspec": batch["melspecs"], "val_pred_{}".format(feats_name): enc_out, } for auxfeats in ["melceps", "melspecsaux"]: if auxfeats in batch: logger_img_dict["val_aux_{}".format(auxfeats)] = batch[auxfeats] logger_wav_dict = { "val_src_wav": batch["wavs"], "val_remastered_wav": remas, "val_pred_wav": wavsdeg, } if "wavsaux" in batch: logger_wav_dict["val_aux_wav"] = batch["wavsaux"] d_out = {"val_loss": loss, "logger_dict": [logger_img_dict, logger_wav_dict]} return d_out def validation_epoch_end(self, outputs): self.log( "val_loss", torch.stack([out["val_loss"] for out in outputs]).mean().item(), on_epoch=True, prog_bar=True, logger=True, ) def optimizer_step( self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu=False, using_native_amp=False, using_lbfgs=False, ): if epoch < self.config["train"]["epoch_channel"]: if optimizer_idx == 0: optimizer.step(closure=optimizer_closure) elif optimizer_idx == 1: optimizer_closure() else: if optimizer_idx == 0: optimizer_closure() elif optimizer_idx == 1: optimizer.step(closure=optimizer_closure) def configure_optimizers(self): if self.config["train"]["fix_channel"]: if self.config["general"]["use_gst"]: optimizer_channel = torch.optim.Adam( self.gst.parameters(), lr=self.config["train"]["learning_rate"] ) else: optimizer_channel = torch.optim.Adam( self.channelfeats.parameters(), lr=self.config["train"]["learning_rate"], ) optimizer_encoder = torch.optim.Adam( self.encoder.parameters(), lr=self.config["train"]["learning_rate"] ) else: if self.config["general"]["use_gst"]: optimizer_channel = torch.optim.Adam( [ {"params": self.channel.parameters()}, {"params": self.gst.parameters()}, ], lr=self.config["train"]["learning_rate"], ) else: optimizer_channel = torch.optim.Adam( [ {"params": self.channel.parameters()}, {"params": self.channelfeats.parameters()}, ], lr=self.config["train"]["learning_rate"], ) optimizer_encoder = torch.optim.Adam( self.encoder.parameters(), lr=self.config["train"]["learning_rate"] ) optimizers = [optimizer_channel, optimizer_encoder] schedulers = [ { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( optimizers[0], mode="min", factor=0.5, min_lr=1e-5, verbose=True ), "interval": "epoch", "frequency": 3, "monitor": "val_loss", }, { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( optimizers[1], mode="min", factor=0.5, min_lr=1e-5, verbose=True ), "interval": "epoch", "frequency": 3, "monitor": "val_loss", }, ] return optimizers, schedulers def get_loss_function(self, config): return MultiScaleSpectralLoss(config) class SSLDualLightningModule(SSLBaseModule): def __init__(self, config): super().__init__(config) if config["train"]["fix_channel"]: for param in self.channel.parameters(): param.requires_grad = False self.spec_module = torchaudio.transforms.MelSpectrogram( sample_rate=config["preprocess"]["sampling_rate"], n_fft=config["preprocess"]["fft_length"], win_length=config["preprocess"]["frame_length"], hop_length=config["preprocess"]["frame_shift"], f_min=config["preprocess"]["fmin"], f_max=config["preprocess"]["fmax"], n_mels=config["preprocess"]["n_mels"], power=1, center=True, norm="slaney", mel_scale="slaney", ) self.beta = config["train"]["beta"] self.criteria_a, self.criteria_b = self.get_loss_function(config) def training_step(self, batch, batch_idx): if self.config["general"]["use_gst"]: enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) chfeats = self.gst(batch["melspecs"].transpose(1, 2)) else: enc_out, enc_hidden = self.encoder( batch["melspecs"].unsqueeze(1).transpose(2, 3) ) chfeats = self.channelfeats(enc_hidden) enc_out = enc_out.squeeze(1).transpose(1, 2) if self.config["general"]["feature_type"] == "melspec": enc_feats = enc_out elif self.config["general"]["feature_type"] == "vocfeats": enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) remas = self.vocoder(enc_feats).squeeze(1) wavsdeg = self.channel(remas, chfeats) loss_recons = self.criteria_a(wavsdeg, batch["wavs"]) with torch.no_grad(): wavsdegtask = self.channel(batch["wavstask"], chfeats) melspecstask = self.calc_spectrogram(wavsdegtask) if self.config["general"]["use_gst"]: enc_out_task = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3)) else: enc_out_task, _ = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3)) enc_out_task = enc_out_task.squeeze(1).transpose(1, 2) if self.config["general"]["feature_type"] == "melspec": loss_task = self.criteria_b(enc_out_task, batch["melspecstask"]) elif self.config["general"]["feature_type"] == "vocfeats": loss_task = self.criteria_b(enc_out_task, batch["melcepstask"]) loss = self.beta * loss_recons + (1 - self.beta) * loss_task self.log( "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True ) self.log( "train_loss_recons", loss_recons, on_step=True, on_epoch=True, prog_bar=True, logger=True, ) self.log( "train_loss_task", loss_task, on_step=True, on_epoch=True, prog_bar=True, logger=True, ) return loss def validation_step(self, batch, batch_idx): if self.config["general"]["use_gst"]: enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) chfeats = self.gst(batch["melspecs"].transpose(1, 2)) else: enc_out, enc_hidden = self.encoder( batch["melspecs"].unsqueeze(1).transpose(2, 3) ) chfeats = self.channelfeats(enc_hidden) enc_out = enc_out.squeeze(1).transpose(1, 2) if self.config["general"]["feature_type"] == "melspec": enc_feats = enc_out feats_name = "melspec" elif self.config["general"]["feature_type"] == "vocfeats": enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) feats_name = "melcep" remas = self.vocoder(enc_feats).squeeze(1) wavsdeg = self.channel(remas, chfeats) loss_recons = self.criteria_a(wavsdeg, batch["wavs"]) wavsdegtask = self.channel(batch["wavstask"], chfeats) melspecstask = self.calc_spectrogram(wavsdegtask) if self.config["general"]["use_gst"]: enc_out_task = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3)) else: enc_out_task, _ = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3)) enc_out_task = enc_out_task.squeeze(1).transpose(1, 2) if self.config["general"]["feature_type"] == "melspec": enc_out_task_truth = batch["melspecstask"] loss_task = self.criteria_b(enc_out_task, enc_out_task_truth) elif self.config["general"]["feature_type"] == "vocfeats": enc_out_task_truth = batch["melcepstask"] loss_task = self.criteria_b(enc_out_task, enc_out_task_truth) loss = self.beta * loss_recons + (1 - self.beta) * loss_task logger_img_dict = { "val_src_melspec": batch["melspecs"], "val_pred_{}".format(feats_name): enc_out, "val_truth_{}_task".format(feats_name): enc_out_task_truth, "val_pred_{}_task".format(feats_name): enc_out_task, } for auxfeats in ["melceps", "melspecsaux"]: if auxfeats in batch: logger_img_dict["val_aux_{}".format(auxfeats)] = batch[auxfeats] logger_wav_dict = { "val_src_wav": batch["wavs"], "val_remastered_wav": remas, "val_pred_wav": wavsdeg, "val_truth_wavtask": batch["wavstask"], "val_deg_wavtask": wavsdegtask, } if "wavsaux" in batch: logger_wav_dict["val_aux_wav"] = batch["wavsaux"] d_out = { "val_loss": loss, "val_loss_recons": loss_recons, "val_loss_task": loss_task, "logger_dict": [logger_img_dict, logger_wav_dict], } return d_out def validation_epoch_end(self, outputs): self.log( "val_loss", torch.stack([out["val_loss"] for out in outputs]).mean().item(), on_epoch=True, prog_bar=True, logger=True, ) self.log( "val_loss_recons", torch.stack([out["val_loss_recons"] for out in outputs]).mean().item(), on_epoch=True, prog_bar=True, logger=True, ) self.log( "val_loss_task", torch.stack([out["val_loss_task"] for out in outputs]).mean().item(), on_epoch=True, prog_bar=True, logger=True, ) def test_step(self, batch, batch_idx): if self.config["general"]["use_gst"]: enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3)) chfeats = self.gst(batch["melspecs"].transpose(1, 2)) else: enc_out, enc_hidden = self.encoder( batch["melspecs"].unsqueeze(1).transpose(2, 3) ) chfeats = self.channelfeats(enc_hidden) enc_out = enc_out.squeeze(1).transpose(1, 2) if self.config["general"]["feature_type"] == "melspec": enc_feats = enc_out elif self.config["general"]["feature_type"] == "vocfeats": enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1) remas = self.vocoder(enc_feats).squeeze(1) wavsdeg = self.channel(remas, chfeats) if self.config["general"]["feature_type"] == "melspec": enc_feats_input = batch["melspecs"] elif self.config["general"]["feature_type"] == "vocfeats": enc_feats_input = torch.cat( (batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1 ) input_recons = self.vocoder(enc_feats_input).squeeze(1) wavsdegtask = self.channel(batch["wavstask"], chfeats) if "wavsaux" in batch: gt_wav = batch["wavsaux"] if self.config["general"]["feature_type"] == "melspec": enc_feats_aux = batch["melspecsaux"] elif self.config["general"]["feature_type"] == "vocfeats": enc_feats_aux = torch.cat( (batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1 ) recons_wav = self.vocoder(enc_feats_aux).squeeze(1) else: gt_wav = None recons_wav = None return { "reconstructed": recons_wav, "remastered": remas, "channeled": wavsdeg, "channeled_task": wavsdegtask, "input": batch["wavs"], "input_recons": input_recons, "groundtruth": gt_wav, } def test_epoch_end(self, outputs): wav_dir = ( pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs" ) os.makedirs(wav_dir, exist_ok=True) mel_dir = ( pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels" ) os.makedirs(mel_dir, exist_ok=True) print("Saving mel spectrogram plots ...") for idx, out in enumerate(tqdm.tqdm(outputs)): plot_keys = [] for key in [ "reconstructed", "remastered", "channeled", "channeled_task", "input", "input_recons", "groundtruth", ]: if out[key] != None: plot_keys.append(key) torchaudio.save( wav_dir / "{}-{}.wav".format(idx, key), out[key][0, ...].unsqueeze(0).cpu(), sample_rate=self.config["preprocess"]["sampling_rate"], channels_first=True, ) def configure_optimizers(self): optimizer = torch.optim.Adam( self.parameters(), lr=self.config["train"]["learning_rate"] ) lr_scheduler_config = { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.5, min_lr=1e-5, verbose=True ), "interval": "epoch", "frequency": 3, "monitor": "val_loss", } return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} def calc_spectrogram(self, wav): specs = self.spec_module(wav) log_spec = torch.log( torch.clamp_min(specs, self.config["preprocess"]["min_magnitude"]) * self.config["preprocess"]["comp_factor"] ).to(torch.float32) return log_spec def get_loss_function(self, config): if config["train"]["feature_loss"]["type"] == "mae": feature_loss = torch.nn.L1Loss() else: feature_loss = torch.nn.MSELoss() return MultiScaleSpectralLoss(config), feature_loss