Arnaudding001 commited on
Commit
039d7b8
1 Parent(s): 5a2469e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +252 -86
app.py CHANGED
@@ -1,88 +1,254 @@
 
 
 
1
  import os
2
- import gradio as gr
 
 
 
3
  import whisper
4
- from whisper import tokenizer
5
- import time
6
-
7
- current_size = 'base'
8
- model = whisper.load_model(current_size)
9
- AUTO_DETECT_LANG = "Auto Detection"
10
-
11
- def transcribe(audio, state={}, model_size='base', delay=1.2, lang=None, translate=False):
12
- time.sleep(delay - 1)
13
-
14
- global current_size
15
- global model
16
- if model_size != current_size:
17
- current_size = model_size
18
- model = whisper.load_model(current_size)
19
-
20
- transcription = model.transcribe(
21
- audio,
22
- language = lang if lang != AUTO_DETECT_LANG else None
23
- )
24
- state['transcription'] += transcription['text'] + " "
25
-
26
- if translate:
27
- x = whisper.load_audio(audio)
28
- x = whisper.pad_or_trim(x)
29
- mel = whisper.log_mel_spectrogram(x).to(model.device)
30
-
31
- options = whisper.DecodingOptions(task = "translation")
32
- translation = whisper.decode(model, mel, options)
33
-
34
- state['translation'] += translation.text + " "
35
-
36
- return state['transcription'], state['translation'], state, f"detected language: {transcription['language']}"
37
-
38
-
39
- title = "OpenAI's Whisper Real-time Demo"
40
- description = "A simple demo of OpenAI's [**Whisper**](https://github.com/openai/whisper) speech recognition model. This demo runs on a CPU. For faster inference choose 'tiny' model size and set the language explicitly."
41
-
42
- model_size = gr.Dropdown(label="Model size", choices=['base', 'tiny', 'small', 'medium', 'large'], value='base')
43
-
44
- delay_slider = gr.inputs.Slider(minimum=1, maximum=5, default=1.2, label="Rate of transcription")
45
-
46
- available_languages = sorted(tokenizer.TO_LANGUAGE_CODE.keys())
47
- available_languages = [lang.capitalize() for lang in available_languages]
48
- available_languages = [AUTO_DETECT_LANG]+available_languages
49
-
50
- lang_dropdown = gr.inputs.Dropdown(choices=available_languages, label="Language", default=AUTO_DETECT_LANG, type="value")
51
-
52
- if lang_dropdown==AUTO_DETECT_LANG:
53
- lang_dropdown=None
54
-
55
- translate_checkbox = gr.inputs.Checkbox(label="Translate to English", default=False)
56
-
57
-
58
-
59
- transcription_tb = gr.Textbox(label="Transcription", lines=10, max_lines=20)
60
- translation_tb = gr.Textbox(label="Translation", lines=10, max_lines=20)
61
- detected_lang = gr.outputs.HTML(label="Detected Language")
62
-
63
- state = gr.State({"transcription": "", "translation": ""})
64
-
65
- gr.Interface(
66
- fn=transcribe,
67
- inputs=[
68
- gr.Audio(source="microphone", type="filepath", streaming=True),
69
- state,
70
- model_size,
71
- delay_slider,
72
- lang_dropdown,
73
- translate_checkbox
74
- ],
75
- outputs=[
76
- transcription_tb,
77
- translation_tb,
78
- state,
79
- detected_lang
80
- ],
81
- live=True,
82
- allow_flagging='never',
83
- title=title,
84
- description=description,
85
- ).launch(
86
- # enable_queue=True,
87
- # debug=True
88
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterator
2
+
3
+ from io import StringIO
4
  import os
5
+ import pathlib
6
+ import tempfile
7
+
8
+ # External programs
9
  import whisper
10
+ import ffmpeg
11
+
12
+ # UI
13
+ import gradio as gr
14
+
15
+ from src.download import ExceededMaximumDuration, download_url
16
+ from src.utils import slugify, write_srt, write_vtt
17
+ from src.vad import NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
18
+
19
+ # Limitations (set to -1 to disable)
20
+ DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
21
+
22
+ # Whether or not to automatically delete all uploaded files, to save disk space
23
+ DELETE_UPLOADED_FILES = True
24
+
25
+ # Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
26
+ MAX_FILE_PREFIX_LENGTH = 17
27
+
28
+ LANGUAGES = [
29
+ "English", "Chinese", "German", "Spanish", "Russian", "Korean",
30
+ "French", "Japanese", "Portuguese", "Turkish", "Polish", "Catalan",
31
+ "Dutch", "Arabic", "Swedish", "Italian", "Indonesian", "Hindi",
32
+ "Finnish", "Vietnamese", "Hebrew", "Ukrainian", "Greek", "Malay",
33
+ "Czech", "Romanian", "Danish", "Hungarian", "Tamil", "Norwegian",
34
+ "Thai", "Urdu", "Croatian", "Bulgarian", "Lithuanian", "Latin",
35
+ "Maori", "Malayalam", "Welsh", "Slovak", "Telugu", "Persian",
36
+ "Latvian", "Bengali", "Serbian", "Azerbaijani", "Slovenian",
37
+ "Kannada", "Estonian", "Macedonian", "Breton", "Basque", "Icelandic",
38
+ "Armenian", "Nepali", "Mongolian", "Bosnian", "Kazakh", "Albanian",
39
+ "Swahili", "Galician", "Marathi", "Punjabi", "Sinhala", "Khmer",
40
+ "Shona", "Yoruba", "Somali", "Afrikaans", "Occitan", "Georgian",
41
+ "Belarusian", "Tajik", "Sindhi", "Gujarati", "Amharic", "Yiddish",
42
+ "Lao", "Uzbek", "Faroese", "Haitian Creole", "Pashto", "Turkmen",
43
+ "Nynorsk", "Maltese", "Sanskrit", "Luxembourgish", "Myanmar", "Tibetan",
44
+ "Tagalog", "Malagasy", "Assamese", "Tatar", "Hawaiian", "Lingala",
45
+ "Hausa", "Bashkir", "Javanese", "Sundanese"
46
+ ]
47
+
48
+ class WhisperTranscriber:
49
+ def __init__(self, inputAudioMaxDuration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, deleteUploadedFiles: bool = DELETE_UPLOADED_FILES):
50
+ self.model_cache = dict()
51
+
52
+ self.vad_model = None
53
+ self.inputAudioMaxDuration = inputAudioMaxDuration
54
+ self.deleteUploadedFiles = deleteUploadedFiles
55
+
56
+ def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
57
+ try:
58
+ source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
59
+
60
+ try:
61
+ selectedLanguage = languageName.lower() if len(languageName) > 0 else None
62
+ selectedModel = modelName if modelName is not None else "base"
63
+
64
+ model = self.model_cache.get(selectedModel, None)
65
+
66
+ if not model:
67
+ model = whisper.load_model(selectedModel)
68
+ self.model_cache[selectedModel] = model
69
+
70
+ # Execute whisper
71
+ result = self.transcribe_file(model, source, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
72
+
73
+ # Write result
74
+ downloadDirectory = tempfile.mkdtemp()
75
+
76
+ filePrefix = slugify(sourceName, allow_unicode=True)
77
+ download, text, vtt = self.write_result(result, filePrefix, downloadDirectory)
78
+
79
+ return download, text, vtt
80
+
81
+ finally:
82
+ # Cleanup source
83
+ if self.deleteUploadedFiles:
84
+ print("Deleting source file " + source)
85
+ os.remove(source)
86
+
87
+ except ExceededMaximumDuration as e:
88
+ return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
89
+
90
+ def transcribe_file(self, model: whisper.Whisper, audio_path: str, language: str, task: str = None, vad: str = None,
91
+ vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
92
+
93
+ initial_prompt = decodeOptions.pop('initial_prompt', None)
94
+
95
+ if ('task' in decodeOptions):
96
+ task = decodeOptions.pop('task')
97
+
98
+ # Callable for processing an audio file
99
+ whisperCallable = lambda audio, segment_index, prompt, detected_language : model.transcribe(audio, \
100
+ language=language if language else detected_language, task=task, \
101
+ initial_prompt=self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt, \
102
+ **decodeOptions)
103
+
104
+ # The results
105
+ if (vad == 'silero-vad'):
106
+ # Silero VAD where non-speech gaps are transcribed
107
+ process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
108
+ result = self.vad_model.transcribe(audio_path, whisperCallable, process_gaps)
109
+ elif (vad == 'silero-vad-skip-gaps'):
110
+ # Silero VAD where non-speech gaps are simply ignored
111
+ skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
112
+ result = self.vad_model.transcribe(audio_path, whisperCallable, skip_gaps)
113
+ elif (vad == 'silero-vad-expand-into-gaps'):
114
+ # Use Silero VAD where speech-segments are expanded into non-speech gaps
115
+ expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
116
+ result = self.vad_model.transcribe(audio_path, whisperCallable, expand_gaps)
117
+ elif (vad == 'periodic-vad'):
118
+ # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
119
+ # it may create a break in the middle of a sentence, causing some artifacts.
120
+ periodic_vad = VadPeriodicTranscription()
121
+ result = periodic_vad.transcribe(audio_path, whisperCallable, PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow))
122
+ else:
123
+ # Default VAD
124
+ result = whisperCallable(audio_path, 0, None, None)
125
+
126
+ return result
127
+
128
+ def _concat_prompt(self, prompt1, prompt2):
129
+ if (prompt1 is None):
130
+ return prompt2
131
+ elif (prompt2 is None):
132
+ return prompt1
133
+ else:
134
+ return prompt1 + " " + prompt2
135
+
136
+ def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
137
+ # Use Silero VAD
138
+ if (self.vad_model is None):
139
+ self.vad_model = VadSileroTranscription()
140
+
141
+ config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
142
+ max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
143
+ segment_padding_left=vadPadding, segment_padding_right=vadPadding,
144
+ max_prompt_window=vadPromptWindow)
145
+
146
+ return config
147
+
148
+ def write_result(self, result: dict, source_name: str, output_dir: str):
149
+ if not os.path.exists(output_dir):
150
+ os.makedirs(output_dir)
151
+
152
+ text = result["text"]
153
+ language = result["language"]
154
+ languageMaxLineWidth = self.__get_max_line_width(language)
155
+
156
+ print("Max line width " + str(languageMaxLineWidth))
157
+ vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
158
+ srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
159
+
160
+ output_files = []
161
+ output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
162
+ output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
163
+ output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
164
+
165
+ return output_files, text, vtt
166
+
167
+ def clear_cache(self):
168
+ self.model_cache = dict()
169
+ self.vad_model = None
170
+
171
+ def __get_source(self, urlData, uploadFile, microphoneData):
172
+ if urlData:
173
+ # Download from YouTube
174
+ source = download_url(urlData, self.inputAudioMaxDuration)[0]
175
+ else:
176
+ # File input
177
+ source = uploadFile if uploadFile is not None else microphoneData
178
+
179
+ if self.inputAudioMaxDuration > 0:
180
+ # Calculate audio length
181
+ audioDuration = ffmpeg.probe(source)["format"]["duration"]
182
+
183
+ if float(audioDuration) > self.inputAudioMaxDuration:
184
+ raise ExceededMaximumDuration(videoDuration=audioDuration, maxDuration=self.inputAudioMaxDuration, message="Video is too long")
185
+
186
+ file_path = pathlib.Path(source)
187
+ sourceName = file_path.stem[:MAX_FILE_PREFIX_LENGTH] + file_path.suffix
188
+
189
+ return source, sourceName
190
+
191
+ def __get_max_line_width(self, language: str) -> int:
192
+ if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
193
+ # Chinese characters and kana are wider, so limit line length to 40 characters
194
+ return 40
195
+ else:
196
+ # TODO: Add more languages
197
+ # 80 latin characters should fit on a 1080p/720p screen
198
+ return 80
199
+
200
+ def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
201
+ segmentStream = StringIO()
202
+
203
+ if format == 'vtt':
204
+ write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
205
+ elif format == 'srt':
206
+ write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
207
+ else:
208
+ raise Exception("Unknown format " + format)
209
+
210
+ segmentStream.seek(0)
211
+ return segmentStream.read()
212
+
213
+ def __create_file(self, text: str, directory: str, fileName: str) -> str:
214
+ # Write the text to a file
215
+ with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
216
+ file.write(text)
217
+
218
+ return file.name
219
+
220
+
221
+ def create_ui(inputAudioMaxDuration, share=False, server_name: str = None):
222
+ ui = WhisperTranscriber(inputAudioMaxDuration)
223
+
224
+ ui_description = "Whisper是一个语音转文字模型,经过多个语音数据集的训练而成。也可以进行多语言的识别任务和翻译(多种语言翻译成英文)"
225
+
226
+
227
+ ui_description += "\n\n\n\n对于时长大于10分钟的非英语音频文件,建议选择VAD选项中的Silero VAD (语音活动检测器)。"
228
+
229
+ if inputAudioMaxDuration > 0:
230
+ ui_description += "\n\n" + "Max audio file length: " + str(inputAudioMaxDuration) + " s"
231
+
232
+
233
+ demo = gr.Interface(fn=ui.transcribe_webui, description=ui_description, inputs=[
234
+ gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value="medium", label="Model"),
235
+ gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
236
+ gr.Text(label="URL (YouTube, etc.)"),
237
+ gr.Audio(source="upload", type="filepath", label="Upload Audio"),
238
+ gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
239
+ gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
240
+ gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], label="VAD"),
241
+ gr.Number(label="VAD - Merge Window (s)", precision=0, value=5),
242
+ gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=30),
243
+ gr.Number(label="VAD - Padding (s)", precision=None, value=1),
244
+ gr.Number(label="VAD - Prompt Window (s)", precision=None, value=3)
245
+ ], outputs=[
246
+ gr.File(label="Download"),
247
+ gr.Text(label="Transcription"),
248
+ gr.Text(label="Segments")
249
+ ])
250
+
251
+ demo.launch(share=share, server_name=server_name)
252
+
253
+ if __name__ == '__main__':
254
+ create_ui(DEFAULT_INPUT_AUDIO_MAX_DURATION)