Spaces:
Runtime error
Runtime error
File size: 4,286 Bytes
530ac2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
from typing import List
import ffmpeg
from src.config import ModelConfig
from src.hooks.progressListener import ProgressListener
from src.modelCache import ModelCache
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
class DummyWhisperContainer(AbstractWhisperContainer):
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
download_root: str = None,
cache: ModelCache = None, models: List[ModelConfig] = []):
super().__init__(model_name, device, compute_type, download_root, cache, models)
def ensure_downloaded(self):
"""
Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
passing the container to a subprocess.
"""
print("[Dummy] Ensuring that the model is downloaded")
def _create_model(self):
print("[Dummy] Creating dummy whisper model " + self.model_name + " for device " + str(self.device))
return None
def create_callback(self, language: str = None, task: str = None,
prompt_strategy: AbstractPromptStrategy = None,
**decodeOptions: dict) -> AbstractWhisperCallback:
"""
Create a WhisperCallback object that can be used to transcript audio files.
Parameters
----------
language: str
The target language of the transcription. If not specified, the language will be inferred from the audio content.
task: str
The task - either translate or transcribe.
prompt_strategy: AbstractPromptStrategy
The prompt strategy to use. If not specified, the prompt from Whisper will be used.
decodeOptions: dict
Additional options to pass to the decoder. Must be pickleable.
Returns
-------
A WhisperCallback object.
"""
return DummyWhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
class DummyWhisperCallback(AbstractWhisperCallback):
def __init__(self, model_container: DummyWhisperContainer, **decodeOptions: dict):
self.model_container = model_container
self.decodeOptions = decodeOptions
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
"""
Peform the transcription of the given audio file or data.
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor]
The audio file to transcribe, or the audio data as a numpy array or torch tensor.
segment_index: int
The target language of the transcription. If not specified, the language will be inferred from the audio content.
task: str
The task - either translate or transcribe.
progress_listener: ProgressListener
A callback to receive progress updates.
"""
print("[Dummy] Invoking dummy whisper callback for segment " + str(segment_index))
# Estimate length
if isinstance(audio, str):
audio_length = ffmpeg.probe(audio)["format"]["duration"]
# Format is pcm_s16le at a sample rate of 16000, loaded as a float32 array.
else:
audio_length = len(audio) / 16000
# Convert the segments to a format that is easier to serialize
whisper_segments = [{
"text": "Dummy text for segment " + str(segment_index),
"start": 0,
"end": audio_length,
# Extra fields added by faster-whisper
"words": []
}]
result = {
"segments": whisper_segments,
"text": "Dummy text for segment " + str(segment_index),
"language": "en" if detected_language is None else detected_language,
# Extra fields added by faster-whisper
"language_probability": 1.0,
"duration": audio_length,
}
if progress_listener is not None:
progress_listener.on_finished()
return result |