aheba31 commited on
Commit
e43d0df
1 Parent(s): d32c845
Files changed (1) hide show
  1. inference.py +153 -1
inference.py CHANGED
@@ -1,4 +1,156 @@
1
  import torch
2
  from speechbrain.pretrained import Pretrained
3
 
4
- class EncoderClassifier(Pretrained):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from speechbrain.pretrained import Pretrained
3
 
4
+ class EncoderClassifier(Pretrained):
5
+ """A ready-to-use class for utterance-level classification (e.g, speaker-id,
6
+ language-id, emotion recognition, keyword spotting, etc).
7
+
8
+ The class assumes that an encoder called "embedding_model" and a model
9
+ called "classifier" are defined in the yaml file. If you want to
10
+ convert the predicted index into a corresponding text label, please
11
+ provide the path of the label_encoder in a variable called 'lab_encoder_file'
12
+ within the yaml.
13
+
14
+ The class can be used either to run only the encoder (encode_batch()) to
15
+ extract embeddings or to run a classification step (classify_batch()).
16
+ ```
17
+
18
+ Example
19
+ -------
20
+ >>> import torchaudio
21
+ >>> from speechbrain.pretrained import EncoderClassifier
22
+ >>> # Model is downloaded from the speechbrain HuggingFace repo
23
+ >>> tmpdir = getfixture("tmpdir")
24
+ >>> classifier = EncoderClassifier.from_hparams(
25
+ ... source="speechbrain/spkrec-ecapa-voxceleb",
26
+ ... savedir=tmpdir,
27
+ ... )
28
+
29
+ >>> # Compute embeddings
30
+ >>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav")
31
+ >>> embeddings = classifier.encode_batch(signal)
32
+
33
+ >>> # Classification
34
+ >>> prediction = classifier .classify_batch(signal)
35
+ """
36
+
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs)
39
+
40
+ def encode_batch(self, wavs, wav_lens=None, normalize=False):
41
+ """Encodes the input audio into a single vector embedding.
42
+
43
+ The waveforms should already be in the model's desired format.
44
+ You can call:
45
+ ``normalized = <this>.normalizer(signal, sample_rate)``
46
+ to get a correctly converted signal in most cases.
47
+
48
+ Arguments
49
+ ---------
50
+ wavs : torch.tensor
51
+ Batch of waveforms [batch, time, channels] or [batch, time]
52
+ depending on the model. Make sure the sample rate is fs=16000 Hz.
53
+ wav_lens : torch.tensor
54
+ Lengths of the waveforms relative to the longest one in the
55
+ batch, tensor of shape [batch]. The longest one should have
56
+ relative length 1.0 and others len(waveform) / max_length.
57
+ Used for ignoring padding.
58
+ normalize : bool
59
+ If True, it normalizes the embeddings with the statistics
60
+ contained in mean_var_norm_emb.
61
+
62
+ Returns
63
+ -------
64
+ torch.tensor
65
+ The encoded batch
66
+ """
67
+ # Manage single waveforms in input
68
+ if len(wavs.shape) == 1:
69
+ wavs = wavs.unsqueeze(0)
70
+
71
+ # Assign full length if wav_lens is not assigned
72
+ if wav_lens is None:
73
+ wav_lens = torch.ones(wavs.shape[0], device=self.device)
74
+
75
+ # Storing waveform in the specified device
76
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
77
+ wavs = wavs.float()
78
+
79
+ # Computing features and embeddings
80
+ feats = self.mods.compute_features(wavs)
81
+ feats = self.mods.mean_var_norm(feats, wav_lens)
82
+ embeddings = self.mods.embedding_model(feats, wav_lens)
83
+ if normalize:
84
+ embeddings = self.hparams.mean_var_norm_emb(
85
+ embeddings, torch.ones(embeddings.shape[0], device=self.device)
86
+ )
87
+ return embeddings
88
+
89
+ def classify_batch(self, wavs, wav_lens=None):
90
+ """Performs classification on the top of the encoded features.
91
+
92
+ It returns the posterior probabilities, the index and, if the label
93
+ encoder is specified it also the text label.
94
+
95
+ Arguments
96
+ ---------
97
+ wavs : torch.tensor
98
+ Batch of waveforms [batch, time, channels] or [batch, time]
99
+ depending on the model. Make sure the sample rate is fs=16000 Hz.
100
+ wav_lens : torch.tensor
101
+ Lengths of the waveforms relative to the longest one in the
102
+ batch, tensor of shape [batch]. The longest one should have
103
+ relative length 1.0 and others len(waveform) / max_length.
104
+ Used for ignoring padding.
105
+
106
+ Returns
107
+ -------
108
+ out_prob
109
+ The log posterior probabilities of each class ([batch, N_class])
110
+ score:
111
+ It is the value of the log-posterior for the best class ([batch,])
112
+ index
113
+ The indexes of the best class ([batch,])
114
+ text_lab:
115
+ List with the text labels corresponding to the indexes.
116
+ (label encoder should be provided).
117
+ """
118
+ emb = self.encode_batch(wavs, wav_lens)
119
+ out_prob = self.mods.classifier(emb).squeeze(1)
120
+ score, index = torch.max(out_prob, dim=-1)
121
+ text_lab = self.hparams.label_encoder.decode_torch(index)
122
+ return out_prob, score, index, text_lab
123
+
124
+ def classify_file(self, path):
125
+ """Classifies the given audiofile into the given set of labels.
126
+
127
+ Arguments
128
+ ---------
129
+ path : str
130
+ Path to audio file to classify.
131
+
132
+ Returns
133
+ -------
134
+ out_prob
135
+ The log posterior probabilities of each class ([batch, N_class])
136
+ score:
137
+ It is the value of the log-posterior for the best class ([batch,])
138
+ index
139
+ The indexes of the best class ([batch,])
140
+ text_lab:
141
+ List with the text labels corresponding to the indexes.
142
+ (label encoder should be provided).
143
+ """
144
+ waveform = self.load_audio(path)
145
+ # Fake a batch:
146
+ batch = waveform.unsqueeze(0)
147
+ rel_length = torch.tensor([1.0])
148
+ emb = self.encode_batch(batch, rel_length)
149
+ out_prob = self.mods.classifier(emb).squeeze(1)
150
+ score, index = torch.max(out_prob, dim=-1)
151
+ text_lab = self.hparams.label_encoder.decode_torch(index)
152
+ return out_prob, score, index, text_lab
153
+
154
+ def forward(self, wavs, wav_lens=None):
155
+ """Runs the classification"""
156
+ return self.classify_batch(wavs, wav_lens)