Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
import numpy as np | |
from pyannote.audio import Pipeline | |
import os | |
from dotenv import load_dotenv | |
import plotly.graph_objects as go | |
load_dotenv() | |
# Check and set device | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
# Model and pipeline setup | |
model_id = "distil-whisper/distil-small.en" | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
max_new_tokens=128, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
diarization_pipeline = Pipeline.from_pretrained( | |
"pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_KEY") | |
) | |
# returns diarization info such as segment start and end times, and speaker id | |
def diarization_info(res): | |
starts = [] | |
ends = [] | |
speakers = [] | |
for segment, _, speaker in res.itertracks(yield_label=True): | |
starts.append(segment.start) | |
ends.append(segment.end) | |
speakers.append(speaker) | |
return starts, ends, speakers | |
# plot diarization results on a graph | |
def plot_diarization(starts, ends, speakers): | |
fig = go.Figure() | |
# Define a color map for different speakers | |
num_speakers = len(set(speakers)) | |
colors = [f"hsl({h},80%,60%)" for h in np.linspace(0, 360, num_speakers)] | |
# Plot each segment with its speaker's color | |
for start, end, speaker in zip(starts, ends, speakers): | |
speaker_id = list(set(speakers)).index(speaker) | |
fig.add_trace( | |
go.Scatter( | |
x=[start, end], | |
y=[speaker_id, speaker_id], | |
mode="lines", | |
line=dict(color=colors[speaker_id], width=15), | |
showlegend=False, | |
) | |
) | |
fig.update_layout( | |
title="Speaker Diarization", | |
xaxis=dict(title="Time"), | |
yaxis=dict(title="Speaker"), | |
height=600, | |
width=800, | |
) | |
return fig | |
def transcribe(sr, data): | |
processed_data = np.array(data).astype(np.float32) / 32767.0 | |
# results from the pipeline | |
transcription_res = pipe({"sampling_rate": sr, "raw": processed_data})["text"] | |
return transcription_res | |
def transcribe_diarize(audio): | |
sr, data = audio | |
processed_data = np.array(data).astype(np.float32) / 32767.0 | |
waveform_tensor = torch.tensor(processed_data[np.newaxis, :]) | |
transcription_res = transcribe(sr, data) | |
# results from the diarization pipeline | |
diarization_res = diarization_pipeline( | |
{"waveform": waveform_tensor, "sample_rate": sr} | |
) | |
# Get diarization information | |
starts, ends, speakers = diarization_info(diarization_res) | |
# results from the transcription pipeline | |
diarized_transcription = "" | |
# Get transcription results for each speaker segment | |
for start_time, end_time, speaker_id in zip(starts, ends, speakers): | |
segment = data[int(start_time * sr) : int(end_time * sr)] | |
diarized_transcription += f"{speaker_id} {round(start_time, 2)}:{round(end_time, 2)} \t {transcribe(sr, segment)}\n" | |
# Plot diarization | |
diarization_plot = plot_diarization(starts, ends, speakers) | |
return transcription_res, diarized_transcription, diarization_plot | |
# creating the gradio interface | |
demo = gr.Interface( | |
fn=transcribe_diarize, | |
inputs=gr.Audio(sources=["upload", "microphone"]), | |
outputs=[ | |
gr.Textbox(lines=3, label="Text Transcription"), | |
gr.Textbox(label="Diarized Transcription"), | |
gr.Plot(label="Visualization"), | |
], | |
examples=["sample1.wav"], | |
title="Automatic Speech Recognition with Diarization 🗣️", | |
description="Transcribe your speech to text with distilled whisper and diarization with pyannote. Get started by recording from your mic or uploading an audio file (.wav) 🎙️", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |