|
"""Custom Interface for AISHELL-1 CTC inference |
|
An external tokenizer is used so some special tokens |
|
need to be specified during decoding |
|
|
|
Authors |
|
* Yingzhi Wang 2022 |
|
""" |
|
|
|
import torch |
|
from speechbrain.inference.interfaces import Pretrained |
|
|
|
|
|
class CustomEncoderDecoderASR(Pretrained): |
|
"""A ready-to-use Encoder-Decoder ASR model |
|
The class can be used either to run only the encoder (encode()) to extract |
|
features or to run the entire encoder-decoder model |
|
(transcribe()) to transcribe speech. The given YAML must contains the fields |
|
specified in the *_NEEDED[] lists. |
|
Example |
|
------- |
|
>>> from speechbrain.pretrained import EncoderDecoderASR |
|
>>> tmpdir = getfixture("tmpdir") |
|
>>> asr_model = EncoderDecoderASR.from_hparams( |
|
... source="speechbrain/asr-crdnn-rnnlm-librispeech", |
|
... savedir=tmpdir, |
|
... ) |
|
>>> asr_model.transcribe_file("tests/samples/single-mic/example2.flac") |
|
"MY FATHER HAS REVEALED THE CULPRIT'S NAME" |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.tokenizer = self.hparams.tokenizer |
|
|
|
def transcribe_file(self, path): |
|
"""Transcribes the given audiofile into a sequence of words. |
|
Arguments |
|
--------- |
|
path : str |
|
Path to audio file which to transcribe. |
|
Returns |
|
------- |
|
str |
|
The audiofile transcription produced by this ASR system. |
|
""" |
|
waveform = self.load_audio(path) |
|
|
|
batch = waveform.unsqueeze(0) |
|
rel_length = torch.tensor([1.0]) |
|
predicted_words = self.transcribe_batch( |
|
batch, rel_length |
|
) |
|
return predicted_words[0] |
|
|
|
def encode_batch(self, wavs): |
|
"""Encodes the input audio into a sequence of hidden states |
|
The waveforms should already be in the model's desired format. |
|
You can call: |
|
``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)`` |
|
to get a correctly converted signal in most cases. |
|
Arguments |
|
--------- |
|
wavs : torch.tensor |
|
Batch of waveforms [batch, time, channels] or [batch, time] |
|
depending on the model. |
|
wav_lens : torch.tensor |
|
Lengths of the waveforms relative to the longest one in the |
|
batch, tensor of shape [batch]. The longest one should have |
|
relative length 1.0 and others len(waveform) / max_length. |
|
Used for ignoring padding. |
|
Returns |
|
------- |
|
torch.tensor |
|
The encoded batch |
|
""" |
|
wavs = wavs.float() |
|
wavs = wavs.to(self.device) |
|
outputs = self.mods.wav2vec2(wavs) |
|
outputs = self.mods.enc(outputs) |
|
outputs = self.mods.ctc_lin(outputs) |
|
return outputs |
|
|
|
def transcribe_batch(self, wavs, wav_lens): |
|
"""Transcribes the input audio into a sequence of words |
|
The waveforms should already be in the model's desired format. |
|
You can call: |
|
``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)`` |
|
to get a correctly converted signal in most cases. |
|
Arguments |
|
--------- |
|
wavs : torch.tensor |
|
Batch of waveforms [batch, time, channels] or [batch, time] |
|
depending on the model. |
|
wav_lens : torch.tensor |
|
Lengths of the waveforms relative to the longest one in the |
|
batch, tensor of shape [batch]. The longest one should have |
|
relative length 1.0 and others len(waveform) / max_length. |
|
Used for ignoring padding. |
|
Returns |
|
------- |
|
list |
|
Each waveform in the batch transcribed. |
|
tensor |
|
Each predicted token id. |
|
""" |
|
with torch.no_grad(): |
|
wav_lens = wav_lens.to(self.device) |
|
encoder_out = self.encode_batch(wavs) |
|
p_ctc = self.hparams.log_softmax(encoder_out) |
|
sequences = self.hparams.decoder(p_ctc, wav_lens) |
|
predicted_words_list = [] |
|
for sequence in sequences: |
|
predicted_tokens = self.tokenizer.convert_ids_to_tokens( |
|
sequence |
|
) |
|
predicted_words = [] |
|
for c in predicted_tokens: |
|
if c == "[CLS]": |
|
continue |
|
elif c == "[SEP]" or c == "[PAD]": |
|
break |
|
else: |
|
predicted_words.append(c) |
|
predicted_words_list.append(predicted_words) |
|
|
|
return predicted_words_list |
|
|
|
def forward(self, wavs, wav_lens): |
|
"""Runs full transcription - note: no gradients through decoding""" |
|
return self.transcribe_batch(wavs, wav_lens) |
|
|