File size: 12,345 Bytes
039d7b8
 
 
01c308f
039d7b8
 
 
 
01c308f
039d7b8
 
 
 
 
7320deb
 
 
039d7b8
 
3aaea69
039d7b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed0d9df
 
 
 
 
 
039d7b8
 
 
 
 
 
 
 
ed0d9df
039d7b8
 
bb23e2c
039d7b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
from typing import Iterator

from io import StringIO
import os
import pathlib
import tempfile

# External programs
import whisper
import ffmpeg

# UI
import gradio as gr

from download import ExceededMaximumDuration, download_url
from utils import slugify, write_srt, write_vtt
from vad import NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription

# Limitations (set to -1 to disable)
DEFAULT_INPUT_AUDIO_MAX_DURATION = 3605 # seconds #initial value 600

# Whether or not to automatically delete all uploaded files, to save disk space
DELETE_UPLOADED_FILES = True

# Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself 
MAX_FILE_PREFIX_LENGTH = 17

LANGUAGES = [ 
 "English", "Chinese", "German", "Spanish", "Russian", "Korean", 
 "French", "Japanese", "Portuguese", "Turkish", "Polish", "Catalan", 
 "Dutch", "Arabic", "Swedish", "Italian", "Indonesian", "Hindi", 
 "Finnish", "Vietnamese", "Hebrew", "Ukrainian", "Greek", "Malay", 
 "Czech", "Romanian", "Danish", "Hungarian", "Tamil", "Norwegian", 
 "Thai", "Urdu", "Croatian", "Bulgarian", "Lithuanian", "Latin", 
 "Maori", "Malayalam", "Welsh", "Slovak", "Telugu", "Persian", 
 "Latvian", "Bengali", "Serbian", "Azerbaijani", "Slovenian", 
 "Kannada", "Estonian", "Macedonian", "Breton", "Basque", "Icelandic", 
 "Armenian", "Nepali", "Mongolian", "Bosnian", "Kazakh", "Albanian",
 "Swahili", "Galician", "Marathi", "Punjabi", "Sinhala", "Khmer", 
 "Shona", "Yoruba", "Somali", "Afrikaans", "Occitan", "Georgian", 
 "Belarusian", "Tajik", "Sindhi", "Gujarati", "Amharic", "Yiddish", 
 "Lao", "Uzbek", "Faroese", "Haitian Creole", "Pashto", "Turkmen", 
 "Nynorsk", "Maltese", "Sanskrit", "Luxembourgish", "Myanmar", "Tibetan",
 "Tagalog", "Malagasy", "Assamese", "Tatar", "Hawaiian", "Lingala", 
 "Hausa", "Bashkir", "Javanese", "Sundanese"
]

class WhisperTranscriber:
    def __init__(self, inputAudioMaxDuration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, deleteUploadedFiles: bool = DELETE_UPLOADED_FILES):
        self.model_cache = dict()

        self.vad_model = None
        self.inputAudioMaxDuration = inputAudioMaxDuration
        self.deleteUploadedFiles = deleteUploadedFiles

    def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
        try:
            source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
            
            try:
                selectedLanguage = languageName.lower() if len(languageName) > 0 else None
                selectedModel = modelName if modelName is not None else "base"

                model = self.model_cache.get(selectedModel, None)
                
                if not model:
                    model = whisper.load_model(selectedModel)
                    self.model_cache[selectedModel] = model

                # Execute whisper
                result = self.transcribe_file(model, source, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)

                # Write result
                downloadDirectory = tempfile.mkdtemp()
                
                filePrefix = slugify(sourceName, allow_unicode=True)
                download, text, vtt = self.write_result(result, filePrefix, downloadDirectory)

                return download, text, vtt

            finally:
                # Cleanup source
                if self.deleteUploadedFiles:
                    print("Deleting source file " + source)
                    os.remove(source)
        
        except ExceededMaximumDuration as e:
            return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"

    def transcribe_file(self, model: whisper.Whisper, audio_path: str, language: str, task: str = None, vad: str = None, 
                        vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
        
        initial_prompt = decodeOptions.pop('initial_prompt', None)

        if ('task' in decodeOptions):
            task = decodeOptions.pop('task')

        # Callable for processing an audio file
        whisperCallable = lambda audio, segment_index, prompt, detected_language : model.transcribe(audio, \
                 language=language if language else detected_language, task=task, \
                 initial_prompt=self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt, \
                 **decodeOptions)

        # The results
        if (vad == 'silero-vad'):
            # Silero VAD where non-speech gaps are transcribed
            process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
            result = self.vad_model.transcribe(audio_path, whisperCallable, process_gaps)
        elif (vad == 'silero-vad-skip-gaps'):
            # Silero VAD where non-speech gaps are simply ignored
            skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
            result = self.vad_model.transcribe(audio_path, whisperCallable, skip_gaps)
        elif (vad == 'silero-vad-expand-into-gaps'):
            # Use Silero VAD where speech-segments are expanded into non-speech gaps
            expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
            result = self.vad_model.transcribe(audio_path, whisperCallable, expand_gaps)
        elif (vad == 'periodic-vad'):
            # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
            # it may create a break in the middle of a sentence, causing some artifacts.
            periodic_vad = VadPeriodicTranscription()
            result = periodic_vad.transcribe(audio_path, whisperCallable, PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow))
        else:
            # Default VAD
            result = whisperCallable(audio_path, 0, None, None)

        return result

    def _concat_prompt(self, prompt1, prompt2):
        if (prompt1 is None):
            return prompt2
        elif (prompt2 is None):
            return prompt1
        else:
            return prompt1 + " " + prompt2

    def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
        # Use Silero VAD 
        if (self.vad_model is None):
            self.vad_model = VadSileroTranscription()

        config = TranscriptionConfig(non_speech_strategy = non_speech_strategy, 
                max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize, 
                segment_padding_left=vadPadding, segment_padding_right=vadPadding, 
                max_prompt_window=vadPromptWindow)

        return config

    def write_result(self, result: dict, source_name: str, output_dir: str):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        text = result["text"]
        language = result["language"]
        languageMaxLineWidth = self.__get_max_line_width(language)

        print("Max line width " + str(languageMaxLineWidth))
        vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
        srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)

        output_files = []
        output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
        output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
        output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));

        return output_files, text, vtt

    def clear_cache(self):
        self.model_cache = dict()
        self.vad_model = None

    def __get_source(self, urlData, uploadFile, microphoneData):
        if urlData:
            # Download from YouTube
            source = download_url(urlData, self.inputAudioMaxDuration)[0]
        else:
            # File input
            source = uploadFile if uploadFile is not None else microphoneData

            if self.inputAudioMaxDuration > 0:
                # Calculate audio length
                audioDuration = ffmpeg.probe(source)["format"]["duration"]
            
                if float(audioDuration) > self.inputAudioMaxDuration:
                    raise ExceededMaximumDuration(videoDuration=audioDuration, maxDuration=self.inputAudioMaxDuration, message="Video is too long")

        file_path = pathlib.Path(source)
        sourceName = file_path.stem[:MAX_FILE_PREFIX_LENGTH] + file_path.suffix

        return source, sourceName

    def __get_max_line_width(self, language: str) -> int:
        if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
            # Chinese characters and kana are wider, so limit line length to 40 characters
            return 40
        else:
            # TODO: Add more languages
            # 80 latin characters should fit on a 1080p/720p screen
            return 80

    def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
        segmentStream = StringIO()

        if format == 'vtt':
            write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
        elif format == 'srt':
            write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
        else:
            raise Exception("Unknown format " + format)

        segmentStream.seek(0)
        return segmentStream.read()

    def __create_file(self, text: str, directory: str, fileName: str) -> str:
        # Write the text to a file
        with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
            file.write(text)

        return file.name
       
 # translate_checkbox = gr.inputs.Checkbox(label = "Translate to English", default=False)
 # transcription_tb = gr.Textbox(label="Transcription", lines=10, max_lines=20)
 # translation_tb = gr.Textbox(label="Translation", lines=10, max_lines=20)
 # detected_lang = gr.outputs.HTML(label="Detected Language")
 


def create_ui(inputAudioMaxDuration, share=False, server_name: str = None):
    ui = WhisperTranscriber(inputAudioMaxDuration)

    ui_description = "Whisper是一个语音转文字模型,经过多个语音数据集的训练而成。也可以进行多语言的识别任务和翻译(多种语言翻译成英文)" 


    ui_description += "\n\n\n\n对于时长大于20分钟的非英语音频文件,建议选择VAD选项中的Silero VAD (语音活动检测器)。"

    if inputAudioMaxDuration > 0:
        ui_description += "\n\n" + "音频最大时长: " + str(inputAudioMaxDuration) + " 秒"

    
    demo = gr.Interface(fn=ui.transcribe_webui, description=ui_description, inputs=[
        gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value="medium", label="Model"),
        gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
        gr.Text(label="URL (YouTube, etc.)"),
        gr.Audio(source="upload", type="filepath", label="Upload Audio"), 
        gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
        gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
        gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], label="VAD"),
        gr.Number(label="VAD - Merge Window (s)", precision=0, value=5),
        gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=30),
        gr.Number(label="VAD - Padding (s)", precision=None, value=1),
        gr.Number(label="VAD - Prompt Window (s)", precision=None, value=3)
    ], outputs=[
        gr.File(label="Download"),
        gr.Text(label="Transcription"), 
        gr.Text(label="Segments")
    ])

    demo.launch(share=share, server_name=server_name)   

if __name__ == '__main__':
    create_ui(DEFAULT_INPUT_AUDIO_MAX_DURATION)