indic / TTS /encoder /models /base_encoder.py
azamat's picture
Init
6127b48
raw
history blame
5.35 kB
import numpy as np
import torch
import torchaudio
from coqpit import Coqpit
from torch import nn
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.utils.generic_utils import set_init_dict
from TTS.utils.io import load_fsspec
class PreEmphasis(nn.Module):
def __init__(self, coefficient=0.97):
super().__init__()
self.coefficient = coefficient
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
def forward(self, x):
assert len(x.size()) == 2
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
class BaseEncoder(nn.Module):
"""Base `encoder` class. Every new `encoder` model must inherit this.
It defines common `encoder` specific functions.
"""
# pylint: disable=W0102
def __init__(self):
super(BaseEncoder, self).__init__()
def get_torch_mel_spectrogram_class(self, audio_config):
return torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
# TorchSTFT(
# n_fft=audio_config["fft_size"],
# hop_length=audio_config["hop_length"],
# win_length=audio_config["win_length"],
# sample_rate=audio_config["sample_rate"],
# window="hamming_window",
# mel_fmin=0.0,
# mel_fmax=None,
# use_htk=True,
# do_amp_to_db=False,
# n_mels=audio_config["num_mels"],
# power=2.0,
# use_mel=True,
# mel_norm=None,
# )
torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
),
)
@torch.no_grad()
def inference(self, x, l2_norm=True):
return self.forward(x, l2_norm)
@torch.no_grad()
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
"""
Generate embeddings for a batch of utterances
x: 1xTxD
"""
# map to the waveform size
if self.use_torch_spec:
num_frames = num_frames * self.audio_config["hop_length"]
max_len = x.shape[1]
if max_len < num_frames:
num_frames = max_len
offsets = np.linspace(0, max_len - num_frames, num=num_eval)
frames_batch = []
for offset in offsets:
offset = int(offset)
end_offset = int(offset + num_frames)
frames = x[:, offset:end_offset]
frames_batch.append(frames)
frames_batch = torch.cat(frames_batch, dim=0)
embeddings = self.inference(frames_batch, l2_norm=l2_norm)
if return_mean:
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
return embeddings
def get_criterion(self, c: Coqpit, num_classes=None):
if c.loss == "ge2e":
criterion = GE2ELoss(loss_method="softmax")
elif c.loss == "angleproto":
criterion = AngleProtoLoss()
elif c.loss == "softmaxproto":
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes)
else:
raise Exception("The %s not is a loss supported" % c.loss)
return criterion
def load_checkpoint(
self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None
):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
try:
self.load_state_dict(state["model"])
except (KeyError, RuntimeError) as error:
# If eval raise the error
if eval:
raise error
print(" > Partial model initialization.")
model_dict = self.state_dict()
model_dict = set_init_dict(model_dict, state["model"], c)
self.load_state_dict(model_dict)
del model_dict
# load the criterion for restore_path
if criterion is not None and "criterion" in state:
try:
criterion.load_state_dict(state["criterion"])
except (KeyError, RuntimeError) as error:
print(" > Criterion load ignored because of:", error)
# instance and load the criterion for the encoder classifier in inference time
if (
eval
and criterion is None
and "criterion" in state
and getattr(config, "map_classid_to_classname", None) is not None
):
criterion = self.get_criterion(config, len(config.map_classid_to_classname))
criterion.load_state_dict(state["criterion"])
if use_cuda:
self.cuda()
if criterion is not None:
criterion = criterion.cuda()
if eval:
self.eval()
assert not self.training
if not eval:
return criterion, state["step"]
return criterion