|
import argparse |
|
import faulthandler |
|
import gc |
|
import os |
|
import tempfile |
|
|
|
import torch |
|
import whisperx |
|
|
|
from whisperx.asr import FasterWhisperPipeline |
|
|
|
|
|
|
|
def get_device(): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
return device |
|
|
|
|
|
def generate_subtitles_from_audio( |
|
audio_file_path: str, |
|
model: FasterWhisperPipeline, |
|
batch_size: int = 8 |
|
): |
|
audio = whisperx.load_audio(audio_file_path) |
|
result = model.transcribe(audio, batch_size=batch_size, language="ru", ) |
|
return result |
|
|
|
|
|
def generate_subtitles_from_video( |
|
video_path: str, |
|
model_name: str = "base", |
|
batch_size: int = 8, |
|
compute_type: str = "int8", |
|
): |
|
_, audio_file = tempfile.mkstemp() |
|
|
|
device = get_device() |
|
|
|
|
|
print("Loading model:") |
|
model = whisperx.load_model(model_name, device, compute_type=compute_type, language="ru") |
|
print("Parsing audio:") |
|
parse_audio(video_path, audio_file) |
|
print("Generating subtitles:") |
|
result = generate_subtitles_from_audio(audio_file, model, batch_size=batch_size) |
|
|
|
os.remove(audio_file) |
|
del model |
|
gc.collect() |
|
return result |
|
|
|
|
|
def add_whisper_args(arg_parser: argparse.ArgumentParser): |
|
arg_parser.add_argument("video", help="video file") |
|
arg_parser.add_argument("--compute_type", help="Base type for model", default="int8", |
|
choices=["int8", "float16", "float32"]) |
|
arg_parser.add_argument("--whisper_model", help="model to use", default="large-v2") |
|
arg_parser.add_argument("--batch_size", help="Batch size for inference", default=4, type=int) |
|
|
|
|
|
if __name__ == "__main__": |
|
faulthandler.enable() |
|
parser = argparse.ArgumentParser(description="Get video subtitles from a video") |
|
add_whisper_args(parser) |
|
args = parser.parse_args() |
|
print(generate_subtitles_from_video(args.video, args.whisper_model, args.batch_size, args.compute_type)) |
|
|