arifagustyawan
commited on
Commit
•
4eafd35
1
Parent(s):
0e72979
initial commit
Browse files- app.py +59 -0
- assets/halo.wav +0 -0
- requirements.txt +7 -0
- src/inference.py +143 -0
app.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from src.inference import Wav2Vec2Inference
|
3 |
+
import librosa
|
4 |
+
import os, sys
|
5 |
+
import soundfile
|
6 |
+
|
7 |
+
model_name = "arifagustyawan/wav2vec2-large-xlsr-common_voice_13_0-id"
|
8 |
+
asr = Wav2Vec2Inference(model_name)
|
9 |
+
|
10 |
+
def convert(inputfile, outfile):
|
11 |
+
target_sr = 16000
|
12 |
+
data, sample_rate = librosa.load(inputfile)
|
13 |
+
data = librosa.resample(data, orig_sr=sample_rate, target_sr=target_sr)
|
14 |
+
soundfile.write(outfile, data, target_sr)
|
15 |
+
|
16 |
+
def parse_transcription_record(wav_file):
|
17 |
+
filename = wav_file.split('.')[0]
|
18 |
+
convert(wav_file, filename + "16k.wav")
|
19 |
+
transcription, confidence = asr.file_to_text(filename + "16k.wav")
|
20 |
+
return transcription, confidence
|
21 |
+
return filename + "16k.wav", transcription
|
22 |
+
|
23 |
+
def parse_transcription_file(wav_file):
|
24 |
+
filename = wav_file.name.split('.')[0]
|
25 |
+
convert(wav_file.name, filename + "16k.wav")
|
26 |
+
transcription, confidence = asr.file_to_text(filename + "16k.wav")
|
27 |
+
return transcription, confidence
|
28 |
+
return filename + "16k.wav", transcription
|
29 |
+
|
30 |
+
examples = [
|
31 |
+
[os.path.join("assets", "halo.wav")]
|
32 |
+
|
33 |
+
]
|
34 |
+
record_audio = gr.Interface(
|
35 |
+
fn = parse_transcription_record,
|
36 |
+
inputs = gr.Audio(sources="microphone", type="filepath", label = "Click button to record audio"),
|
37 |
+
outputs = [gr.Textbox(label="Transcription"), gr.Textbox(label="Confidence")],
|
38 |
+
analytics_enabled=False,
|
39 |
+
allow_flagging = "never",
|
40 |
+
title="Automatic Speech Recognition",
|
41 |
+
description="Click the button bellow to record audio!",
|
42 |
+
)
|
43 |
+
|
44 |
+
upload_file = gr.Interface(
|
45 |
+
fn = parse_transcription_file,
|
46 |
+
inputs = gr.File(type= "filepath", label = "Upload file here"),
|
47 |
+
outputs = [gr.Textbox(label="Transcription"), gr.Textbox(label="Confidence")],
|
48 |
+
examples = examples,
|
49 |
+
analytics_enabled=False,
|
50 |
+
title="Automatic Speech Recognition",
|
51 |
+
allow_flagging = "never",
|
52 |
+
description="Upload or drag and drop the audio file here!",
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
demo = gr.TabbedInterface([record_audio, upload_file], ["Record Audio", "Upload Audio"])
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
demo.launch()
|
assets/halo.wav
ADDED
Binary file (77.9 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets
|
2 |
+
transformers
|
3 |
+
huggingface-hub
|
4 |
+
soundfile
|
5 |
+
halo
|
6 |
+
gradio
|
7 |
+
librosa
|
src/inference.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
from halo import Halo
|
6 |
+
import soundfile as sf
|
7 |
+
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor
|
8 |
+
|
9 |
+
path_this_file = os.path.dirname(os.path.abspath(__file__))
|
10 |
+
pat_project_root = os.path.join(path_this_file, "..")
|
11 |
+
sys.path.append(pat_project_root)
|
12 |
+
|
13 |
+
class Wav2Vec2Inference:
|
14 |
+
def __init__(self,model_name, hotwords=[], use_lm_if_possible=True, use_gpu=True):
|
15 |
+
"""
|
16 |
+
Initializes the class with the provided parameters.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
model_name (str): The name of the model to be used.
|
20 |
+
hotwords (list, optional): A list of hotwords. Defaults to an empty list.
|
21 |
+
use_lm_if_possible (bool, optional): Specifies whether to use a language model if possible.
|
22 |
+
Defaults to True.
|
23 |
+
use_gpu (bool, optional): Specifies whether to use the GPU. Defaults to True.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
None
|
27 |
+
"""
|
28 |
+
self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
|
29 |
+
if use_lm_if_possible:
|
30 |
+
self.processor = AutoProcessor.from_pretrained(model_name)
|
31 |
+
else:
|
32 |
+
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
|
33 |
+
self.model = AutoModelForCTC.from_pretrained(model_name)
|
34 |
+
self.model.to(self.device)
|
35 |
+
self.hotwords = hotwords
|
36 |
+
self.use_lm_if_possible = use_lm_if_possible
|
37 |
+
|
38 |
+
def buffer_to_text(self, audio_buffer):
|
39 |
+
"""
|
40 |
+
Transcribes the given audio buffer into text.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
audio_buffer (list): A list representing the audio buffer.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
tuple: A tuple containing the transcribed text (str) and the confidence score (float).
|
47 |
+
"""
|
48 |
+
spinner = Halo(text="Transcribing audio...", spinner="dots")
|
49 |
+
spinner.start()
|
50 |
+
|
51 |
+
try:
|
52 |
+
if len(audio_buffer) == 0:
|
53 |
+
return ""
|
54 |
+
|
55 |
+
inputs = self.processor(torch.tensor(audio_buffer), sampling_rate=16_000, return_tensors="pt", padding=True)
|
56 |
+
|
57 |
+
with torch.no_grad():
|
58 |
+
logits = self.model(inputs.input_values.to(self.device),
|
59 |
+
attention_mask=inputs.attention_mask.to(self.device)).logits
|
60 |
+
|
61 |
+
if hasattr(self.processor, 'decoder') and self.use_lm_if_possible:
|
62 |
+
transcription = \
|
63 |
+
self.processor.decode(logits[0].cpu().numpy(),
|
64 |
+
hotwords=self.hotwords,
|
65 |
+
output_word_offsets=True,
|
66 |
+
)
|
67 |
+
confidence = transcription.lm_score / len(transcription.text.split(" "))
|
68 |
+
transcription = transcription.text
|
69 |
+
else:
|
70 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
71 |
+
transcription = self.processor.batch_decode(predicted_ids)[0]
|
72 |
+
confidence = self.confidence_score(logits,predicted_ids)
|
73 |
+
|
74 |
+
spinner.succeed("Audio transcribed successfully!")
|
75 |
+
return transcription, confidence.item()
|
76 |
+
except Exception as e:
|
77 |
+
spinner.fail(f"Error during transcription: {str(e)}")
|
78 |
+
return "", 0.0
|
79 |
+
|
80 |
+
def confidence_score(self, logits, predicted_ids):
|
81 |
+
"""
|
82 |
+
Calculate the confidence score for the predicted IDs based on the logits.
|
83 |
+
|
84 |
+
Parameters:
|
85 |
+
logits (torch.Tensor): The logits tensor.
|
86 |
+
predicted_ids (torch.Tensor): The predicted IDs tensor.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
float: The average confidence score for the predicted IDs.
|
90 |
+
"""
|
91 |
+
scores = torch.nn.functional.softmax(logits, dim=-1)
|
92 |
+
pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0]
|
93 |
+
mask = torch.logical_and(
|
94 |
+
predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id),
|
95 |
+
predicted_ids.not_equal(self.processor.tokenizer.pad_token_id))
|
96 |
+
|
97 |
+
character_scores = pred_scores.masked_select(mask)
|
98 |
+
total_average = torch.sum(character_scores) / len(character_scores)
|
99 |
+
return total_average
|
100 |
+
|
101 |
+
def file_to_text(self, filename):
|
102 |
+
"""
|
103 |
+
Reads an audio file and converts it to text using the buffer_to_text method.
|
104 |
+
|
105 |
+
Parameters:
|
106 |
+
filename (str): The path to the audio file.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
tuple: A tuple containing the transcription (str) and the confidence (float) of the transcription. If there is an error reading the audio file, an empty string and a confidence of 0.0 will be returned.
|
110 |
+
"""
|
111 |
+
spinner = Halo(text="Reading audio file...", spinner="dots")
|
112 |
+
spinner.start()
|
113 |
+
|
114 |
+
try:
|
115 |
+
audio_input, samplerate = sf.read(filename)
|
116 |
+
assert samplerate == 16000
|
117 |
+
transcription, confidence = self.buffer_to_text(audio_input)
|
118 |
+
spinner.succeed("File read successfully!")
|
119 |
+
return transcription, confidence
|
120 |
+
except Exception as e:
|
121 |
+
spinner.fail(f"Error reading audio file: {str(e)}")
|
122 |
+
return "", 0.0
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
parser = argparse.ArgumentParser()
|
126 |
+
parser.add_argument("--model_name", type=str, default="arifagustyawan/wav2vec2-large-xlsr-common_voice_13_0-id")
|
127 |
+
parser.add_argument("--filename", type=str, default="assets/halo.wav")
|
128 |
+
args = parser.parse_args()
|
129 |
+
|
130 |
+
with Halo(text="Initializing Wav2Vec2 Inference...", spinner="dots") as init_spinner:
|
131 |
+
try:
|
132 |
+
asr = Wav2Vec2Inference(args.model_name)
|
133 |
+
init_spinner.succeed("Wav2Vec2 Inference initialized successfully!")
|
134 |
+
except Exception as e:
|
135 |
+
init_spinner.fail(f"Error initializing Wav2Vec2 Inference: {str(e)}")
|
136 |
+
sys.exit(1)
|
137 |
+
|
138 |
+
with Halo(text="Performing audio transcription...", spinner="dots") as transcribe_spinner:
|
139 |
+
transcription, confidence = asr.file_to_text(args.filename)
|
140 |
+
|
141 |
+
print("\033[94mTranscription:\033[0m", transcription)
|
142 |
+
print("\033[94mConfidence:\033[0m", confidence)
|
143 |
+
|