import torch import speechbrain as sb class FeatureScaler(torch.nn.Module): def __init__(self, num_in, scale): super().__init__() self.scaler = torch.ones((num_in,))* scale def forward(self, x): return x * self.scaler class CustomInterface(sb.pretrained.interfaces.Pretrained): MODULES_NEEDED = ["normalizer"] HPARAMS_NEEDED = ["feature_extractor"] def feats_from_audio(self, audio, lengths=torch.tensor([1.0])): feats = self.hparams.feature_extractor(audio) normalized = self.mods.normalizer(feats, lengths) scaled = self.mods.feature_scaler(normalized) return scaled def feats_from_file(self, path): audio = self.load_audio(path) return self.feats_from_audio(audio.unsqueeze(0)).squeeze(0)