sdelangen commited on
Commit
416616e
1 Parent(s): 04e927d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +111 -0
README.md CHANGED
@@ -166,6 +166,117 @@ for text_chunk in asr.transcribe_file_streaming(args.audio_path, config):
166
  ```
167
  </details>
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  ### Inference on GPU
170
  To perform inference on the GPU, add `run_opts={"device":"cuda"}` when calling the `from_hparams` method.
171
 
 
166
  ```
167
  </details>
168
 
169
+ <details>
170
+ <summary>Live ASR decoding from a browser using Gradio</summary>
171
+
172
+ This is a simple hacky demo of live ASR in the browser using Gradio's live microphone streaming feature.
173
+ If you run this, please note that browsers may refuse to stream audio from an insecure connection, unless it is localhost.
174
+ If you are running this on a remote server, you could use SSH port forwarding to expose the remote's port on your machine.
175
+
176
+ Run using:
177
+
178
+ `python3 gradio-asr.py --model-source speechbrain/asr-streaming-conformer-librispeech --ip=localhost --device=cpu`
179
+
180
+ ```python
181
+ from argparse import ArgumentParser
182
+ from dataclasses import dataclass
183
+ import logging
184
+
185
+ parser = ArgumentParser()
186
+ parser.add_argument("--model-source", required=True)
187
+ parser.add_argument("--device", default="cpu")
188
+ parser.add_argument("--ip", default="127.0.0.1")
189
+ parser.add_argument("--port", default=9431)
190
+ parser.add_argument("--chunk-size", default=24, type=int)
191
+ parser.add_argument("--left-context-chunks", default=4, type=int)
192
+ parser.add_argument("--num-threads", default=None, type=int)
193
+ parser.add_argument("--verbose", "-v", default=False, action="store_true")
194
+ args = parser.parse_args()
195
+
196
+ if args.verbose:
197
+ logging.getLogger().setLevel(logging.INFO)
198
+
199
+ logging.info("Loading libraries")
200
+
201
+ from speechbrain.inference.ASR import StreamingASR, ASRStreamingContext
202
+ from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
203
+ import torch
204
+ import gradio as gr
205
+ import torchaudio
206
+ import numpy as np
207
+
208
+ device = args.device
209
+
210
+ if args.num_threads is not None:
211
+ torch.set_num_threads(args.num_threads)
212
+
213
+ logging.info(f"Loading model from \"{args.model_source}\" onto device {device}")
214
+
215
+ asr = StreamingASR.from_hparams(args.model_source, run_opts={"device": device})
216
+ config = DynChunkTrainConfig(args.chunk_size, args.left_context_chunks)
217
+
218
+ @dataclass
219
+ class GradioStreamingContext:
220
+ context: ASRStreamingContext
221
+ chunk_size: int
222
+ waveform_buffer: torch.Tensor
223
+ decoded_text: str
224
+
225
+ def transcribe(stream, new_chunk):
226
+ sr, y = new_chunk
227
+
228
+ y = y.astype(np.float32)
229
+ y = torch.tensor(y, dtype=torch.float32, device=device)
230
+ y /= max(1, torch.max(torch.abs(y)).item()) # norm by max abs() within chunk & avoid NaN
231
+ if len(y.shape) > 1:
232
+ y = torch.mean(y, dim=1) # downmix to mono
233
+
234
+ # HACK: we are making poor use of the resampler across chunk boundaries
235
+ # which may degrade accuracy.
236
+ # NOTE: we should also absolutely avoid recreating a resampler every time
237
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=asr.audio_normalizer.sample_rate)
238
+ y = resampler(y) # janky resample (probably to 16kHz)
239
+
240
+
241
+ if stream is None:
242
+ stream = GradioStreamingContext(
243
+ context=asr.make_streaming_context(config),
244
+ chunk_size=asr.get_chunk_size_frames(config),
245
+ waveform_buffer=y,
246
+ decoded_text="",
247
+ )
248
+ else:
249
+ stream.waveform_buffer = torch.concat((stream.waveform_buffer, y))
250
+
251
+ while stream.waveform_buffer.size(0) > stream.chunk_size:
252
+ chunk = stream.waveform_buffer[:stream.chunk_size]
253
+ stream.waveform_buffer = stream.waveform_buffer[stream.chunk_size:]
254
+
255
+ # fake batch dim
256
+ chunk = chunk.unsqueeze(0)
257
+
258
+ # list of transcribed strings, of size 1 because the batch size is 1
259
+ with torch.no_grad():
260
+ transcribed = asr.transcribe_chunk(stream.context, chunk)
261
+ stream.decoded_text += transcribed[0]
262
+
263
+ return stream, stream.decoded_text
264
+
265
+ # NOTE: latency seems relatively high, which may be due to this:
266
+ # https://github.com/gradio-app/gradio/issues/6526
267
+
268
+ demo = gr.Interface(
269
+ transcribe,
270
+ ["state", gr.Audio(sources=["microphone"], streaming=True)],
271
+ ["state", "text"],
272
+ live=True,
273
+ )
274
+
275
+ demo.launch(server_name=args.ip, server_port=args.port)
276
+ ```
277
+
278
+ </details>
279
+
280
  ### Inference on GPU
281
  To perform inference on the GPU, add `run_opts={"device":"cuda"}` when calling the `from_hparams` method.
282