asr / server.py
maolin.liu
[feature]Support choose audio file path.
2278032
import base64
import io
import logging
import os
import pathlib
import typing
from contextlib import asynccontextmanager
import uvicorn
from fastapi import FastAPI, Request, UploadFile, File, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from faster_whisper import WhisperModel
from pydantic import BaseModel, Field, ValidationError, model_validator, ValidationInfo
from starlette.websockets import WebSocketState
@asynccontextmanager
async def register_init(app: FastAPI):
"""
启动初始化
:return:
"""
print('Loading ASR model...')
setup_asr_model()
yield
def register_middleware(app: FastAPI):
# Gzip: Always at the top
app.add_middleware(GZipMiddleware)
# CORS: Always at the end
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
def create_app():
app = FastAPI(
lifespan=register_init
)
register_middleware(app)
return app
app = create_app()
model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3')
# Run on GPU with FP16
asr_model: typing.Optional[WhisperModel] = None
def setup_asr_model():
global asr_model
if asr_model is None:
logging.info('Loading ASR model...')
asr_model = WhisperModel(model_size, device='cuda', compute_type='float16')
logging.info('Load ASR model finished.')
return asr_model
class TranscribeRequestParams(BaseModel):
uuid: str = Field(title='Request Unique Id.')
audio_file: str
language: typing.Literal['en', 'zh',]
using_file_content: bool
@model_validator(mode='after')
def check_audio_file(self):
if self.using_file_content:
return self
if not pathlib.Path(self.audio_file).exists():
raise FileNotFoundError(f'Audio file not exists.')
@app.post('/transcribe')
async def transcribe_api(
request: Request,
obj: TranscribeRequestParams
):
try:
audio_file = obj.audio_file
if obj.using_file_content:
audio_file = io.BytesIO(base64.b64decode(obj.audio_file))
segments, _ = asr_model.transcribe(audio_file, language=obj.language)
transcribed_text = ''
for segment in segments:
transcribed_text = segment.text
break
except Exception as exc:
logging.exception(exc)
response_body = {
"if_success": False,
'uuid': obj.uuid,
'msg': f'{exc}'
}
else:
response_body = {
"if_success": True,
'uuid': obj.uuid,
'transcribed_text': transcribed_text
}
return response_body
@app.post('/transcribe-file')
async def transcribe_file_api(
request: Request,
uuid: str,
audio_file: typing.Annotated[UploadFile, File()],
language: typing.Literal['en', 'zh']
):
try:
segments, _ = asr_model.transcribe(audio_file.file, language=language)
transcribed_text = ''
for segment in segments:
transcribed_text = segment.text
break
except Exception as exc:
logging.exception(exc)
response_body = {
"if_success": False,
'uuid': uuid,
'msg': f'{exc}'
}
else:
response_body = {
"if_success": True,
'uuid': uuid,
'transcribed_text': transcribed_text
}
return response_body
@app.websocket('/transcribe')
async def transcribe_ws_api(
websocket: WebSocket
):
await websocket.accept()
while websocket.client_state == WebSocketState.CONNECTED:
request_params = await websocket.receive_json()
try:
form = TranscribeRequestParams.model_validate(request_params)
except ValidationError as exc:
logging.exception(exc)
await websocket.send_json({
"if_success": False,
'uuid': request_params.get('uuid', ''),
'msg': f'{exc}'
})
continue
try:
audio_file = form.audio_file
if form.using_file_content:
audio_file = io.BytesIO(base64.b64decode(form.audio_file))
segments, _ = asr_model.transcribe(audio_file, language=form.language)
transcribed_text = ''
for segment in segments:
transcribed_text = segment.text
break
except Exception as exc:
logging.exception(exc)
response_body = {
"if_success": False,
'uuid': form.uuid,
'msg': f'{exc}'
}
else:
response_body = {
"if_success": True,
'uuid': form.uuid,
'transcribed_text': transcribed_text
}
await websocket.send_json(response_body)
if __name__ == '__main__':
uvicorn.run(
app,
host=os.environ.get('HOST', '0.0.0.0'),
port=int(os.environ.get('PORT', 8080))
)