|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import nemo.collections.asr as nemo_asr |
|
import shutil |
|
import os |
|
from tempfile import NamedTemporaryFile |
|
from typing import Dict |
|
from pydantic import BaseModel |
|
import uvicorn |
|
|
|
|
|
LANGUAGE_MODELS = { |
|
"hi": "ai4bharat/indicconformer_stt_hi_hybrid_ctc_rnnt_large", |
|
"bn": "ai4bharat/indicconformer_stt_bn_hybrid_ctc_rnnt_large", |
|
"ta": "ai4bharat/indicconformer_stt_ta_hybrid_ctc_rnnt_large", |
|
|
|
} |
|
|
|
|
|
class TranscriptionResponse(BaseModel): |
|
text: str |
|
language: str |
|
|
|
|
|
app = FastAPI( |
|
title="Indian Languages ASR API", |
|
description="API for automatic speech recognition in Indian languages", |
|
version="1.0.0", |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
model_cache = {} |
|
|
|
|
|
def get_model(language: str): |
|
""" |
|
Get or load the ASR model for the specified language |
|
""" |
|
if language not in LANGUAGE_MODELS: |
|
raise HTTPException( |
|
status_code=400, |
|
detail=f"Unsupported language: {language}. Supported languages are: {list(LANGUAGE_MODELS.keys())}", |
|
) |
|
|
|
if language not in model_cache: |
|
try: |
|
model = nemo_asr.models.ASRModel.from_pretrained(LANGUAGE_MODELS[language]) |
|
model_cache[language] = model |
|
except Exception as e: |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Error loading model for language {language}: {str(e)}", |
|
) |
|
|
|
return model_cache[language] |
|
|
|
|
|
@app.post("/transcribe/", response_model=TranscriptionResponse) |
|
async def transcribe_audio( |
|
language: str, |
|
file: UploadFile = File(...), |
|
): |
|
""" |
|
Transcribe audio file in the specified Indian language |
|
|
|
Parameters: |
|
- language: Language code (e.g., 'hi' for Hindi, 'bn' for Bengali) |
|
- file: Audio file in WAV format |
|
|
|
Returns: |
|
- Transcription text and language |
|
""" |
|
|
|
if not file.filename.endswith(".wav"): |
|
raise HTTPException(status_code=400, detail="Only WAV files are supported") |
|
|
|
|
|
model = get_model(language) |
|
|
|
|
|
with NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: |
|
try: |
|
|
|
shutil.copyfileobj(file.file, temp_file) |
|
temp_file.flush() |
|
|
|
|
|
transcriptions = model.transcribe([temp_file.name]) |
|
|
|
if not transcriptions or len(transcriptions) == 0: |
|
raise HTTPException(status_code=500, detail="Transcription failed") |
|
|
|
return TranscriptionResponse(text=transcriptions[0], language=language) |
|
|
|
except Exception as e: |
|
raise HTTPException( |
|
status_code=500, detail=f"Error during transcription: {str(e)}" |
|
) |
|
finally: |
|
|
|
os.unlink(temp_file.name) |
|
|
|
|
|
@app.get("/languages/") |
|
async def get_supported_languages() -> Dict[str, str]: |
|
""" |
|
Get list of supported languages and their model names |
|
""" |
|
return LANGUAGE_MODELS |
|
|
|
|
|
@app.get("/health/") |
|
async def health_check(): |
|
""" |
|
Health check endpoint |
|
""" |
|
return {"status": "healthy"} |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|