arifagustyawan's picture
initial commit
4eafd35
import os
import sys
import torch
import argparse
from halo import Halo
import soundfile as sf
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor
path_this_file = os.path.dirname(os.path.abspath(__file__))
pat_project_root = os.path.join(path_this_file, "..")
sys.path.append(pat_project_root)
class Wav2Vec2Inference:
def __init__(self,model_name, hotwords=[], use_lm_if_possible=True, use_gpu=True):
"""
Initializes the class with the provided parameters.
Args:
model_name (str): The name of the model to be used.
hotwords (list, optional): A list of hotwords. Defaults to an empty list.
use_lm_if_possible (bool, optional): Specifies whether to use a language model if possible.
Defaults to True.
use_gpu (bool, optional): Specifies whether to use the GPU. Defaults to True.
Returns:
None
"""
self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
if use_lm_if_possible:
self.processor = AutoProcessor.from_pretrained(model_name)
else:
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
self.model = AutoModelForCTC.from_pretrained(model_name)
self.model.to(self.device)
self.hotwords = hotwords
self.use_lm_if_possible = use_lm_if_possible
def buffer_to_text(self, audio_buffer):
"""
Transcribes the given audio buffer into text.
Args:
audio_buffer (list): A list representing the audio buffer.
Returns:
tuple: A tuple containing the transcribed text (str) and the confidence score (float).
"""
spinner = Halo(text="Transcribing audio...", spinner="dots")
spinner.start()
try:
if len(audio_buffer) == 0:
return ""
inputs = self.processor(torch.tensor(audio_buffer), sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = self.model(inputs.input_values.to(self.device),
attention_mask=inputs.attention_mask.to(self.device)).logits
if hasattr(self.processor, 'decoder') and self.use_lm_if_possible:
transcription = \
self.processor.decode(logits[0].cpu().numpy(),
hotwords=self.hotwords,
output_word_offsets=True,
)
confidence = transcription.lm_score / len(transcription.text.split(" "))
transcription = transcription.text
else:
predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.processor.batch_decode(predicted_ids)[0]
confidence = self.confidence_score(logits,predicted_ids)
spinner.succeed("Audio transcribed successfully!")
return transcription, confidence.item()
except Exception as e:
spinner.fail(f"Error during transcription: {str(e)}")
return "", 0.0
def confidence_score(self, logits, predicted_ids):
"""
Calculate the confidence score for the predicted IDs based on the logits.
Parameters:
logits (torch.Tensor): The logits tensor.
predicted_ids (torch.Tensor): The predicted IDs tensor.
Returns:
float: The average confidence score for the predicted IDs.
"""
scores = torch.nn.functional.softmax(logits, dim=-1)
pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0]
mask = torch.logical_and(
predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id),
predicted_ids.not_equal(self.processor.tokenizer.pad_token_id))
character_scores = pred_scores.masked_select(mask)
total_average = torch.sum(character_scores) / len(character_scores)
return total_average
def file_to_text(self, filename):
"""
Reads an audio file and converts it to text using the buffer_to_text method.
Parameters:
filename (str): The path to the audio file.
Returns:
tuple: A tuple containing the transcription (str) and the confidence (float) of the transcription. If there is an error reading the audio file, an empty string and a confidence of 0.0 will be returned.
"""
spinner = Halo(text="Reading audio file...", spinner="dots")
spinner.start()
try:
audio_input, samplerate = sf.read(filename)
assert samplerate == 16000
transcription, confidence = self.buffer_to_text(audio_input)
spinner.succeed("File read successfully!")
return transcription, confidence
except Exception as e:
spinner.fail(f"Error reading audio file: {str(e)}")
return "", 0.0
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="arifagustyawan/wav2vec2-large-xlsr-common_voice_13_0-id")
parser.add_argument("--filename", type=str, default="assets/halo.wav")
args = parser.parse_args()
with Halo(text="Initializing Wav2Vec2 Inference...", spinner="dots") as init_spinner:
try:
asr = Wav2Vec2Inference(args.model_name)
init_spinner.succeed("Wav2Vec2 Inference initialized successfully!")
except Exception as e:
init_spinner.fail(f"Error initializing Wav2Vec2 Inference: {str(e)}")
sys.exit(1)
with Halo(text="Performing audio transcription...", spinner="dots") as transcribe_spinner:
transcription, confidence = asr.file_to_text(args.filename)
print("\033[94mTranscription:\033[0m", transcription)
print("\033[94mConfidence:\033[0m", confidence)