import base64 import io import logging import os import pathlib import typing from contextlib import asynccontextmanager import uvicorn from fastapi import FastAPI, Request, UploadFile, File, WebSocket from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from faster_whisper import WhisperModel from pydantic import BaseModel, Field, ValidationError, model_validator, ValidationInfo from starlette.websockets import WebSocketState @asynccontextmanager async def register_init(app: FastAPI): """ 启动初始化 :return: """ print('Loading ASR model...') setup_asr_model() yield def register_middleware(app: FastAPI): # Gzip: Always at the top app.add_middleware(GZipMiddleware) # CORS: Always at the end app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) def create_app(): app = FastAPI( lifespan=register_init ) register_middleware(app) return app app = create_app() model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3') # Run on GPU with FP16 asr_model: typing.Optional[WhisperModel] = None def setup_asr_model(): global asr_model if asr_model is None: logging.info('Loading ASR model...') asr_model = WhisperModel(model_size, device='cuda', compute_type='float16') logging.info('Load ASR model finished.') return asr_model class TranscribeRequestParams(BaseModel): uuid: str = Field(title='Request Unique Id.') audio_file: str language: typing.Literal['en', 'zh',] using_file_content: bool @model_validator(mode='after') def check_audio_file(self): if self.using_file_content: return self if not pathlib.Path(self.audio_file).exists(): raise FileNotFoundError(f'Audio file not exists.') @app.post('/transcribe') async def transcribe_api( request: Request, obj: TranscribeRequestParams ): try: audio_file = obj.audio_file if obj.using_file_content: audio_file = io.BytesIO(base64.b64decode(obj.audio_file)) segments, _ = asr_model.transcribe(audio_file, language=obj.language) transcribed_text = '' for segment in segments: transcribed_text = segment.text break except Exception as exc: logging.exception(exc) response_body = { "if_success": False, 'uuid': obj.uuid, 'msg': f'{exc}' } else: response_body = { "if_success": True, 'uuid': obj.uuid, 'transcribed_text': transcribed_text } return response_body @app.post('/transcribe-file') async def transcribe_file_api( request: Request, uuid: str, audio_file: typing.Annotated[UploadFile, File()], language: typing.Literal['en', 'zh'] ): try: segments, _ = asr_model.transcribe(audio_file.file, language=language) transcribed_text = '' for segment in segments: transcribed_text = segment.text break except Exception as exc: logging.exception(exc) response_body = { "if_success": False, 'uuid': uuid, 'msg': f'{exc}' } else: response_body = { "if_success": True, 'uuid': uuid, 'transcribed_text': transcribed_text } return response_body @app.websocket('/transcribe') async def transcribe_ws_api( websocket: WebSocket ): await websocket.accept() while websocket.client_state == WebSocketState.CONNECTED: request_params = await websocket.receive_json() try: form = TranscribeRequestParams.model_validate(request_params) except ValidationError as exc: logging.exception(exc) await websocket.send_json({ "if_success": False, 'uuid': request_params.get('uuid', ''), 'msg': f'{exc}' }) continue try: audio_file = form.audio_file if form.using_file_content: audio_file = io.BytesIO(base64.b64decode(form.audio_file)) segments, _ = asr_model.transcribe(audio_file, language=form.language) transcribed_text = '' for segment in segments: transcribed_text = segment.text break except Exception as exc: logging.exception(exc) response_body = { "if_success": False, 'uuid': form.uuid, 'msg': f'{exc}' } else: response_body = { "if_success": True, 'uuid': form.uuid, 'transcribed_text': transcribed_text } await websocket.send_json(response_body) if __name__ == '__main__': uvicorn.run( app, host=os.environ.get('HOST', '0.0.0.0'), port=int(os.environ.get('PORT', 8080)) )