Darshan commited on
Commit
1f89c40
1 Parent(s): 05a7f27

use different setup

Browse files
Files changed (3) hide show
  1. Dockerfile +22 -16
  2. app.py +111 -34
  3. requirements.txt +24 -8
Dockerfile CHANGED
@@ -1,28 +1,34 @@
1
- # Use a lightweight Python image
2
  FROM python:3.10-slim
3
 
 
 
 
4
  # Install system dependencies
5
  RUN apt-get update && apt-get install -y \
6
- git ffmpeg wget bash && \
7
- rm -rf /var/lib/apt/lists/*
 
 
8
 
9
- # Set working directory
10
- WORKDIR /app
 
 
 
11
 
12
- # Clone NeMo from the specific branch and install it
13
  RUN git clone https://github.com/AI4Bharat/NeMo.git && \
14
  cd NeMo && \
15
- git checkout nemo-v2 && \
16
- bash reinstall.sh
17
 
18
- # Copy the application code into the container
19
- COPY . .
20
 
21
- # Install remaining Python dependencies
22
- RUN pip install --no-cache-dir -r requirements.txt
23
 
24
- # Expose the application port
25
- EXPOSE 7860
26
 
27
- # Start the FastAPI app with Uvicorn
28
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
1
  FROM python:3.10-slim
2
 
3
+ # Set working directory
4
+ WORKDIR /app
5
+
6
  # Install system dependencies
7
  RUN apt-get update && apt-get install -y \
8
+ git \
9
+ build-essential \
10
+ libsndfile1 \
11
+ && rm -rf /var/lib/apt/lists/*
12
 
13
+ # Copy requirements first to leverage Docker cache
14
+ COPY requirements.txt .
15
+
16
+ # Install Python dependencies
17
+ RUN pip install --no-cache-dir -r requirements.txt
18
 
19
+ # Clone and install NeMo
20
  RUN git clone https://github.com/AI4Bharat/NeMo.git && \
21
  cd NeMo && \
22
+ pip install -e .
 
23
 
24
+ # Copy application code
25
+ COPY main.py .
26
 
27
+ # Create directory for temporary files
28
+ RUN mkdir -p /tmp/audio_files
29
 
30
+ # Expose port
31
+ EXPOSE 8000
32
 
33
+ # Command to run the application
34
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
app.py CHANGED
@@ -1,53 +1,130 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
 
2
  import nemo.collections.asr as nemo_asr
3
- import torch
4
  import shutil
5
  import os
 
 
 
6
  import uvicorn
7
 
8
- app = FastAPI()
 
 
 
 
 
 
9
 
10
- # Set the device (CPU or CUDA if available)
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
- # Load and configure the ASR model
14
- model = nemo_asr.models.ASRModel.from_pretrained(
15
- "ai4bharat/indicconformer_stt_hi_hybrid_rnnt_large"
16
- )
17
- model.freeze() # Set to inference mode
18
- model = model.to(device)
19
- model.cur_decoder = "rnnt" # Use RNNT decoder
20
-
21
- UPLOAD_FOLDER = "./uploads"
22
- os.makedirs(UPLOAD_FOLDER, exist_ok=True) # Create upload folder if it doesn't exist
23
 
24
 
25
- @app.post("/transcribe/")
26
- async def transcribe_audio(file: UploadFile = File(...), source_lang: str = "hi"):
27
- try:
28
- # Save the uploaded audio file to disk
29
- file_path = os.path.join(UPLOAD_FOLDER, file.filename)
30
- with open(file_path, "wb") as buffer:
31
- shutil.copyfileobj(file.file, buffer)
32
 
33
- # Perform transcription using the provided language ID
34
- transcription = model.transcribe(
35
- [file_path], batch_size=1, language_id=source_lang
36
- )[0]
 
 
 
 
37
 
38
- # Cleanup the uploaded file
39
- os.remove(file_path)
40
 
41
- return {"transcription": transcription}
42
 
43
- except Exception as e:
 
 
 
 
44
  raise HTTPException(
45
- status_code=500, detail=f"Error during transcription: {str(e)}"
 
46
  )
47
 
 
 
 
 
 
 
 
 
 
48
 
49
- # Run the app if inside a container
50
- if __name__ == "__main__":
51
- import uvicorn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  import nemo.collections.asr as nemo_asr
 
4
  import shutil
5
  import os
6
+ from tempfile import NamedTemporaryFile
7
+ from typing import Dict
8
+ from pydantic import BaseModel
9
  import uvicorn
10
 
11
+ # Dictionary mapping language codes to model names
12
+ LANGUAGE_MODELS = {
13
+ "hi": "ai4bharat/indicconformer_stt_hi_hybrid_ctc_rnnt_large",
14
+ "bn": "ai4bharat/indicconformer_stt_bn_hybrid_ctc_rnnt_large",
15
+ "ta": "ai4bharat/indicconformer_stt_ta_hybrid_ctc_rnnt_large",
16
+ # Add more languages and their corresponding models as needed
17
+ }
18
 
 
 
19
 
20
+ class TranscriptionResponse(BaseModel):
21
+ text: str
22
+ language: str
 
 
 
 
 
 
 
23
 
24
 
25
+ app = FastAPI(
26
+ title="Indian Languages ASR API",
27
+ description="API for automatic speech recognition in Indian languages",
28
+ version="1.0.0",
29
+ )
 
 
30
 
31
+ # Add CORS middleware
32
+ app.add_middleware(
33
+ CORSMiddleware,
34
+ allow_origins=["*"],
35
+ allow_credentials=True,
36
+ allow_methods=["*"],
37
+ allow_headers=["*"],
38
+ )
39
 
40
+ # Cache for loaded models
41
+ model_cache = {}
42
 
 
43
 
44
+ def get_model(language: str):
45
+ """
46
+ Get or load the ASR model for the specified language
47
+ """
48
+ if language not in LANGUAGE_MODELS:
49
  raise HTTPException(
50
+ status_code=400,
51
+ detail=f"Unsupported language: {language}. Supported languages are: {list(LANGUAGE_MODELS.keys())}",
52
  )
53
 
54
+ if language not in model_cache:
55
+ try:
56
+ model = nemo_asr.models.ASRModel.from_pretrained(LANGUAGE_MODELS[language])
57
+ model_cache[language] = model
58
+ except Exception as e:
59
+ raise HTTPException(
60
+ status_code=500,
61
+ detail=f"Error loading model for language {language}: {str(e)}",
62
+ )
63
 
64
+ return model_cache[language]
65
+
66
+
67
+ @app.post("/transcribe/", response_model=TranscriptionResponse)
68
+ async def transcribe_audio(
69
+ language: str,
70
+ file: UploadFile = File(...),
71
+ ):
72
+ """
73
+ Transcribe audio file in the specified Indian language
74
+
75
+ Parameters:
76
+ - language: Language code (e.g., 'hi' for Hindi, 'bn' for Bengali)
77
+ - file: Audio file in WAV format
78
+
79
+ Returns:
80
+ - Transcription text and language
81
+ """
82
+ # Validate file format
83
+ if not file.filename.endswith(".wav"):
84
+ raise HTTPException(status_code=400, detail="Only WAV files are supported")
85
+
86
+ # Get the appropriate model
87
+ model = get_model(language)
88
+
89
+ # Save uploaded file temporarily
90
+ with NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
91
+ try:
92
+ # Copy uploaded file to temporary file
93
+ shutil.copyfileobj(file.file, temp_file)
94
+ temp_file.flush()
95
 
96
+ # Perform transcription
97
+ transcriptions = model.transcribe([temp_file.name])
98
+
99
+ if not transcriptions or len(transcriptions) == 0:
100
+ raise HTTPException(status_code=500, detail="Transcription failed")
101
+
102
+ return TranscriptionResponse(text=transcriptions[0], language=language)
103
+
104
+ except Exception as e:
105
+ raise HTTPException(
106
+ status_code=500, detail=f"Error during transcription: {str(e)}"
107
+ )
108
+ finally:
109
+ # Clean up temporary file
110
+ os.unlink(temp_file.name)
111
+
112
+
113
+ @app.get("/languages/")
114
+ async def get_supported_languages() -> Dict[str, str]:
115
+ """
116
+ Get list of supported languages and their model names
117
+ """
118
+ return LANGUAGE_MODELS
119
+
120
+
121
+ @app.get("/health/")
122
+ async def health_check():
123
+ """
124
+ Health check endpoint
125
+ """
126
+ return {"status": "healthy"}
127
+
128
+
129
+ if __name__ == "__main__":
130
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt CHANGED
@@ -1,9 +1,25 @@
1
- fastapi
2
- uvicorn
3
- torch
4
- ffmpeg-python
5
- packaging
 
 
 
 
6
  huggingface_hub==0.23.2
7
- soundfile
8
- numpy
9
- setuptools
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt
2
+ fastapi==0.104.1
3
+ uvicorn==0.24.0
4
+ python-multipart==0.0.6
5
+ pydantic==2.4.2
6
+ torch==2.1.0
7
+ torchaudio==2.1.0
8
+ torchvision==0.16.0
9
+ packaging==23.2
10
  huggingface_hub==0.23.2
11
+ numpy>=1.20.0
12
+ soundfile>=0.12.1
13
+ librosa>=0.10.1
14
+ omegaconf>=2.3.0
15
+ hydra-core>=1.3.2
16
+ pytorch-lightning>=2.1.0
17
+ webdataset>=0.1.62
18
+ transformers>=4.36.0
19
+ sacremoses>=0.0.53
20
+ youtokentome>=1.0.6
21
+ numpy<1.24.0
22
+ einops>=0.6.1
23
+ contextlib2>=21.6.0
24
+ inflect>=7.0.0
25
+ typing_extensions>=4.8.0