|
from dataclasses import dataclass, field |
|
from typing import Dict, List, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
from coqpit import Coqpit |
|
from torch import nn |
|
from torch.nn.utils import weight_norm |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.distributed import DistributedSampler |
|
from trainer.trainer_utils import get_optimizer, get_scheduler |
|
|
|
from TTS.utils.io import load_fsspec |
|
from TTS.vocoder.datasets import WaveGradDataset |
|
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock |
|
from TTS.vocoder.models.base_vocoder import BaseVocoder |
|
from TTS.vocoder.utils.generic_utils import plot_results |
|
|
|
|
|
@dataclass |
|
class WavegradArgs(Coqpit): |
|
in_channels: int = 80 |
|
out_channels: int = 1 |
|
use_weight_norm: bool = False |
|
y_conv_channels: int = 32 |
|
x_conv_channels: int = 768 |
|
dblock_out_channels: List[int] = field(default_factory=lambda: [128, 128, 256, 512]) |
|
ublock_out_channels: List[int] = field(default_factory=lambda: [512, 512, 256, 128, 128]) |
|
upsample_factors: List[int] = field(default_factory=lambda: [4, 4, 4, 2, 2]) |
|
upsample_dilations: List[List[int]] = field( |
|
default_factory=lambda: [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]] |
|
) |
|
|
|
|
|
class Wavegrad(BaseVocoder): |
|
"""πΈ π WaveGrad π model. |
|
Paper - https://arxiv.org/abs/2009.00713 |
|
|
|
Examples: |
|
Initializing the model. |
|
|
|
>>> from TTS.vocoder.configs import WavegradConfig |
|
>>> config = WavegradConfig() |
|
>>> model = Wavegrad(config) |
|
|
|
Paper Abstract: |
|
This paper introduces WaveGrad, a conditional model for waveform generation which estimates gradients of the |
|
data density. The model is built on prior work on score matching and diffusion probabilistic models. It starts |
|
from a Gaussian white noise signal and iteratively refines the signal via a gradient-based sampler conditioned |
|
on the mel-spectrogram. WaveGrad offers a natural way to trade inference speed for sample quality by adjusting |
|
the number of refinement steps, and bridges the gap between non-autoregressive and autoregressive models in |
|
terms of audio quality. We find that it can generate high fidelity audio samples using as few as six iterations. |
|
Experiments reveal WaveGrad to generate high fidelity audio, outperforming adversarial non-autoregressive |
|
baselines and matching a strong likelihood-based autoregressive baseline using fewer sequential operations. |
|
Audio samples are available at this https URL. |
|
""" |
|
|
|
|
|
def __init__(self, config: Coqpit): |
|
super().__init__(config) |
|
self.config = config |
|
self.use_weight_norm = config.model_params.use_weight_norm |
|
self.hop_len = np.prod(config.model_params.upsample_factors) |
|
self.noise_level = None |
|
self.num_steps = None |
|
self.beta = None |
|
self.alpha = None |
|
self.alpha_hat = None |
|
self.c1 = None |
|
self.c2 = None |
|
self.sigma = None |
|
|
|
|
|
self.y_conv = Conv1d(1, config.model_params.y_conv_channels, 5, padding=2) |
|
self.dblocks = nn.ModuleList([]) |
|
ic = config.model_params.y_conv_channels |
|
for oc, df in zip(config.model_params.dblock_out_channels, reversed(config.model_params.upsample_factors)): |
|
self.dblocks.append(DBlock(ic, oc, df)) |
|
ic = oc |
|
|
|
|
|
self.film = nn.ModuleList([]) |
|
ic = config.model_params.y_conv_channels |
|
for oc in reversed(config.model_params.ublock_out_channels): |
|
self.film.append(FiLM(ic, oc)) |
|
ic = oc |
|
|
|
|
|
self.ublocks = nn.ModuleList([]) |
|
ic = config.model_params.x_conv_channels |
|
for oc, uf, ud in zip( |
|
config.model_params.ublock_out_channels, |
|
config.model_params.upsample_factors, |
|
config.model_params.upsample_dilations, |
|
): |
|
self.ublocks.append(UBlock(ic, oc, uf, ud)) |
|
ic = oc |
|
|
|
self.x_conv = Conv1d(config.model_params.in_channels, config.model_params.x_conv_channels, 3, padding=1) |
|
self.out_conv = Conv1d(oc, config.model_params.out_channels, 3, padding=1) |
|
|
|
if config.model_params.use_weight_norm: |
|
self.apply_weight_norm() |
|
|
|
def forward(self, x, spectrogram, noise_scale): |
|
shift_and_scale = [] |
|
|
|
x = self.y_conv(x) |
|
shift_and_scale.append(self.film[0](x, noise_scale)) |
|
|
|
for film, layer in zip(self.film[1:], self.dblocks): |
|
x = layer(x) |
|
shift_and_scale.append(film(x, noise_scale)) |
|
|
|
x = self.x_conv(spectrogram) |
|
for layer, (film_shift, film_scale) in zip(self.ublocks, reversed(shift_and_scale)): |
|
x = layer(x, film_shift, film_scale) |
|
x = self.out_conv(x) |
|
return x |
|
|
|
def load_noise_schedule(self, path): |
|
beta = np.load(path, allow_pickle=True).item()["beta"] |
|
self.compute_noise_level(beta) |
|
|
|
@torch.no_grad() |
|
def inference(self, x, y_n=None): |
|
""" |
|
Shapes: |
|
x: :math:`[B, C , T]` |
|
y_n: :math:`[B, 1, T]` |
|
""" |
|
if y_n is None: |
|
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1]) |
|
else: |
|
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0) |
|
y_n = y_n.type_as(x) |
|
sqrt_alpha_hat = self.noise_level.to(x) |
|
for n in range(len(self.alpha) - 1, -1, -1): |
|
y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0]))) |
|
if n > 0: |
|
z = torch.randn_like(y_n) |
|
y_n += self.sigma[n - 1] * z |
|
y_n.clamp_(-1.0, 1.0) |
|
return y_n |
|
|
|
def compute_y_n(self, y_0): |
|
"""Compute noisy audio based on noise schedule""" |
|
self.noise_level = self.noise_level.to(y_0) |
|
if len(y_0.shape) == 3: |
|
y_0 = y_0.squeeze(1) |
|
s = torch.randint(0, self.num_steps - 1, [y_0.shape[0]]) |
|
l_a, l_b = self.noise_level[s], self.noise_level[s + 1] |
|
noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a) |
|
noise_scale = noise_scale.unsqueeze(1) |
|
noise = torch.randn_like(y_0) |
|
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2) ** 0.5 * noise |
|
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0] |
|
|
|
def compute_noise_level(self, beta): |
|
"""Compute noise schedule parameters""" |
|
self.num_steps = len(beta) |
|
alpha = 1 - beta |
|
alpha_hat = np.cumprod(alpha) |
|
noise_level = np.concatenate([[1.0], alpha_hat**0.5], axis=0) |
|
noise_level = alpha_hat**0.5 |
|
|
|
|
|
self.beta = torch.tensor(beta.astype(np.float32)) |
|
self.alpha = torch.tensor(alpha.astype(np.float32)) |
|
self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32)) |
|
self.noise_level = torch.tensor(noise_level.astype(np.float32)) |
|
|
|
self.c1 = 1 / self.alpha**0.5 |
|
self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5 |
|
self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5 |
|
|
|
def remove_weight_norm(self): |
|
for _, layer in enumerate(self.dblocks): |
|
if len(layer.state_dict()) != 0: |
|
try: |
|
nn.utils.remove_weight_norm(layer) |
|
except ValueError: |
|
layer.remove_weight_norm() |
|
|
|
for _, layer in enumerate(self.film): |
|
if len(layer.state_dict()) != 0: |
|
try: |
|
nn.utils.remove_weight_norm(layer) |
|
except ValueError: |
|
layer.remove_weight_norm() |
|
|
|
for _, layer in enumerate(self.ublocks): |
|
if len(layer.state_dict()) != 0: |
|
try: |
|
nn.utils.remove_weight_norm(layer) |
|
except ValueError: |
|
layer.remove_weight_norm() |
|
|
|
nn.utils.remove_weight_norm(self.x_conv) |
|
nn.utils.remove_weight_norm(self.out_conv) |
|
nn.utils.remove_weight_norm(self.y_conv) |
|
|
|
def apply_weight_norm(self): |
|
for _, layer in enumerate(self.dblocks): |
|
if len(layer.state_dict()) != 0: |
|
layer.apply_weight_norm() |
|
|
|
for _, layer in enumerate(self.film): |
|
if len(layer.state_dict()) != 0: |
|
layer.apply_weight_norm() |
|
|
|
for _, layer in enumerate(self.ublocks): |
|
if len(layer.state_dict()) != 0: |
|
layer.apply_weight_norm() |
|
|
|
self.x_conv = weight_norm(self.x_conv) |
|
self.out_conv = weight_norm(self.out_conv) |
|
self.y_conv = weight_norm(self.y_conv) |
|
|
|
def load_checkpoint( |
|
self, config, checkpoint_path, eval=False |
|
): |
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) |
|
self.load_state_dict(state["model"]) |
|
if eval: |
|
self.eval() |
|
assert not self.training |
|
if self.config.model_params.use_weight_norm: |
|
self.remove_weight_norm() |
|
betas = np.linspace( |
|
config["test_noise_schedule"]["min_val"], |
|
config["test_noise_schedule"]["max_val"], |
|
config["test_noise_schedule"]["num_steps"], |
|
) |
|
self.compute_noise_level(betas) |
|
else: |
|
betas = np.linspace( |
|
config["train_noise_schedule"]["min_val"], |
|
config["train_noise_schedule"]["max_val"], |
|
config["train_noise_schedule"]["num_steps"], |
|
) |
|
self.compute_noise_level(betas) |
|
|
|
def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]: |
|
|
|
x = batch["input"] |
|
y = batch["waveform"] |
|
|
|
|
|
noise, x_noisy, noise_scale = self.compute_y_n(y) |
|
|
|
|
|
noise_hat = self.forward(x_noisy, x, noise_scale) |
|
|
|
|
|
loss = criterion(noise, noise_hat) |
|
return {"model_output": noise_hat}, {"loss": loss} |
|
|
|
def train_log( |
|
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int |
|
) -> Tuple[Dict, np.ndarray]: |
|
pass |
|
|
|
@torch.no_grad() |
|
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: |
|
return self.train_step(batch, criterion) |
|
|
|
def eval_log( |
|
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int |
|
) -> None: |
|
pass |
|
|
|
def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): |
|
|
|
ap = assets["audio_processor"] |
|
noise_schedule = self.config["test_noise_schedule"] |
|
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) |
|
self.compute_noise_level(betas) |
|
samples = test_loader.dataset.load_test_samples(1) |
|
for sample in samples: |
|
x = sample[0] |
|
x = x[None, :, :].to(next(self.parameters()).device) |
|
y = sample[1] |
|
y = y[None, :] |
|
|
|
y_pred = self.inference(x) |
|
|
|
figures = plot_results(y_pred, y, ap, "test") |
|
|
|
sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy() |
|
return figures, {"test/audio": sample_voice} |
|
|
|
def get_optimizer(self): |
|
return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self) |
|
|
|
def get_scheduler(self, optimizer): |
|
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer) |
|
|
|
@staticmethod |
|
def get_criterion(): |
|
return torch.nn.L1Loss() |
|
|
|
@staticmethod |
|
def format_batch(batch: Dict) -> Dict: |
|
|
|
m, y = batch[0], batch[1] |
|
y = y.unsqueeze(1) |
|
return {"input": m, "waveform": y} |
|
|
|
def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int): |
|
ap = assets["audio_processor"] |
|
dataset = WaveGradDataset( |
|
ap=ap, |
|
items=samples, |
|
seq_len=self.config.seq_len, |
|
hop_len=ap.hop_length, |
|
pad_short=self.config.pad_short, |
|
conv_pad=self.config.conv_pad, |
|
is_training=not is_eval, |
|
return_segments=True, |
|
use_noise_augment=False, |
|
use_cache=config.use_cache, |
|
verbose=verbose, |
|
) |
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None |
|
loader = DataLoader( |
|
dataset, |
|
batch_size=self.config.batch_size, |
|
shuffle=num_gpus <= 1, |
|
drop_last=False, |
|
sampler=sampler, |
|
num_workers=self.config.num_eval_loader_workers if is_eval else self.config.num_loader_workers, |
|
pin_memory=False, |
|
) |
|
return loader |
|
|
|
def on_epoch_start(self, trainer): |
|
noise_schedule = self.config["train_noise_schedule"] |
|
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) |
|
self.compute_noise_level(betas) |
|
|
|
@staticmethod |
|
def init_from_config(config: "WavegradConfig"): |
|
return Wavegrad(config) |
|
|