|
from whisperx.alignment import ( |
|
DEFAULT_ALIGN_MODELS_TORCH as DAMT, |
|
DEFAULT_ALIGN_MODELS_HF as DAMHF, |
|
) |
|
from whisperx.utils import TO_LANGUAGE_CODE |
|
import whisperx |
|
import torch |
|
import gc |
|
import os |
|
import soundfile as sf |
|
from IPython.utils import capture |
|
from .language_configuration import EXTRA_ALIGN, INVERTED_LANGUAGES |
|
from .logging_setup import logger |
|
from .postprocessor import sanitize_file_name |
|
from .utils import remove_directory_contents, run_command |
|
|
|
|
|
import spaces |
|
import copy |
|
import random |
|
import time |
|
|
|
def random_sleep(): |
|
if os.environ.get("ZERO_GPU") == "TRUE": |
|
print("Random sleep") |
|
sleep_time = round(random.uniform(7.2, 9.9), 1) |
|
time.sleep(sleep_time) |
|
|
|
|
|
@spaces.GPU |
|
def load_and_transcribe_audio(asr_model, audio, compute_type, language, asr_options, batch_size, segment_duration_limit): |
|
|
|
model = whisperx.load_model( |
|
asr_model, |
|
os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda", |
|
compute_type=compute_type, |
|
language=language, |
|
asr_options=asr_options, |
|
) |
|
|
|
|
|
result = model.transcribe( |
|
audio, |
|
batch_size=batch_size, |
|
chunk_size=segment_duration_limit, |
|
print_progress=True, |
|
) |
|
|
|
del model |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
return result |
|
|
|
def load_align_and_align_segments(result, audio, DAMHF): |
|
|
|
|
|
model_a, metadata = whisperx.load_align_model( |
|
language_code=result["language"], |
|
device=os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda", |
|
model_name=None |
|
if result["language"] in DAMHF.keys() |
|
else EXTRA_ALIGN[result["language"]], |
|
) |
|
|
|
|
|
alignment_result = whisperx.align( |
|
result["segments"], |
|
model_a, |
|
metadata, |
|
audio, |
|
os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda", |
|
return_char_alignments=True, |
|
print_progress=False, |
|
) |
|
|
|
|
|
del model_a |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
return alignment_result |
|
|
|
@spaces.GPU |
|
def diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers): |
|
|
|
if os.environ.get("ZERO_GPU") == "TRUE": |
|
diarize_model.model.to(torch.device("cuda")) |
|
diarize_segments = diarize_model( |
|
audio_wav, |
|
min_speakers=min_speakers, |
|
max_speakers=max_speakers |
|
) |
|
return diarize_segments |
|
|
|
|
|
|
|
ASR_MODEL_OPTIONS = [ |
|
"tiny", |
|
"base", |
|
"small", |
|
"medium", |
|
"large", |
|
"large-v1", |
|
"large-v2", |
|
"large-v3", |
|
"distil-large-v2", |
|
"Systran/faster-distil-whisper-large-v3", |
|
"tiny.en", |
|
"base.en", |
|
"small.en", |
|
"medium.en", |
|
"distil-small.en", |
|
"distil-medium.en", |
|
"OpenAI_API_Whisper", |
|
] |
|
|
|
COMPUTE_TYPE_GPU = [ |
|
"default", |
|
"auto", |
|
"int8", |
|
"int8_float32", |
|
"int8_float16", |
|
"int8_bfloat16", |
|
"float16", |
|
"bfloat16", |
|
"float32" |
|
] |
|
|
|
COMPUTE_TYPE_CPU = [ |
|
"default", |
|
"auto", |
|
"int8", |
|
"int8_float32", |
|
"int16", |
|
"float32", |
|
] |
|
|
|
WHISPER_MODELS_PATH = './WHISPER_MODELS' |
|
|
|
|
|
def openai_api_whisper( |
|
input_audio_file, |
|
source_lang=None, |
|
chunk_duration=1800 |
|
): |
|
|
|
info = sf.info(input_audio_file) |
|
duration = info.duration |
|
|
|
output_directory = "./whisper_api_audio_parts" |
|
os.makedirs(output_directory, exist_ok=True) |
|
remove_directory_contents(output_directory) |
|
|
|
if duration > chunk_duration: |
|
|
|
cm = f'ffmpeg -i "{input_audio_file}" -f segment -segment_time {chunk_duration} -c:a libvorbis "{output_directory}/output%03d.ogg"' |
|
run_command(cm) |
|
|
|
chunk_files = sorted( |
|
[f"{output_directory}/{f}" for f in os.listdir(output_directory) if f.endswith('.ogg')] |
|
) |
|
else: |
|
one_file = f"{output_directory}/output000.ogg" |
|
cm = f'ffmpeg -i "{input_audio_file}" -c:a libvorbis {one_file}' |
|
run_command(cm) |
|
chunk_files = [one_file] |
|
|
|
|
|
segments = [] |
|
language = source_lang if source_lang else None |
|
for i, chunk in enumerate(chunk_files): |
|
from openai import OpenAI |
|
client = OpenAI() |
|
|
|
audio_file = open(chunk, "rb") |
|
transcription = client.audio.transcriptions.create( |
|
model="whisper-1", |
|
file=audio_file, |
|
language=language, |
|
response_format="verbose_json", |
|
timestamp_granularities=["segment"], |
|
) |
|
|
|
try: |
|
transcript_dict = transcription.model_dump() |
|
except: |
|
transcript_dict = transcription.to_dict() |
|
|
|
if language is None: |
|
logger.info(f'Language detected: {transcript_dict["language"]}') |
|
language = TO_LANGUAGE_CODE[transcript_dict["language"]] |
|
|
|
chunk_time = chunk_duration * (i) |
|
|
|
for seg in transcript_dict["segments"]: |
|
|
|
if "start" in seg.keys(): |
|
segments.append( |
|
{ |
|
"text": seg["text"], |
|
"start": seg["start"] + chunk_time, |
|
"end": seg["end"] + chunk_time, |
|
} |
|
) |
|
|
|
audio = whisperx.load_audio(input_audio_file) |
|
result = {"segments": segments, "language": language} |
|
|
|
return audio, result |
|
|
|
|
|
def find_whisper_models(): |
|
path = WHISPER_MODELS_PATH |
|
folders = [] |
|
|
|
if os.path.exists(path): |
|
for folder in os.listdir(path): |
|
folder_path = os.path.join(path, folder) |
|
if ( |
|
os.path.isdir(folder_path) |
|
and 'model.bin' in os.listdir(folder_path) |
|
): |
|
folders.append(folder) |
|
return folders |
|
|
|
def transcribe_speech( |
|
audio_wav, |
|
asr_model, |
|
compute_type, |
|
batch_size, |
|
SOURCE_LANGUAGE, |
|
literalize_numbers=True, |
|
segment_duration_limit=15, |
|
): |
|
""" |
|
Transcribe speech using a whisper model. |
|
|
|
Parameters: |
|
- audio_wav (str): Path to the audio file in WAV format. |
|
- asr_model (str): The whisper model to be loaded. |
|
- compute_type (str): Type of compute to be used (e.g., 'int8', 'float16'). |
|
- batch_size (int): Batch size for transcription. |
|
- SOURCE_LANGUAGE (str): Source language for transcription. |
|
|
|
Returns: |
|
- Tuple containing: |
|
- audio: Loaded audio file. |
|
- result: Transcription result as a dictionary. |
|
""" |
|
|
|
if asr_model == "OpenAI_API_Whisper": |
|
if literalize_numbers: |
|
logger.info( |
|
"OpenAI's API Whisper does not support " |
|
"the literalization of numbers." |
|
) |
|
return openai_api_whisper(audio_wav, SOURCE_LANGUAGE) |
|
|
|
|
|
prompt = "以下是普通话的句子。" if SOURCE_LANGUAGE == "zh" else None |
|
SOURCE_LANGUAGE = ( |
|
SOURCE_LANGUAGE if SOURCE_LANGUAGE != "zh-TW" else "zh" |
|
) |
|
asr_options = { |
|
"initial_prompt": prompt, |
|
"suppress_numerals": literalize_numbers |
|
} |
|
|
|
if asr_model not in ASR_MODEL_OPTIONS: |
|
|
|
base_dir = WHISPER_MODELS_PATH |
|
if not os.path.exists(base_dir): |
|
os.makedirs(base_dir) |
|
model_dir = os.path.join(base_dir, sanitize_file_name(asr_model)) |
|
|
|
if not os.path.exists(model_dir): |
|
from ctranslate2.converters import TransformersConverter |
|
|
|
quantization = "float32" |
|
|
|
try: |
|
converter = TransformersConverter( |
|
asr_model, |
|
low_cpu_mem_usage=True, |
|
copy_files=[ |
|
"tokenizer_config.json", "preprocessor_config.json" |
|
] |
|
) |
|
converter.convert( |
|
model_dir, |
|
quantization=quantization, |
|
force=False |
|
) |
|
except Exception as error: |
|
if "File tokenizer_config.json does not exist" in str(error): |
|
converter._copy_files = [ |
|
"tokenizer.json", "preprocessor_config.json" |
|
] |
|
converter.convert( |
|
model_dir, |
|
quantization=quantization, |
|
force=True |
|
) |
|
else: |
|
raise error |
|
|
|
asr_model = model_dir |
|
logger.info(f"ASR Model: {str(model_dir)}") |
|
|
|
audio = whisperx.load_audio(audio_wav) |
|
|
|
result = load_and_transcribe_audio( |
|
asr_model, audio, compute_type, SOURCE_LANGUAGE, asr_options, batch_size, segment_duration_limit |
|
) |
|
|
|
if result["language"] == "zh" and not prompt: |
|
result["language"] = "zh-TW" |
|
logger.info("Chinese - Traditional (zh-TW)") |
|
|
|
|
|
return audio, result |
|
|
|
|
|
def align_speech(audio, result): |
|
""" |
|
Aligns speech segments based on the provided audio and result metadata. |
|
|
|
Parameters: |
|
- audio (array): The audio data in a suitable format for alignment. |
|
- result (dict): Metadata containing information about the segments |
|
and language. |
|
|
|
Returns: |
|
- result (dict): Updated metadata after aligning the segments with |
|
the audio. This includes character-level alignments if |
|
'return_char_alignments' is set to True. |
|
|
|
Notes: |
|
- This function uses language-specific models to align speech segments. |
|
- It performs language compatibility checks and selects the |
|
appropriate alignment model. |
|
- Cleans up memory by releasing resources after alignment. |
|
""" |
|
DAMHF.update(DAMT) |
|
if ( |
|
not result["language"] in DAMHF.keys() |
|
and not result["language"] in EXTRA_ALIGN.keys() |
|
): |
|
logger.warning( |
|
"Automatic detection: Source language not compatible with align" |
|
) |
|
raise ValueError( |
|
f"Detected language {result['language']} incompatible, " |
|
"you can select the source language to avoid this error." |
|
) |
|
if ( |
|
result["language"] in EXTRA_ALIGN.keys() |
|
and EXTRA_ALIGN[result["language"]] == "" |
|
): |
|
lang_name = ( |
|
INVERTED_LANGUAGES[result["language"]] |
|
if result["language"] in INVERTED_LANGUAGES.keys() |
|
else result["language"] |
|
) |
|
logger.warning( |
|
"No compatible wav2vec2 model found " |
|
f"for the language '{lang_name}', skipping alignment." |
|
) |
|
return result |
|
|
|
|
|
result = load_align_and_align_segments(result, audio, DAMHF) |
|
|
|
return result |
|
|
|
|
|
diarization_models = { |
|
"pyannote_3.1": "pyannote/speaker-diarization-3.1", |
|
"pyannote_2.1": "pyannote/speaker-diarization@2.1", |
|
"disable": "", |
|
} |
|
|
|
|
|
def reencode_speakers(result): |
|
|
|
if result["segments"][0]["speaker"] == "SPEAKER_00": |
|
return result |
|
|
|
speaker_mapping = {} |
|
counter = 0 |
|
|
|
logger.debug("Reencode speakers") |
|
|
|
for segment in result["segments"]: |
|
old_speaker = segment["speaker"] |
|
if old_speaker not in speaker_mapping: |
|
speaker_mapping[old_speaker] = f"SPEAKER_{counter:02d}" |
|
counter += 1 |
|
segment["speaker"] = speaker_mapping[old_speaker] |
|
|
|
return result |
|
|
|
|
|
def diarize_speech( |
|
audio_wav, |
|
result, |
|
min_speakers, |
|
max_speakers, |
|
YOUR_HF_TOKEN, |
|
model_name="pyannote/speaker-diarization@2.1", |
|
): |
|
""" |
|
Performs speaker diarization on speech segments. |
|
|
|
Parameters: |
|
- audio_wav (array): Audio data in WAV format to perform speaker |
|
diarization. |
|
- result (dict): Metadata containing information about speech segments |
|
and alignments. |
|
- min_speakers (int): Minimum number of speakers expected in the audio. |
|
- max_speakers (int): Maximum number of speakers expected in the audio. |
|
- YOUR_HF_TOKEN (str): Your Hugging Face API token for model |
|
authentication. |
|
- model_name (str): Name of the speaker diarization model to be used |
|
(default: "pyannote/speaker-diarization@2.1"). |
|
|
|
Returns: |
|
- result_diarize (dict): Updated metadata after assigning speaker |
|
labels to segments. |
|
|
|
Notes: |
|
- This function utilizes a speaker diarization model to label speaker |
|
segments in the audio. |
|
- It assigns speakers to word-level segments based on diarization results. |
|
- Cleans up memory by releasing resources after diarization. |
|
- If only one speaker is specified, each segment is automatically assigned |
|
as the first speaker, eliminating the need for diarization inference. |
|
""" |
|
|
|
if max(min_speakers, max_speakers) > 1 and model_name: |
|
try: |
|
|
|
diarize_model = whisperx.DiarizationPipeline( |
|
model_name=model_name, |
|
use_auth_token=YOUR_HF_TOKEN, |
|
device=os.environ.get("SONITR_DEVICE"), |
|
) |
|
|
|
except Exception as error: |
|
error_str = str(error) |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
if "'NoneType' object has no attribute 'to'" in error_str: |
|
if model_name == diarization_models["pyannote_2.1"]: |
|
raise ValueError( |
|
"Accept the license agreement for using Pyannote 2.1." |
|
" You need to have an account on Hugging Face and " |
|
"accept the license to use the models: " |
|
"https://huggingface.co/pyannote/speaker-diarization " |
|
"and https://huggingface.co/pyannote/segmentation " |
|
"Get your KEY TOKEN here: " |
|
"https://hf.co/settings/tokens " |
|
) |
|
elif model_name == diarization_models["pyannote_3.1"]: |
|
raise ValueError( |
|
"New Licence Pyannote 3.1: You need to have an account" |
|
" on Hugging Face and accept the license to use the " |
|
"models: https://huggingface.co/pyannote/speaker-diarization-3.1 " |
|
"and https://huggingface.co/pyannote/segmentation-3.0 " |
|
) |
|
else: |
|
raise error |
|
|
|
random_sleep() |
|
diarize_segments = diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers) |
|
|
|
result_diarize = whisperx.assign_word_speakers( |
|
diarize_segments, result |
|
) |
|
|
|
for segment in result_diarize["segments"]: |
|
if "speaker" not in segment: |
|
segment["speaker"] = "SPEAKER_00" |
|
logger.warning( |
|
f"No speaker detected in {segment['start']}. First TTS " |
|
f"will be used for the segment text: {segment['text']} " |
|
) |
|
|
|
del diarize_model |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
else: |
|
result_diarize = result |
|
result_diarize["segments"] = [ |
|
{**item, "speaker": "SPEAKER_00"} |
|
for item in result_diarize["segments"] |
|
] |
|
return reencode_speakers(result_diarize) |
|
|