import os import sys import torch import argparse from halo import Halo import soundfile as sf from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor path_this_file = os.path.dirname(os.path.abspath(__file__)) pat_project_root = os.path.join(path_this_file, "..") sys.path.append(pat_project_root) class Wav2Vec2Inference: def __init__(self,model_name, hotwords=[], use_lm_if_possible=True, use_gpu=True): """ Initializes the class with the provided parameters. Args: model_name (str): The name of the model to be used. hotwords (list, optional): A list of hotwords. Defaults to an empty list. use_lm_if_possible (bool, optional): Specifies whether to use a language model if possible. Defaults to True. use_gpu (bool, optional): Specifies whether to use the GPU. Defaults to True. Returns: None """ self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu" if use_lm_if_possible: self.processor = AutoProcessor.from_pretrained(model_name) else: self.processor = Wav2Vec2Processor.from_pretrained(model_name) self.model = AutoModelForCTC.from_pretrained(model_name) self.model.to(self.device) self.hotwords = hotwords self.use_lm_if_possible = use_lm_if_possible def buffer_to_text(self, audio_buffer): """ Transcribes the given audio buffer into text. Args: audio_buffer (list): A list representing the audio buffer. Returns: tuple: A tuple containing the transcribed text (str) and the confidence score (float). """ spinner = Halo(text="Transcribing audio...", spinner="dots") spinner.start() try: if len(audio_buffer) == 0: return "" inputs = self.processor(torch.tensor(audio_buffer), sampling_rate=16_000, return_tensors="pt", padding=True) with torch.no_grad(): logits = self.model(inputs.input_values.to(self.device), attention_mask=inputs.attention_mask.to(self.device)).logits if hasattr(self.processor, 'decoder') and self.use_lm_if_possible: transcription = \ self.processor.decode(logits[0].cpu().numpy(), hotwords=self.hotwords, output_word_offsets=True, ) confidence = transcription.lm_score / len(transcription.text.split(" ")) transcription = transcription.text else: predicted_ids = torch.argmax(logits, dim=-1) transcription = self.processor.batch_decode(predicted_ids)[0] confidence = self.confidence_score(logits,predicted_ids) spinner.succeed("Audio transcribed successfully!") return transcription, confidence.item() except Exception as e: spinner.fail(f"Error during transcription: {str(e)}") return "", 0.0 def confidence_score(self, logits, predicted_ids): """ Calculate the confidence score for the predicted IDs based on the logits. Parameters: logits (torch.Tensor): The logits tensor. predicted_ids (torch.Tensor): The predicted IDs tensor. Returns: float: The average confidence score for the predicted IDs. """ scores = torch.nn.functional.softmax(logits, dim=-1) pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0] mask = torch.logical_and( predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id), predicted_ids.not_equal(self.processor.tokenizer.pad_token_id)) character_scores = pred_scores.masked_select(mask) total_average = torch.sum(character_scores) / len(character_scores) return total_average def file_to_text(self, filename): """ Reads an audio file and converts it to text using the buffer_to_text method. Parameters: filename (str): The path to the audio file. Returns: tuple: A tuple containing the transcription (str) and the confidence (float) of the transcription. If there is an error reading the audio file, an empty string and a confidence of 0.0 will be returned. """ spinner = Halo(text="Reading audio file...", spinner="dots") spinner.start() try: audio_input, samplerate = sf.read(filename) assert samplerate == 16000 transcription, confidence = self.buffer_to_text(audio_input) spinner.succeed("File read successfully!") return transcription, confidence except Exception as e: spinner.fail(f"Error reading audio file: {str(e)}") return "", 0.0 if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_name", type=str, default="arifagustyawan/wav2vec2-large-xlsr-common_voice_13_0-id") parser.add_argument("--filename", type=str, default="assets/halo.wav") args = parser.parse_args() with Halo(text="Initializing Wav2Vec2 Inference...", spinner="dots") as init_spinner: try: asr = Wav2Vec2Inference(args.model_name) init_spinner.succeed("Wav2Vec2 Inference initialized successfully!") except Exception as e: init_spinner.fail(f"Error initializing Wav2Vec2 Inference: {str(e)}") sys.exit(1) with Halo(text="Performing audio transcription...", spinner="dots") as transcribe_spinner: transcription, confidence = asr.file_to_text(args.filename) print("\033[94mTranscription:\033[0m", transcription) print("\033[94mConfidence:\033[0m", confidence)