Edit model card

End of Speech Detection with Wav2Vec 2.0

The End-of-Speech model is based on the open-source Wav2Vec 2.0 model from Meta AI. It uses convolutional feature encoders, which translate chunks of raw audio input into latent speech representations and a transformer to capture the information throughout this sequence of representations. This helps the model distinguish different pitch declines, as well as final lengthening (and the following pause) in the intonation and therefore distinguish when an end of speech event occurs - the same way us humans do.

Training Data

The training data is constructed from the Common voice 16.0 English Audio dataset by the Mozilla Firefox foundation. It is under a permissive license CC0 1.0.

In order to train the wav2vec 2.0 model for end of speech, we would need a large enough dataset that consists of both end of speech and not end of speech samples. Since there weren’t any open source datasets that contained such ready samples, we needed to construct one. The common voice dataset consists of audio samples that contain only one spoken sentence each.

Unfortunately, there is additional noisy/empty audio in the beginning and end of the audio samples. To remove those and capture only the audio that corresponds to the spoken sentence, we would need the timestamp of the sentence, or better yet, the word level timestamps. This is achieved with the help of whisperX. This way we capture when the sentence starts and finishes and remove anything before and after.

After cleaning the samples, we ran through random samples to validate the correctness of the procedure. Afterwards we label the last 700/704ms of the audio samples as end of speech events and all before that as not end of speech.

Finally, in addition, we added overlapping segments to the dataset by moving the 700/704ms window in both directions.

Input

The model is trained at 700 and 704ms (11x64ms) inputs of raw audio. The sample rate is 16kHz. During experiments different lengths have been tested (300ms, 500ms and 1 sec) and 700/704ms proved to be the middle ground between good enough performance and shortest chunk.

Output

The model classifies each audio input into 2 classes - eos (id: 0) and not_eos (id: 1).

Usage

from transformers import Wav2Vec2Processor, AutoConfig
import onnxruntime as rt
import torch
import torch.nn.functional as F
import numpy as np
import os
import torchaudio


class EndOfSpeechDetection:
    processor: Wav2Vec2Processor
    config: AutoConfig
    session: rt.InferenceSession

    def load_model(self, path, use_gpu=False):
        processor = Wav2Vec2Processor.from_pretrained(path)
        config = AutoConfig.from_pretrained(path)

        sess_options = rt.SessionOptions()
        sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL

        providers = ["ROCMExecutionProvider"] if use_gpu else ["CPUExecutionProvider"]
        session = rt.InferenceSession(
            os.path.join(path, "model.onnx"), sess_options, providers=providers
        )
        return processor, config, session

    def predict(self, segment, file_type="pcm"):
        if file_type == "pcm":
            # pcm files
            speech_array = np.memmap(segment, dtype="float32", mode="r").astype(
                np.float32
            )
        else:
            # wave files
            speech_array, _ = torchaudio.load(segment)
            speech_array = speech_array[0].numpy()

        features = self.processor(
            speech_array, sampling_rate=16000, return_tensors="pt", padding=True
        )
        input_values = features.input_values
        outputs = self.session.run(
            [self.session.get_outputs()[-1].name],
            {self.session.get_inputs()[-1].name: input_values.detach().cpu().numpy()},
        )[0]
        softmax_output = F.softmax(torch.tensor(outputs), dim=1)

        both_classes_with_prob = {
            self.config.id2label[i]: softmax_output[0][i].item()
            for i in range(len(softmax_output[0]))
        }

        return both_classes_with_prob


if __name__ == "__main__":
    eos = EndOfSpeechDetection()
    eos.processor, eos.config, eos.session = eos.load_model("eos-model-onnx")
    print(eos.predict("some.pcm", file_type="pcm"))

Latency (& Memory) Optimization

  • Knowledge Distillation
  • Onnx format weights
    • The weights are converted in the Onnx format (in order to optimize CPU & GPU Performance)
    • As tested on an AMD Instinct MI100 GPU - sub 10ms inference per 704ms audio chunk

Evaluation

Accuracy at 0.95 with 8120 samples tested.

classes precision recall f1-score support
eos 0.94 0.95 0.95 4060
not_eos 0.95 0.94 0.95 4060
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Examples
Inference API (serverless) is not available, repository is disabled.

Model tree for telnyx/wav2vec2-end-of-speech-detection

Quantized
this model

Dataset used to train telnyx/wav2vec2-end-of-speech-detection