Spaces:
Sleeping
Sleeping
cyberspyde
commited on
Commit
•
99d311a
1
Parent(s):
c858f8e
model update
Browse files
main.py
CHANGED
@@ -1,35 +1,12 @@
|
|
1 |
from flask import Flask, request, jsonify
|
2 |
-
from transformers import
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
|
6 |
app = Flask(__name__)
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
USE_ONNX = False # change this to True if you want to test onnx model
|
12 |
-
silero_vad_path = 'snakers4/silero-vad'
|
13 |
-
vad_model, vad_utils = torch.hub.load(silero_vad_path,
|
14 |
-
model='silero_vad',
|
15 |
-
force_reload=True,
|
16 |
-
onnx=USE_ONNX)
|
17 |
-
|
18 |
-
(get_speech_timestamps,
|
19 |
-
save_audio,
|
20 |
-
read_audio,
|
21 |
-
VADIterator,
|
22 |
-
collect_chunks) = vad_utils
|
23 |
-
STT_SAMPLE_RATE = 16000
|
24 |
-
|
25 |
-
|
26 |
-
def int2float(sound):
|
27 |
-
abs_max = np.abs(sound).max()
|
28 |
-
sound = sound.astype('float32')
|
29 |
-
if abs_max > 0:
|
30 |
-
sound *= 1/32768
|
31 |
-
sound = sound.squeeze() # depends on the use case
|
32 |
-
return sound
|
33 |
|
34 |
@app.route('/', methods=['GET'])
|
35 |
def index():
|
@@ -38,21 +15,16 @@ def index():
|
|
38 |
@app.route('/transcribe', methods=['POST'])
|
39 |
def transcribe():
|
40 |
data_frames = request.data
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
inputs = processor(final_audio_data, return_tensors="pt", sampling_rate=16000, max_new_tokens=100)
|
48 |
-
input_features = inputs.input_features
|
49 |
-
generated_ids = model.generate(inputs=input_features)
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
transcription = str(e)
|
55 |
-
return str(transcription), {'Content-Type': 'application/json'}
|
56 |
|
57 |
if __name__ == '__main__':
|
58 |
app.run(host='0.0.0.0', port=7860)
|
|
|
1 |
from flask import Flask, request, jsonify
|
2 |
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
|
6 |
app = Flask(__name__)
|
7 |
+
processor = Wav2Vec2Processor.from_pretrained("oyqiz/uzbek_stt")
|
8 |
+
model = Wav2Vec2ForCTC.from_pretrained("oyqiz/uzbek_stt")
|
9 |
+
SAMPLE_RATE = 16000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
@app.route('/', methods=['GET'])
|
12 |
def index():
|
|
|
15 |
@app.route('/transcribe', methods=['POST'])
|
16 |
def transcribe():
|
17 |
data_frames = request.data
|
18 |
+
audio_np = np.frombuffer(data_frames, dtype=np.int16)
|
19 |
+
audio_np = audio_np / np.iinfo(np.int16).max
|
20 |
+
inputs = processor(audio_np, sampling_rate=SAMPLE_RATE, return_tensors="pt")
|
21 |
+
|
22 |
+
with torch.no_grad():
|
23 |
+
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
|
|
|
|
|
|
|
24 |
|
25 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
26 |
+
transcription = processor.decode(predicted_ids[0])
|
27 |
+
return transcription
|
|
|
|
|
28 |
|
29 |
if __name__ == '__main__':
|
30 |
app.run(host='0.0.0.0', port=7860)
|