darag's picture
Update app.py
2c66d5b verified
import os
import torch
import librosa
import numpy as np
import gradio as gr
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
def format_time(milliseconds):
seconds, milliseconds = divmod(int(milliseconds), 1000)
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
def detect_speech_activity(y, sr, frame_length=1024, hop_length=512, threshold=0.01):
energy = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0]
speech_frames = energy > threshold
speech_regions = []
in_speech = False
for i, speech in enumerate(speech_frames):
if speech and not in_speech:
start = i
in_speech = True
elif not speech and in_speech:
end = i
speech_regions.append((start * hop_length / sr, end * hop_length / sr))
in_speech = False
if in_speech:
speech_regions.append((start * hop_length / sr, len(y) / sr))
return speech_regions
def post_process_text(text):
text = text.replace(" ", " ")
text = text.strip()
return text
def transcribe_audio(audio_file):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = "Akashpb13/xlsr_kurmanji_kurdish"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
y, sr = librosa.load(audio_file, sr=16000)
voiced_segments = detect_speech_activity(y, sr, threshold=0.005)
srt_content = ""
for i, (start, end) in enumerate(voiced_segments, start=1):
segment_audio = y[int(start * sr):int(end * sr)]
input_values = processor(segment_audio, sampling_rate=sr, return_tensors="pt").input_values
input_values = input_values.to(device)
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
transcription = post_process_text(transcription)
if transcription:
start_time = format_time(start * 1000)
end_time = format_time(end * 1000)
srt_content += f"{i}\n"
srt_content += f"{start_time} --> {end_time}\n"
# Break long lines into shorter ones (max 50 characters)
words = transcription.split()
lines = []
current_line = ""
for word in words:
if len(current_line) + len(word) > 50:
lines.append(current_line.strip())
current_line = ""
current_line += word + " "
if current_line:
lines.append(current_line.strip())
srt_content += "\n".join(lines) + "\n\n"
return srt_content
def save_srt(audio_file):
srt_content = transcribe_audio(audio_file)
output_filename = "output.srt"
with open(output_filename, "w", encoding="utf-8") as f:
f.write(srt_content)
return output_filename, srt_content
iface = gr.Interface(
fn=save_srt,
inputs=gr.Audio(type="filepath"),
outputs=[
gr.File(label="Download SRT"),
gr.Textbox(label="SRT Content", lines=10)
],
title="Kurdish Speech-to-Text Transcription",
description="Upload an audio file to generate a SRT subtitle file with Kurdish transcription."
)
if __name__ == "__main__":
iface.launch()