summary / local_transcript.py
xsestech's picture
Created app
d5c679f verified
import argparse
import faulthandler
import gc
import os
import tempfile
import torch
import whisperx
from whisperx.asr import FasterWhisperPipeline
def get_device():
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "mps" if torch.backends.mps.is_available() else device
return device
def generate_subtitles_from_audio(
audio_file_path: str,
model: FasterWhisperPipeline,
batch_size: int = 8
):
audio = whisperx.load_audio(audio_file_path)
result = model.transcribe(audio, batch_size=batch_size, language="ru", )
return result
def generate_subtitles_from_video(
video_path: str,
model_name: str = "base",
batch_size: int = 8,
compute_type: str = "int8",
):
_, audio_file = tempfile.mkstemp()
device = get_device()
print("Loading model:")
model = whisperx.load_model(model_name, device, compute_type=compute_type, language="ru")
print("Parsing audio:")
parse_audio(video_path, audio_file)
print("Generating subtitles:")
result = generate_subtitles_from_audio(audio_file, model, batch_size=batch_size)
os.remove(audio_file)
del model
gc.collect()
return result
def add_whisper_args(arg_parser: argparse.ArgumentParser):
arg_parser.add_argument("video", help="video file")
arg_parser.add_argument("--compute_type", help="Base type for model", default="int8",
choices=["int8", "float16", "float32"])
arg_parser.add_argument("--whisper_model", help="model to use", default="large-v2")
arg_parser.add_argument("--batch_size", help="Batch size for inference", default=4, type=int)
if __name__ == "__main__":
faulthandler.enable()
parser = argparse.ArgumentParser(description="Get video subtitles from a video")
add_whisper_args(parser)
args = parser.parse_args()
print(generate_subtitles_from_video(args.video, args.whisper_model, args.batch_size, args.compute_type))