|
import numpy as np |
|
import torch |
|
import torchaudio |
|
from coqpit import Coqpit |
|
from torch import nn |
|
|
|
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss |
|
from TTS.utils.generic_utils import set_init_dict |
|
from TTS.utils.io import load_fsspec |
|
|
|
|
|
class PreEmphasis(nn.Module): |
|
def __init__(self, coefficient=0.97): |
|
super().__init__() |
|
self.coefficient = coefficient |
|
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) |
|
|
|
def forward(self, x): |
|
assert len(x.size()) == 2 |
|
|
|
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") |
|
return torch.nn.functional.conv1d(x, self.filter).squeeze(1) |
|
|
|
|
|
class BaseEncoder(nn.Module): |
|
"""Base `encoder` class. Every new `encoder` model must inherit this. |
|
|
|
It defines common `encoder` specific functions. |
|
""" |
|
|
|
|
|
def __init__(self): |
|
super(BaseEncoder, self).__init__() |
|
|
|
def get_torch_mel_spectrogram_class(self, audio_config): |
|
return torch.nn.Sequential( |
|
PreEmphasis(audio_config["preemphasis"]), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torchaudio.transforms.MelSpectrogram( |
|
sample_rate=audio_config["sample_rate"], |
|
n_fft=audio_config["fft_size"], |
|
win_length=audio_config["win_length"], |
|
hop_length=audio_config["hop_length"], |
|
window_fn=torch.hamming_window, |
|
n_mels=audio_config["num_mels"], |
|
), |
|
) |
|
|
|
@torch.no_grad() |
|
def inference(self, x, l2_norm=True): |
|
return self.forward(x, l2_norm) |
|
|
|
@torch.no_grad() |
|
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True): |
|
""" |
|
Generate embeddings for a batch of utterances |
|
x: 1xTxD |
|
""" |
|
|
|
if self.use_torch_spec: |
|
num_frames = num_frames * self.audio_config["hop_length"] |
|
|
|
max_len = x.shape[1] |
|
|
|
if max_len < num_frames: |
|
num_frames = max_len |
|
|
|
offsets = np.linspace(0, max_len - num_frames, num=num_eval) |
|
|
|
frames_batch = [] |
|
for offset in offsets: |
|
offset = int(offset) |
|
end_offset = int(offset + num_frames) |
|
frames = x[:, offset:end_offset] |
|
frames_batch.append(frames) |
|
|
|
frames_batch = torch.cat(frames_batch, dim=0) |
|
embeddings = self.inference(frames_batch, l2_norm=l2_norm) |
|
|
|
if return_mean: |
|
embeddings = torch.mean(embeddings, dim=0, keepdim=True) |
|
return embeddings |
|
|
|
def get_criterion(self, c: Coqpit, num_classes=None): |
|
if c.loss == "ge2e": |
|
criterion = GE2ELoss(loss_method="softmax") |
|
elif c.loss == "angleproto": |
|
criterion = AngleProtoLoss() |
|
elif c.loss == "softmaxproto": |
|
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes) |
|
else: |
|
raise Exception("The %s not is a loss supported" % c.loss) |
|
return criterion |
|
|
|
def load_checkpoint( |
|
self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None |
|
): |
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) |
|
try: |
|
self.load_state_dict(state["model"]) |
|
except (KeyError, RuntimeError) as error: |
|
|
|
if eval: |
|
raise error |
|
|
|
print(" > Partial model initialization.") |
|
model_dict = self.state_dict() |
|
model_dict = set_init_dict(model_dict, state["model"], c) |
|
self.load_state_dict(model_dict) |
|
del model_dict |
|
|
|
|
|
if criterion is not None and "criterion" in state: |
|
try: |
|
criterion.load_state_dict(state["criterion"]) |
|
except (KeyError, RuntimeError) as error: |
|
print(" > Criterion load ignored because of:", error) |
|
|
|
|
|
if ( |
|
eval |
|
and criterion is None |
|
and "criterion" in state |
|
and getattr(config, "map_classid_to_classname", None) is not None |
|
): |
|
criterion = self.get_criterion(config, len(config.map_classid_to_classname)) |
|
criterion.load_state_dict(state["criterion"]) |
|
|
|
if use_cuda: |
|
self.cuda() |
|
if criterion is not None: |
|
criterion = criterion.cuda() |
|
|
|
if eval: |
|
self.eval() |
|
assert not self.training |
|
|
|
if not eval: |
|
return criterion, state["step"] |
|
return criterion |
|
|