arifagustyawan commited on
Commit
4eafd35
1 Parent(s): 0e72979

initial commit

Browse files
Files changed (4) hide show
  1. app.py +59 -0
  2. assets/halo.wav +0 -0
  3. requirements.txt +7 -0
  4. src/inference.py +143 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from src.inference import Wav2Vec2Inference
3
+ import librosa
4
+ import os, sys
5
+ import soundfile
6
+
7
+ model_name = "arifagustyawan/wav2vec2-large-xlsr-common_voice_13_0-id"
8
+ asr = Wav2Vec2Inference(model_name)
9
+
10
+ def convert(inputfile, outfile):
11
+ target_sr = 16000
12
+ data, sample_rate = librosa.load(inputfile)
13
+ data = librosa.resample(data, orig_sr=sample_rate, target_sr=target_sr)
14
+ soundfile.write(outfile, data, target_sr)
15
+
16
+ def parse_transcription_record(wav_file):
17
+ filename = wav_file.split('.')[0]
18
+ convert(wav_file, filename + "16k.wav")
19
+ transcription, confidence = asr.file_to_text(filename + "16k.wav")
20
+ return transcription, confidence
21
+ return filename + "16k.wav", transcription
22
+
23
+ def parse_transcription_file(wav_file):
24
+ filename = wav_file.name.split('.')[0]
25
+ convert(wav_file.name, filename + "16k.wav")
26
+ transcription, confidence = asr.file_to_text(filename + "16k.wav")
27
+ return transcription, confidence
28
+ return filename + "16k.wav", transcription
29
+
30
+ examples = [
31
+ [os.path.join("assets", "halo.wav")]
32
+
33
+ ]
34
+ record_audio = gr.Interface(
35
+ fn = parse_transcription_record,
36
+ inputs = gr.Audio(sources="microphone", type="filepath", label = "Click button to record audio"),
37
+ outputs = [gr.Textbox(label="Transcription"), gr.Textbox(label="Confidence")],
38
+ analytics_enabled=False,
39
+ allow_flagging = "never",
40
+ title="Automatic Speech Recognition",
41
+ description="Click the button bellow to record audio!",
42
+ )
43
+
44
+ upload_file = gr.Interface(
45
+ fn = parse_transcription_file,
46
+ inputs = gr.File(type= "filepath", label = "Upload file here"),
47
+ outputs = [gr.Textbox(label="Transcription"), gr.Textbox(label="Confidence")],
48
+ examples = examples,
49
+ analytics_enabled=False,
50
+ title="Automatic Speech Recognition",
51
+ allow_flagging = "never",
52
+ description="Upload or drag and drop the audio file here!",
53
+ )
54
+
55
+
56
+ demo = gr.TabbedInterface([record_audio, upload_file], ["Record Audio", "Upload Audio"])
57
+
58
+ if __name__ == "__main__":
59
+ demo.launch()
assets/halo.wav ADDED
Binary file (77.9 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ datasets
2
+ transformers
3
+ huggingface-hub
4
+ soundfile
5
+ halo
6
+ gradio
7
+ librosa
src/inference.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import argparse
5
+ from halo import Halo
6
+ import soundfile as sf
7
+ from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor
8
+
9
+ path_this_file = os.path.dirname(os.path.abspath(__file__))
10
+ pat_project_root = os.path.join(path_this_file, "..")
11
+ sys.path.append(pat_project_root)
12
+
13
+ class Wav2Vec2Inference:
14
+ def __init__(self,model_name, hotwords=[], use_lm_if_possible=True, use_gpu=True):
15
+ """
16
+ Initializes the class with the provided parameters.
17
+
18
+ Args:
19
+ model_name (str): The name of the model to be used.
20
+ hotwords (list, optional): A list of hotwords. Defaults to an empty list.
21
+ use_lm_if_possible (bool, optional): Specifies whether to use a language model if possible.
22
+ Defaults to True.
23
+ use_gpu (bool, optional): Specifies whether to use the GPU. Defaults to True.
24
+
25
+ Returns:
26
+ None
27
+ """
28
+ self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
29
+ if use_lm_if_possible:
30
+ self.processor = AutoProcessor.from_pretrained(model_name)
31
+ else:
32
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name)
33
+ self.model = AutoModelForCTC.from_pretrained(model_name)
34
+ self.model.to(self.device)
35
+ self.hotwords = hotwords
36
+ self.use_lm_if_possible = use_lm_if_possible
37
+
38
+ def buffer_to_text(self, audio_buffer):
39
+ """
40
+ Transcribes the given audio buffer into text.
41
+
42
+ Args:
43
+ audio_buffer (list): A list representing the audio buffer.
44
+
45
+ Returns:
46
+ tuple: A tuple containing the transcribed text (str) and the confidence score (float).
47
+ """
48
+ spinner = Halo(text="Transcribing audio...", spinner="dots")
49
+ spinner.start()
50
+
51
+ try:
52
+ if len(audio_buffer) == 0:
53
+ return ""
54
+
55
+ inputs = self.processor(torch.tensor(audio_buffer), sampling_rate=16_000, return_tensors="pt", padding=True)
56
+
57
+ with torch.no_grad():
58
+ logits = self.model(inputs.input_values.to(self.device),
59
+ attention_mask=inputs.attention_mask.to(self.device)).logits
60
+
61
+ if hasattr(self.processor, 'decoder') and self.use_lm_if_possible:
62
+ transcription = \
63
+ self.processor.decode(logits[0].cpu().numpy(),
64
+ hotwords=self.hotwords,
65
+ output_word_offsets=True,
66
+ )
67
+ confidence = transcription.lm_score / len(transcription.text.split(" "))
68
+ transcription = transcription.text
69
+ else:
70
+ predicted_ids = torch.argmax(logits, dim=-1)
71
+ transcription = self.processor.batch_decode(predicted_ids)[0]
72
+ confidence = self.confidence_score(logits,predicted_ids)
73
+
74
+ spinner.succeed("Audio transcribed successfully!")
75
+ return transcription, confidence.item()
76
+ except Exception as e:
77
+ spinner.fail(f"Error during transcription: {str(e)}")
78
+ return "", 0.0
79
+
80
+ def confidence_score(self, logits, predicted_ids):
81
+ """
82
+ Calculate the confidence score for the predicted IDs based on the logits.
83
+
84
+ Parameters:
85
+ logits (torch.Tensor): The logits tensor.
86
+ predicted_ids (torch.Tensor): The predicted IDs tensor.
87
+
88
+ Returns:
89
+ float: The average confidence score for the predicted IDs.
90
+ """
91
+ scores = torch.nn.functional.softmax(logits, dim=-1)
92
+ pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0]
93
+ mask = torch.logical_and(
94
+ predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id),
95
+ predicted_ids.not_equal(self.processor.tokenizer.pad_token_id))
96
+
97
+ character_scores = pred_scores.masked_select(mask)
98
+ total_average = torch.sum(character_scores) / len(character_scores)
99
+ return total_average
100
+
101
+ def file_to_text(self, filename):
102
+ """
103
+ Reads an audio file and converts it to text using the buffer_to_text method.
104
+
105
+ Parameters:
106
+ filename (str): The path to the audio file.
107
+
108
+ Returns:
109
+ tuple: A tuple containing the transcription (str) and the confidence (float) of the transcription. If there is an error reading the audio file, an empty string and a confidence of 0.0 will be returned.
110
+ """
111
+ spinner = Halo(text="Reading audio file...", spinner="dots")
112
+ spinner.start()
113
+
114
+ try:
115
+ audio_input, samplerate = sf.read(filename)
116
+ assert samplerate == 16000
117
+ transcription, confidence = self.buffer_to_text(audio_input)
118
+ spinner.succeed("File read successfully!")
119
+ return transcription, confidence
120
+ except Exception as e:
121
+ spinner.fail(f"Error reading audio file: {str(e)}")
122
+ return "", 0.0
123
+
124
+ if __name__ == "__main__":
125
+ parser = argparse.ArgumentParser()
126
+ parser.add_argument("--model_name", type=str, default="arifagustyawan/wav2vec2-large-xlsr-common_voice_13_0-id")
127
+ parser.add_argument("--filename", type=str, default="assets/halo.wav")
128
+ args = parser.parse_args()
129
+
130
+ with Halo(text="Initializing Wav2Vec2 Inference...", spinner="dots") as init_spinner:
131
+ try:
132
+ asr = Wav2Vec2Inference(args.model_name)
133
+ init_spinner.succeed("Wav2Vec2 Inference initialized successfully!")
134
+ except Exception as e:
135
+ init_spinner.fail(f"Error initializing Wav2Vec2 Inference: {str(e)}")
136
+ sys.exit(1)
137
+
138
+ with Halo(text="Performing audio transcription...", spinner="dots") as transcribe_spinner:
139
+ transcription, confidence = asr.file_to_text(args.filename)
140
+
141
+ print("\033[94mTranscription:\033[0m", transcription)
142
+ print("\033[94mConfidence:\033[0m", confidence)
143
+