File size: 793 Bytes
8b70882 fa827db b57ac94 fa827db 0f77019 fa827db 85046ed 8b70882 fa827db 8b70882 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
import speechbrain as sb
class FeatureScaler(torch.nn.Module):
def __init__(self, num_in, scale):
super().__init__()
self.scaler = torch.ones((num_in,))* scale
def forward(self, x):
return x * self.scaler
class CustomInterface(sb.pretrained.interfaces.Pretrained):
MODULES_NEEDED = ["normalizer"]
HPARAMS_NEEDED = ["feature_extractor"]
def feats_from_audio(self, audio, lengths=torch.tensor([1.0])):
feats = self.hparams.feature_extractor(audio)
normalized = self.mods.normalizer(feats, lengths)
scaled = self.mods.feature_scaler(normalized)
return scaled
def feats_from_file(self, path):
audio = self.load_audio(path)
return self.feats_from_audio(audio.unsqueeze(0)).squeeze(0)
|