import copy from abc import abstractmethod from typing import Dict, Tuple import torch from coqpit import Coqpit from torch import nn from TTS.tts.layers.losses import TacotronLoss from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.generic_utils import format_aux_input from TTS.utils.io import load_fsspec from TTS.utils.training import gradual_training_scheduler class BaseTacotron(BaseTTS): """Base class shared by Tacotron and Tacotron2""" def __init__( self, config: "TacotronConfig", ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None, ): super().__init__(config, ap, tokenizer, speaker_manager) # pass all config fields as class attributes for key in config: setattr(self, key, config[key]) # layers self.embedding = None self.encoder = None self.decoder = None self.postnet = None # init tensors self.embedded_speakers = None self.embedded_speakers_projected = None # global style token if self.gst and self.use_gst: self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim self.gst_layer = None # Capacitron if self.capacitron_vae and self.use_capacitron_vae: self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim # add capacitron embedding dim self.capacitron_vae_layer = None # additional layers self.decoder_backward = None self.coarse_decoder = None @staticmethod def _format_aux_input(aux_input: Dict) -> Dict: """Set missing fields to their default values""" if aux_input: return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input) return None ############################# # INIT FUNCTIONS ############################# def _init_backward_decoder(self): """Init the backward decoder for Forward-Backward decoding.""" self.decoder_backward = copy.deepcopy(self.decoder) def _init_coarse_decoder(self): """Init the coarse decoder for Double-Decoder Consistency.""" self.coarse_decoder = copy.deepcopy(self.decoder) self.coarse_decoder.r_init = self.ddc_r self.coarse_decoder.set_r(self.ddc_r) ############################# # CORE FUNCTIONS ############################# @abstractmethod def forward(self): pass @abstractmethod def inference(self): pass def load_checkpoint( self, config, checkpoint_path, eval=False ): # pylint: disable=unused-argument, redefined-builtin """Load model checkpoint and set up internals. Args: config (Coqpi): model configuration. checkpoint_path (str): path to checkpoint file. eval (bool): whether to load model for evaluation. """ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) # TODO: set r in run-time by taking it from the new config if "r" in state: # set r from the state (for compatibility with older checkpoints) self.decoder.set_r(state["r"]) elif "config" in state: # set r from config used at training time (for inference) self.decoder.set_r(state["config"]["r"]) else: # set r from the new config (for new-models) self.decoder.set_r(config.r) if eval: self.eval() print(f" > Model's reduction rate `r` is set to: {self.decoder.r}") assert not self.training def get_criterion(self) -> nn.Module: """Get the model criterion used in training.""" return TacotronLoss(self.config) @staticmethod def init_from_config(config: Coqpit): """Initialize model from config.""" from TTS.utils.audio import AudioProcessor ap = AudioProcessor.init_from_config(config) tokenizer = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config) return BaseTacotron(config, ap, tokenizer, speaker_manager) ########################## # TEST AND LOG FUNCTIONS # ########################## def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. You can override this for a different behaviour. Args: assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`. Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ print(" | > Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences aux_inputs = self._get_test_aux_input() for idx, sen in enumerate(test_sentences): outputs_dict = synthesis( self, sen, self.config, "cuda" in str(next(self.parameters()).device), speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], use_griffin_lim=True, do_trim_silence=False, ) test_audios["{}-audio".format(idx)] = outputs_dict["wav"] test_figures["{}-prediction".format(idx)] = plot_spectrogram( outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False ) test_figures["{}-alignment".format(idx)] = plot_alignment( outputs_dict["outputs"]["alignments"], output_fig=False ) return {"figures": test_figures, "audios": test_audios} def test_log( self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument ) -> None: logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) logger.test_figures(steps, outputs["figures"]) ############################# # COMMON COMPUTE FUNCTIONS ############################# def compute_masks(self, text_lengths, mel_lengths): """Compute masks against sequence paddings.""" # B x T_in_max (boolean) input_mask = sequence_mask(text_lengths) output_mask = None if mel_lengths is not None: max_len = mel_lengths.max() r = self.decoder.r max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len output_mask = sequence_mask(mel_lengths, max_len=max_len) return input_mask, output_mask def _backward_pass(self, mel_specs, encoder_outputs, mask): """Run backwards decoder""" decoder_outputs_b, alignments_b, _ = self.decoder_backward( encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask ) decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous() return decoder_outputs_b, alignments_b def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask): """Double Decoder Consistency""" T = mel_specs.shape[1] if T % self.coarse_decoder.r > 0: padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r) mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0)) decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder( encoder_outputs.detach(), mel_specs, input_mask ) # scale_factor = self.decoder.r_init / self.decoder.r alignments_backward = torch.nn.functional.interpolate( alignments_backward.transpose(1, 2), size=alignments.shape[1], mode="nearest", ).transpose(1, 2) decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2) decoder_outputs_backward = decoder_outputs_backward[:, :T, :] return decoder_outputs_backward, alignments_backward ############################# # EMBEDDING FUNCTIONS ############################# def compute_gst(self, inputs, style_input, speaker_embedding=None): """Compute global style token""" if isinstance(style_input, dict): # multiply each style token with a weight query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs) if speaker_embedding is not None: query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1) _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) for k_token, v_amplifier in style_input.items(): key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) gst_outputs = gst_outputs + gst_outputs_att * v_amplifier elif style_input is None: # ignore style token and return zero tensor gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) else: # compute style tokens gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable inputs = self._concat_speaker_embedding(inputs, gst_outputs) return inputs def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None): """Capacitron Variational Autoencoder""" (VAE_outputs, posterior_distribution, prior_distribution, capacitron_beta,) = self.capacitron_vae_layer( reference_mel_info, text_info, speaker_embedding, # pylint: disable=not-callable ) VAE_outputs = VAE_outputs.to(inputs.device) encoder_output = self._concat_speaker_embedding( inputs, VAE_outputs ) # concatenate to the output of the basic tacotron encoder return ( encoder_output, posterior_distribution, prior_distribution, capacitron_beta, ) @staticmethod def _add_speaker_embedding(outputs, embedded_speakers): embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1) outputs = outputs + embedded_speakers_ return outputs @staticmethod def _concat_speaker_embedding(outputs, embedded_speakers): embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1) outputs = torch.cat([outputs, embedded_speakers_], dim=-1) return outputs ############################# # CALLBACKS ############################# def on_epoch_start(self, trainer): """Callback for setting values wrt gradual training schedule. Args: trainer (TrainerTTS): TTS trainer object that is used to train this model. """ if self.gradual_training: r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config) trainer.config.r = r self.decoder.set_r(r) if trainer.config.bidirectional_decoder: trainer.model.decoder_backward.set_r(r) print(f"\n > Number of output frames: {self.decoder.r}")