indic-asr / app.py
Darshan
use different setup
1f89c40
raw
history blame
3.65 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import nemo.collections.asr as nemo_asr
import shutil
import os
from tempfile import NamedTemporaryFile
from typing import Dict
from pydantic import BaseModel
import uvicorn
# Dictionary mapping language codes to model names
LANGUAGE_MODELS = {
"hi": "ai4bharat/indicconformer_stt_hi_hybrid_ctc_rnnt_large",
"bn": "ai4bharat/indicconformer_stt_bn_hybrid_ctc_rnnt_large",
"ta": "ai4bharat/indicconformer_stt_ta_hybrid_ctc_rnnt_large",
# Add more languages and their corresponding models as needed
}
class TranscriptionResponse(BaseModel):
text: str
language: str
app = FastAPI(
title="Indian Languages ASR API",
description="API for automatic speech recognition in Indian languages",
version="1.0.0",
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Cache for loaded models
model_cache = {}
def get_model(language: str):
"""
Get or load the ASR model for the specified language
"""
if language not in LANGUAGE_MODELS:
raise HTTPException(
status_code=400,
detail=f"Unsupported language: {language}. Supported languages are: {list(LANGUAGE_MODELS.keys())}",
)
if language not in model_cache:
try:
model = nemo_asr.models.ASRModel.from_pretrained(LANGUAGE_MODELS[language])
model_cache[language] = model
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error loading model for language {language}: {str(e)}",
)
return model_cache[language]
@app.post("/transcribe/", response_model=TranscriptionResponse)
async def transcribe_audio(
language: str,
file: UploadFile = File(...),
):
"""
Transcribe audio file in the specified Indian language
Parameters:
- language: Language code (e.g., 'hi' for Hindi, 'bn' for Bengali)
- file: Audio file in WAV format
Returns:
- Transcription text and language
"""
# Validate file format
if not file.filename.endswith(".wav"):
raise HTTPException(status_code=400, detail="Only WAV files are supported")
# Get the appropriate model
model = get_model(language)
# Save uploaded file temporarily
with NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
try:
# Copy uploaded file to temporary file
shutil.copyfileobj(file.file, temp_file)
temp_file.flush()
# Perform transcription
transcriptions = model.transcribe([temp_file.name])
if not transcriptions or len(transcriptions) == 0:
raise HTTPException(status_code=500, detail="Transcription failed")
return TranscriptionResponse(text=transcriptions[0], language=language)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error during transcription: {str(e)}"
)
finally:
# Clean up temporary file
os.unlink(temp_file.name)
@app.get("/languages/")
async def get_supported_languages() -> Dict[str, str]:
"""
Get list of supported languages and their model names
"""
return LANGUAGE_MODELS
@app.get("/health/")
async def health_check():
"""
Health check endpoint
"""
return {"status": "healthy"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)