File size: 6,196 Bytes
4eafd35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import sys
import torch
import argparse
from halo import Halo
import soundfile as sf
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor

path_this_file = os.path.dirname(os.path.abspath(__file__))
pat_project_root = os.path.join(path_this_file, "..")
sys.path.append(pat_project_root)

class Wav2Vec2Inference:
    def __init__(self,model_name, hotwords=[], use_lm_if_possible=True, use_gpu=True):
        """
        Initializes the class with the provided parameters.

        Args:
            model_name (str): The name of the model to be used.
            hotwords (list, optional): A list of hotwords. Defaults to an empty list.
            use_lm_if_possible (bool, optional): Specifies whether to use a language model if possible. 
                Defaults to True.
            use_gpu (bool, optional): Specifies whether to use the GPU. Defaults to True.

        Returns:
            None
        """
        self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
        if use_lm_if_possible:            
            self.processor = AutoProcessor.from_pretrained(model_name)
        else:
            self.processor = Wav2Vec2Processor.from_pretrained(model_name)
        self.model = AutoModelForCTC.from_pretrained(model_name)
        self.model.to(self.device)
        self.hotwords = hotwords
        self.use_lm_if_possible = use_lm_if_possible

    def buffer_to_text(self, audio_buffer):
        """
        Transcribes the given audio buffer into text.

        Args:
            audio_buffer (list): A list representing the audio buffer.

        Returns:
            tuple: A tuple containing the transcribed text (str) and the confidence score (float).
        """
        spinner = Halo(text="Transcribing audio...", spinner="dots")
        spinner.start()

        try:
            if len(audio_buffer) == 0:
                return ""

            inputs = self.processor(torch.tensor(audio_buffer), sampling_rate=16_000, return_tensors="pt", padding=True)

            with torch.no_grad():
                logits = self.model(inputs.input_values.to(self.device),
                                    attention_mask=inputs.attention_mask.to(self.device)).logits            

            if hasattr(self.processor, 'decoder') and self.use_lm_if_possible:
                transcription = \
                    self.processor.decode(logits[0].cpu().numpy(),                                      
                                          hotwords=self.hotwords, 
                                          output_word_offsets=True,                                      
                                       )                             
                confidence = transcription.lm_score / len(transcription.text.split(" "))
                transcription = transcription.text       
            else:
                predicted_ids = torch.argmax(logits, dim=-1)
                transcription = self.processor.batch_decode(predicted_ids)[0]
                confidence = self.confidence_score(logits,predicted_ids)

            spinner.succeed("Audio transcribed successfully!")
            return transcription, confidence.item()
        except Exception as e:
            spinner.fail(f"Error during transcription: {str(e)}")
            return "", 0.0

    def confidence_score(self, logits, predicted_ids):
        """
        Calculate the confidence score for the predicted IDs based on the logits.

        Parameters:
            logits (torch.Tensor): The logits tensor.
            predicted_ids (torch.Tensor): The predicted IDs tensor.

        Returns:
            float: The average confidence score for the predicted IDs.
        """
        scores = torch.nn.functional.softmax(logits, dim=-1)                                                           
        pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0]
        mask = torch.logical_and(
            predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id), 
            predicted_ids.not_equal(self.processor.tokenizer.pad_token_id))

        character_scores = pred_scores.masked_select(mask)
        total_average = torch.sum(character_scores) / len(character_scores)
        return total_average

    def file_to_text(self, filename):
        """
        Reads an audio file and converts it to text using the buffer_to_text method.

        Parameters:
            filename (str): The path to the audio file.

        Returns:
            tuple: A tuple containing the transcription (str) and the confidence (float) of the transcription. If there is an error reading the audio file, an empty string and a confidence of 0.0 will be returned.
        """
        spinner = Halo(text="Reading audio file...", spinner="dots")
        spinner.start()

        try:
            audio_input, samplerate = sf.read(filename)
            assert samplerate == 16000
            transcription, confidence = self.buffer_to_text(audio_input)
            spinner.succeed("File read successfully!")
            return transcription, confidence
        except Exception as e:
            spinner.fail(f"Error reading audio file: {str(e)}")
            return "", 0.0
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="arifagustyawan/wav2vec2-large-xlsr-common_voice_13_0-id")
    parser.add_argument("--filename", type=str, default="assets/halo.wav")
    args = parser.parse_args()

    with Halo(text="Initializing Wav2Vec2 Inference...", spinner="dots") as init_spinner:
        try:
            asr = Wav2Vec2Inference(args.model_name)
            init_spinner.succeed("Wav2Vec2 Inference initialized successfully!")
        except Exception as e:
            init_spinner.fail(f"Error initializing Wav2Vec2 Inference: {str(e)}")
            sys.exit(1)

    with Halo(text="Performing audio transcription...", spinner="dots") as transcribe_spinner:
        transcription, confidence = asr.file_to_text(args.filename)
        
    print("\033[94mTranscription:\033[0m", transcription)
    print("\033[94mConfidence:\033[0m", confidence)