Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -24,37 +24,42 @@ def format_and_group_timestamps(chunks, interval=5.0):
|
|
24 |
if interval_start not in grouped:
|
25 |
grouped[interval_start] = []
|
26 |
grouped[interval_start].append((start, end, word))
|
27 |
-
|
28 |
formatted_output = f"Transcript: {transcript.strip()}'\n\n-------\n\nword-stamped transcripts (every 5 seconds):\n\n"
|
29 |
for interval_start, words in grouped.items():
|
30 |
formatted_output += f"({interval_start}, {interval_start + interval}) -- {' '.join([w[2] for w in words])}\n"
|
31 |
return formatted_output
|
32 |
|
33 |
-
# Modified transcribe_audio function
|
34 |
def transcribe_audio(input_file, language, chunk_length_s=10, stride_length_s=(4, 2), return_timestamps="word"):
|
35 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
36 |
target_lang_options = {"English": "eng", "Luganda": "lug", "Acholi": "ach", "Runyankole": "nyn", "Lugbara": "lgg"}
|
37 |
target_lang_code = target_lang_options[language]
|
38 |
-
|
39 |
-
# Determine the model_id based on the language
|
40 |
if target_lang_code == "eng":
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
else:
|
|
|
|
|
43 |
model_id = "Sunbird/sunbird-mms"
|
44 |
-
|
45 |
-
pipe = pipeline(model=model_id, device=device, token=auth_token)
|
46 |
-
pipe.tokenizer.set_target_lang(target_lang_code)
|
47 |
-
pipe.model.load_adapter(target_lang_code)
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
# Interface setup remains the same
|
54 |
description = '''ASR with salt-mms'''
|
55 |
iface = gr.Interface(fn=transcribe_audio,
|
56 |
inputs=[
|
57 |
-
gr.Audio(
|
58 |
gr.Dropdown(choices=list(target_lang_options.keys()), label="Language", value="English")
|
59 |
],
|
60 |
outputs=gr.Textbox(label="Transcription"),
|
|
|
24 |
if interval_start not in grouped:
|
25 |
grouped[interval_start] = []
|
26 |
grouped[interval_start].append((start, end, word))
|
27 |
+
|
28 |
formatted_output = f"Transcript: {transcript.strip()}'\n\n-------\n\nword-stamped transcripts (every 5 seconds):\n\n"
|
29 |
for interval_start, words in grouped.items():
|
30 |
formatted_output += f"({interval_start}, {interval_start + interval}) -- {' '.join([w[2] for w in words])}\n"
|
31 |
return formatted_output
|
32 |
|
33 |
+
# Modified transcribe_audio function to use Whisper for English
|
34 |
def transcribe_audio(input_file, language, chunk_length_s=10, stride_length_s=(4, 2), return_timestamps="word"):
|
|
|
35 |
target_lang_options = {"English": "eng", "Luganda": "lug", "Acholi": "ach", "Runyankole": "nyn", "Lugbara": "lgg"}
|
36 |
target_lang_code = target_lang_options[language]
|
37 |
+
|
|
|
38 |
if target_lang_code == "eng":
|
39 |
+
# Use Whisper for English
|
40 |
+
model = whisper.load_model("small")
|
41 |
+
result = model.transcribe(input_file)
|
42 |
+
# Assuming you want to keep the formatting function for consistency
|
43 |
+
return result["text"]
|
44 |
else:
|
45 |
+
# Use specified model for other languages
|
46 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
47 |
model_id = "Sunbird/sunbird-mms"
|
|
|
|
|
|
|
|
|
48 |
|
49 |
+
auth_token = os.environ.get("HF_TOKEN")
|
50 |
+
pipe = pipeline(model=model_id, device=device, token=auth_token)
|
51 |
+
pipe.tokenizer.set_target_lang(target_lang_code)
|
52 |
+
pipe.model.load_adapter(target_lang_code)
|
53 |
+
|
54 |
+
output = pipe(input_file, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, return_timestamps=return_timestamps)
|
55 |
+
formatted_output = format_and_group_timestamps(output['chunks'])
|
56 |
+
return formatted_output
|
57 |
|
58 |
# Interface setup remains the same
|
59 |
description = '''ASR with salt-mms'''
|
60 |
iface = gr.Interface(fn=transcribe_audio,
|
61 |
inputs=[
|
62 |
+
gr.Audio(sources="upload", type="filepath", label="upload file to transcribe"),
|
63 |
gr.Dropdown(choices=list(target_lang_options.keys()), label="Language", value="English")
|
64 |
],
|
65 |
outputs=gr.Textbox(label="Transcription"),
|