|
import base64 |
|
import io |
|
import logging |
|
import os |
|
import pathlib |
|
import typing |
|
from contextlib import asynccontextmanager |
|
|
|
import uvicorn |
|
from fastapi import FastAPI, Request, UploadFile, File, WebSocket |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.middleware.gzip import GZipMiddleware |
|
from faster_whisper import WhisperModel |
|
from pydantic import BaseModel, Field, ValidationError, model_validator, ValidationInfo |
|
from starlette.websockets import WebSocketState |
|
|
|
|
|
@asynccontextmanager |
|
async def register_init(app: FastAPI): |
|
""" |
|
启动初始化 |
|
|
|
:return: |
|
""" |
|
print('Loading ASR model...') |
|
setup_asr_model() |
|
|
|
yield |
|
|
|
|
|
def register_middleware(app: FastAPI): |
|
|
|
app.add_middleware(GZipMiddleware) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=['*'], |
|
allow_credentials=True, |
|
allow_methods=['*'], |
|
allow_headers=['*'], |
|
) |
|
|
|
|
|
def create_app(): |
|
app = FastAPI( |
|
lifespan=register_init |
|
) |
|
register_middleware(app) |
|
return app |
|
|
|
|
|
app = create_app() |
|
|
|
model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3') |
|
|
|
asr_model: typing.Optional[WhisperModel] = None |
|
|
|
|
|
def setup_asr_model(): |
|
global asr_model |
|
if asr_model is None: |
|
logging.info('Loading ASR model...') |
|
asr_model = WhisperModel(model_size, device='cuda', compute_type='float16') |
|
logging.info('Load ASR model finished.') |
|
return asr_model |
|
|
|
|
|
class TranscribeRequestParams(BaseModel): |
|
uuid: str = Field(title='Request Unique Id.') |
|
audio_file: str |
|
language: typing.Literal['en', 'zh',] |
|
using_file_content: bool |
|
|
|
@model_validator(mode='after') |
|
def check_audio_file(self): |
|
if self.using_file_content: |
|
return self |
|
|
|
if not pathlib.Path(self.audio_file).exists(): |
|
raise FileNotFoundError(f'Audio file not exists.') |
|
|
|
|
|
@app.post('/transcribe') |
|
async def transcribe_api( |
|
request: Request, |
|
obj: TranscribeRequestParams |
|
): |
|
try: |
|
audio_file = obj.audio_file |
|
if obj.using_file_content: |
|
audio_file = io.BytesIO(base64.b64decode(obj.audio_file)) |
|
|
|
segments, _ = asr_model.transcribe(audio_file, language=obj.language) |
|
|
|
transcribed_text = '' |
|
for segment in segments: |
|
transcribed_text = segment.text |
|
break |
|
except Exception as exc: |
|
logging.exception(exc) |
|
response_body = { |
|
"if_success": False, |
|
'uuid': obj.uuid, |
|
'msg': f'{exc}' |
|
} |
|
else: |
|
response_body = { |
|
"if_success": True, |
|
'uuid': obj.uuid, |
|
'transcribed_text': transcribed_text |
|
} |
|
return response_body |
|
|
|
|
|
@app.post('/transcribe-file') |
|
async def transcribe_file_api( |
|
request: Request, |
|
uuid: str, |
|
audio_file: typing.Annotated[UploadFile, File()], |
|
language: typing.Literal['en', 'zh'] |
|
): |
|
try: |
|
segments, _ = asr_model.transcribe(audio_file.file, language=language) |
|
|
|
transcribed_text = '' |
|
for segment in segments: |
|
transcribed_text = segment.text |
|
break |
|
except Exception as exc: |
|
logging.exception(exc) |
|
response_body = { |
|
"if_success": False, |
|
'uuid': uuid, |
|
'msg': f'{exc}' |
|
} |
|
else: |
|
response_body = { |
|
"if_success": True, |
|
'uuid': uuid, |
|
'transcribed_text': transcribed_text |
|
} |
|
|
|
return response_body |
|
|
|
|
|
@app.websocket('/transcribe') |
|
async def transcribe_ws_api( |
|
websocket: WebSocket |
|
): |
|
await websocket.accept() |
|
|
|
while websocket.client_state == WebSocketState.CONNECTED: |
|
request_params = await websocket.receive_json() |
|
|
|
try: |
|
form = TranscribeRequestParams.model_validate(request_params) |
|
except ValidationError as exc: |
|
logging.exception(exc) |
|
await websocket.send_json({ |
|
"if_success": False, |
|
'uuid': request_params.get('uuid', ''), |
|
'msg': f'{exc}' |
|
}) |
|
continue |
|
|
|
try: |
|
|
|
audio_file = form.audio_file |
|
if form.using_file_content: |
|
audio_file = io.BytesIO(base64.b64decode(form.audio_file)) |
|
|
|
segments, _ = asr_model.transcribe(audio_file, language=form.language) |
|
|
|
transcribed_text = '' |
|
for segment in segments: |
|
transcribed_text = segment.text |
|
break |
|
except Exception as exc: |
|
logging.exception(exc) |
|
response_body = { |
|
"if_success": False, |
|
'uuid': form.uuid, |
|
'msg': f'{exc}' |
|
} |
|
else: |
|
response_body = { |
|
"if_success": True, |
|
'uuid': form.uuid, |
|
'transcribed_text': transcribed_text |
|
} |
|
|
|
await websocket.send_json(response_body) |
|
|
|
|
|
if __name__ == '__main__': |
|
uvicorn.run( |
|
app, |
|
host=os.environ.get('HOST', '0.0.0.0'), |
|
port=int(os.environ.get('PORT', 8080)) |
|
) |
|
|