akera's picture
Update app.py
1609670 verified
import gradio as gr
from transformers import pipeline
import torch
import librosa
import json
import os
import whisper
# Assuming other necessary imports and setup are already done
auth_token = os.environ.get("HF_TOKEN")
target_lang_options = {"English": "eng", "Luganda": "lug", "Acholi": "ach", "Runyankole": "nyn", "Lugbara": "lgg"}
languages = list(target_lang_options.keys())
# Helper function to format and group word timestamps
def format_and_group_timestamps(chunks, interval=5.0):
grouped = {}
transcript = ""
for chunk in chunks:
start, end = chunk['timestamp']
word = chunk['text']
transcript += f"{word} "
interval_start = int(start // interval) * interval
if interval_start not in grouped:
grouped[interval_start] = []
grouped[interval_start].append((start, end, word))
formatted_output = f"Transcript: {transcript.strip()}'\n\n-------\n\nword-stamped transcripts (every 5 seconds):\n\n"
for interval_start, words in grouped.items():
formatted_output += f"({interval_start}, {interval_start + interval}) -- {' '.join([w[2] for w in words])}\n"
return formatted_output
# Modified transcribe_audio function to use Whisper for English
def transcribe_audio(input_file, language, chunk_length_s=10, stride_length_s=(4, 2), return_timestamps="word"):
target_lang_options = {"English": "eng", "Luganda": "lug", "Acholi": "ach", "Runyankole": "nyn", "Lugbara": "lgg"}
target_lang_code = target_lang_options[language]
device = "cuda" if torch.cuda.is_available() else "cpu"
if target_lang_code == "eng":
# Use Whisper for English
model = whisper.load_model("small")
result = model.transcribe(input_file)
# Assuming you want to keep the formatting function for consistency
return result["text"]
else:
# Use specified model for other languages
model_id = "Sunbird/sunbird-mms"
auth_token = os.environ.get("HF_TOKEN")
pipe = pipeline(model=model_id, device=device, token=auth_token)
pipe.tokenizer.set_target_lang(target_lang_code)
pipe.model.load_adapter(target_lang_code)
output = pipe(input_file, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, return_timestamps=return_timestamps)
formatted_output = format_and_group_timestamps(output['chunks'])
return formatted_output
# Interface setup remains the same
description = '''ASR with salt-mms'''
iface = gr.Interface(fn=transcribe_audio,
inputs=[
gr.Audio(sources="upload", type="filepath", label="upload file to transcribe"),
gr.Dropdown(choices=list(target_lang_options.keys()), label="Language", value="English")
],
outputs=gr.Textbox(label="Transcription"),
description=description
)
# Launch the interface
iface.launch()