TaiYouWeb commited on
Commit
56bdf87
1 Parent(s): 1f7df9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py CHANGED
@@ -13,6 +13,11 @@ import gradio as gr # 添加Gradio库
13
 
14
  from config import model_config
15
 
 
 
 
 
 
16
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
17
  model_dir = snapshot_download(model_config['model_dir'])
18
 
@@ -90,3 +95,86 @@ gr.Interface(
90
  outputs=outputs,
91
  title="ASR Transcription with FunASR"
92
  ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  from config import model_config
15
 
16
+ from fastapi import FastAPI, File, Form, UploadFile, HTTPException
17
+ from fastapi.responses import StreamingResponse, Response
18
+
19
+ import uvicorn
20
+
21
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
22
  model_dir = snapshot_download(model_config['model_dir'])
23
 
 
95
  outputs=outputs,
96
  title="ASR Transcription with FunASR"
97
  ).launch()
98
+
99
+
100
+ class SynthesizeResponse(Response):
101
+ media_type = 'text/plain'
102
+
103
+ app = FastAPI()
104
+
105
+ @app.post('/asr', response_class=SynthesizeResponse)
106
+ async def generate(
107
+ file: UploadFile = File(...),
108
+ vad_model: str = Form("fsmn-vad"),
109
+ vad_kwargs: str = Form('{"max_single_segment_time": 30000}'),
110
+ ncpu: int = Form(4),
111
+ batch_size: int = Form(1),
112
+ language: str = Form("auto"),
113
+ use_itn: bool = Form(True),
114
+ batch_size_s: int = Form(60),
115
+ merge_vad: bool = Form(True),
116
+ merge_length_s: int = Form(15),
117
+ batch_size_threshold_s: int = Form(50),
118
+ hotword: Optional[str] = Form(" "),
119
+ spk_model: str = Form("cam++"),
120
+ ban_emo_unk: bool = Form(False),
121
+ ) -> StreamingResponse:
122
+ try:
123
+ # 将字符串转换为字典
124
+ vad_kwargs = json.loads(vad_kwargs)
125
+
126
+ # 创建临时文件并保存上传的音频文件
127
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
128
+ temp_file_path = temp_file.name
129
+ input_wav_bytes = await file.read()
130
+ temp_file.write(input_wav_bytes)
131
+
132
+ try:
133
+ # 初始化模型
134
+ model = AutoModel(
135
+ model=model_dir,
136
+ trust_remote_code=False,
137
+ remote_code="./model.py",
138
+ vad_model=vad_model,
139
+ vad_kwargs=vad_kwargs,
140
+ ncpu=ncpu,
141
+ batch_size=batch_size,
142
+ hub="ms",
143
+ device=device,
144
+ )
145
+
146
+ # 生成结果
147
+ res = model.generate(
148
+ input=temp_file_path, # 使用临时文件路径作为输入
149
+ cache={},
150
+ language=language,
151
+ use_itn=use_itn,
152
+ batch_size_s=batch_size_s,
153
+ merge_vad=merge_vad,
154
+ merge_length_s=merge_length_s,
155
+ batch_size_threshold_s=batch_size_threshold_s,
156
+ hotword=hotword,
157
+ spk_model=spk_model,
158
+ ban_emo_unk=ban_emo_unk
159
+ )
160
+
161
+ # 处理结果
162
+ text = rich_transcription_postprocess(res[0]["text"])
163
+
164
+ # 返回结果
165
+ return StreamingResponse(io.BytesIO(text.encode('utf-8')), media_type="text/plain")
166
+
167
+ finally:
168
+ # 确保在处理完毕后删除临时文件
169
+ if os.path.exists(temp_file_path):
170
+ os.remove(temp_file_path)
171
+
172
+ except Exception as e:
173
+ raise HTTPException(status_code=500, detail=str(e))
174
+
175
+ @app.get("/root")
176
+ async def read_root():
177
+ return {"message": "Hello World"}
178
+
179
+ if __name__ == "__main__":
180
+ uvicorn.run(app, host="0.0.0.0", port=7860)