import argparse import faulthandler import gc import os import tempfile import torch import whisperx from whisperx.asr import FasterWhisperPipeline def get_device(): device = "cuda" if torch.cuda.is_available() else "cpu" # device = "mps" if torch.backends.mps.is_available() else device return device def generate_subtitles_from_audio( audio_file_path: str, model: FasterWhisperPipeline, batch_size: int = 8 ): audio = whisperx.load_audio(audio_file_path) result = model.transcribe(audio, batch_size=batch_size, language="ru", ) return result def generate_subtitles_from_video( video_path: str, model_name: str = "base", batch_size: int = 8, compute_type: str = "int8", ): _, audio_file = tempfile.mkstemp() device = get_device() print("Loading model:") model = whisperx.load_model(model_name, device, compute_type=compute_type, language="ru") print("Parsing audio:") parse_audio(video_path, audio_file) print("Generating subtitles:") result = generate_subtitles_from_audio(audio_file, model, batch_size=batch_size) os.remove(audio_file) del model gc.collect() return result def add_whisper_args(arg_parser: argparse.ArgumentParser): arg_parser.add_argument("video", help="video file") arg_parser.add_argument("--compute_type", help="Base type for model", default="int8", choices=["int8", "float16", "float32"]) arg_parser.add_argument("--whisper_model", help="model to use", default="large-v2") arg_parser.add_argument("--batch_size", help="Batch size for inference", default=4, type=int) if __name__ == "__main__": faulthandler.enable() parser = argparse.ArgumentParser(description="Get video subtitles from a video") add_whisper_args(parser) args = parser.parse_args() print(generate_subtitles_from_video(args.video, args.whisper_model, args.batch_size, args.compute_type))