File size: 4,286 Bytes
530ac2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from typing import List

import ffmpeg
from src.config import ModelConfig
from src.hooks.progressListener import ProgressListener
from src.modelCache import ModelCache
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer

class DummyWhisperContainer(AbstractWhisperContainer):
    def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
                       download_root: str = None,
                       cache: ModelCache = None, models: List[ModelConfig] = []):
        super().__init__(model_name, device, compute_type, download_root, cache, models)

    def ensure_downloaded(self):
        """
        Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
        passing the container to a subprocess.
        """
        print("[Dummy] Ensuring that the model is downloaded")

    def _create_model(self):
        print("[Dummy] Creating dummy whisper model " + self.model_name + " for device " + str(self.device))
        return None

    def create_callback(self, language: str = None, task: str = None, 
                        prompt_strategy: AbstractPromptStrategy = None, 
                        **decodeOptions: dict) -> AbstractWhisperCallback:
        """
        Create a WhisperCallback object that can be used to transcript audio files.

        Parameters
        ----------
        language: str
            The target language of the transcription. If not specified, the language will be inferred from the audio content.
        task: str
            The task - either translate or transcribe.
        prompt_strategy: AbstractPromptStrategy
            The prompt strategy to use. If not specified, the prompt from Whisper will be used.
        decodeOptions: dict
            Additional options to pass to the decoder. Must be pickleable.

        Returns
        -------
        A WhisperCallback object.
        """
        return DummyWhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
    
class DummyWhisperCallback(AbstractWhisperCallback):
    def __init__(self, model_container: DummyWhisperContainer, **decodeOptions: dict):
        self.model_container = model_container
        self.decodeOptions = decodeOptions
        
    def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
        """
        Peform the transcription of the given audio file or data.

        Parameters
        ----------
        audio: Union[str, np.ndarray, torch.Tensor]
            The audio file to transcribe, or the audio data as a numpy array or torch tensor.
        segment_index: int
            The target language of the transcription. If not specified, the language will be inferred from the audio content.
        task: str
            The task - either translate or transcribe.
        progress_listener: ProgressListener
            A callback to receive progress updates.
        """
        print("[Dummy] Invoking dummy whisper callback for segment " + str(segment_index))

        # Estimate length
        if isinstance(audio, str):
            audio_length = ffmpeg.probe(audio)["format"]["duration"]
        # Format is pcm_s16le at a sample rate of 16000, loaded as a float32 array.
        else:
            audio_length = len(audio) / 16000

        # Convert the segments to a format that is easier to serialize
        whisper_segments = [{
            "text": "Dummy text for segment " + str(segment_index),
            "start": 0,
            "end": audio_length,

            # Extra fields added by faster-whisper
            "words": []
        }]

        result = {
            "segments": whisper_segments,
            "text": "Dummy text for segment " + str(segment_index),
            "language": "en" if detected_language is None else detected_language,

            # Extra fields added by faster-whisper
            "language_probability": 1.0,
            "duration": audio_length,
        }

        if progress_listener is not None:
            progress_listener.on_finished()
        return result