SajjadAyoubi commited on
Commit
1d88608
1 Parent(s): b55f94e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -16
app.py CHANGED
@@ -6,36 +6,29 @@ from pyctcdecode import build_ctcdecoder
6
 
7
  # Define ASR MODEL
8
  class Speech2Text:
9
- def __init__(self):
10
- self.vocab = list(processor.tokenizer.get_vocab().keys())
 
 
11
  self.decoder = build_ctcdecoder(self.vocab, kenlm_model_path='kenlm.scorer')
12
 
13
  def wav2feature(self, path):
14
  speech_array, sampling_rate = torchaudio.load(path)
15
- speech_array = librosa.resample(speech_array.squeeze().numpy(),
16
- sampling_rate, processor.feature_extractor.sampling_rate)
17
- return processor(speech_array, return_tensors="pt",
18
- sampling_rate=processor.feature_extractor.sampling_rate)
19
 
20
  def feature2logits(self, features):
21
  with torch.no_grad():
22
- return wav2vec_model(features.input_values[0].to(device)).logits.numpy()[0]
23
 
24
  def __call__(self, path):
25
  logits = self.feature2logits(self.wav2feature(path))
26
  return self.decoder.decode(logits)
27
 
28
- #Loading the model and the tokenizer
29
- model_name = 'masoudmzb/wav2vec2-xlsr-multilingual-53-fa'
30
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
- wav2vec_model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device).eval()
32
- processor = Wav2Vec2Processor.from_pretrained(model_name)
33
  s2t = Speech2Text()
34
-
35
- def asr(path):
36
- return s2t(path)
37
 
38
- gr.Interface(asr,
39
  inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Record Your Beautiful Persian Voice"),
40
  outputs = gr.outputs.Textbox(label="Output Text"),
41
  title="Persian ASR using Wav2Vec 2.0 & N-gram LM",
 
6
 
7
  # Define ASR MODEL
8
  class Speech2Text:
9
+ def __init__(self, model_name='masoudmzb/wav2vec2-xlsr-multilingual-53-fa'):
10
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_name).eval()
11
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name)
12
+ self.vocab = list(self.processor.tokenizer.get_vocab().keys())
13
  self.decoder = build_ctcdecoder(self.vocab, kenlm_model_path='kenlm.scorer')
14
 
15
  def wav2feature(self, path):
16
  speech_array, sampling_rate = torchaudio.load(path)
17
+ speech_array = librosa.resample(speech_array.squeeze().numpy(), sampling_rate, processor.feature_extractor.sampling_rate)
18
+ return processor(speech_array, return_tensors="pt", sampling_rate=processor.feature_extractor.sampling_rate)
 
 
19
 
20
  def feature2logits(self, features):
21
  with torch.no_grad():
22
+ return self.model(features.input_values[0]).logits.numpy()[0]
23
 
24
  def __call__(self, path):
25
  logits = self.feature2logits(self.wav2feature(path))
26
  return self.decoder.decode(logits)
27
 
28
+ # Create an instance
 
 
 
 
29
  s2t = Speech2Text()
 
 
 
30
 
31
+ gr.Interface(lambda path: s2t(path),
32
  inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Record Your Beautiful Persian Voice"),
33
  outputs = gr.outputs.Textbox(label="Output Text"),
34
  title="Persian ASR using Wav2Vec 2.0 & N-gram LM",