iflamed commited on
Commit
eb53ccb
1 Parent(s): a4ab4ea

add fastapi client

Browse files
README.md CHANGED
@@ -121,10 +121,13 @@ You can get familiar with CosyVoice following this recipie.
121
  The `main.py` file has added a `TTS` api with `CosyVoice-300M-SFT` model, you can update the code based on **Basic Usage** as above.
122
 
123
  ```sh
 
 
 
124
  # For development
125
- fastapi dev --port 3003
126
  # For production
127
- fastapi run --port 3003
128
  ```
129
 
130
  **Build for deployment**
 
121
  The `main.py` file has added a `TTS` api with `CosyVoice-300M-SFT` model, you can update the code based on **Basic Usage** as above.
122
 
123
  ```sh
124
+ cd runtime/python
125
+ # Set inference model
126
+ export MODEL_DIR=pretrained_models/CosyVoice-300M-Instruct
127
  # For development
128
+ fastapi dev --port 6006 fastapi_server.py
129
  # For production
130
+ fastapi run --port 6006 fastapi_server.py
131
  ```
132
 
133
  **Build for deployment**
runtime/python/fastapi_client.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import requests
4
+
5
+ def saveResponse(path, response):
6
+ # 以二进制写入模式打开文件
7
+ with open(path, 'wb') as file:
8
+ # 将响应的二进制内容写入文件
9
+ file.write(response.content)
10
+
11
+ def main():
12
+ api = args.api_base
13
+ if args.mode == 'sft':
14
+ url = api + "/api/inference/sft"
15
+ payload={
16
+ 'tts': args.tts_text,
17
+ 'role': args.spk_id
18
+ }
19
+ response = requests.request("POST", url, data=payload)
20
+ saveResponse(args.tts_wav, response)
21
+ elif args.mode == 'zero_shot':
22
+ url = api + "/api/inference/zero-shot"
23
+ payload={
24
+ 'tts': args.tts_text,
25
+ 'prompt': args.prompt_text
26
+ }
27
+ files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
28
+ response = requests.request("POST", url, data=payload, files=files)
29
+ saveResponse(args.tts_wav, response)
30
+ elif args.mode == 'cross_lingual':
31
+ url = api + "/api/inference/cross-lingual"
32
+ payload={
33
+ 'tts': args.tts_text,
34
+ }
35
+ files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
36
+ response = requests.request("POST", url, data=payload, files=files)
37
+ saveResponse(args.tts_wav, response)
38
+ else:
39
+ url = api + "/api/inference/instruct"
40
+ payload = {
41
+ 'tts': args.tts_text,
42
+ 'role': args.spk_id,
43
+ 'instruct': args.instruct_text
44
+ }
45
+ response = requests.request("POST", url, data=payload)
46
+ saveResponse(args.tts_wav, response)
47
+ logging.info("Response save to {}", args.tts_wav)
48
+
49
+ if __name__ == "__main__":
50
+ parser = argparse.ArgumentParser()
51
+ parser.add_argument('--api_base',
52
+ type=str,
53
+ default='http://127.0.0.1:6006')
54
+ parser.add_argument('--mode',
55
+ default='sft',
56
+ choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
57
+ help='request mode')
58
+ parser.add_argument('--tts_text',
59
+ type=str,
60
+ default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?')
61
+ parser.add_argument('--spk_id',
62
+ type=str,
63
+ default='中文女')
64
+ parser.add_argument('--prompt_text',
65
+ type=str,
66
+ default='希望你以后能够做的比我还好呦。')
67
+ parser.add_argument('--prompt_wav',
68
+ type=str,
69
+ default='../../zero_shot_prompt.wav')
70
+ parser.add_argument('--instruct_text',
71
+ type=str,
72
+ default='Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
73
+ parser.add_argument('--tts_wav',
74
+ type=str,
75
+ default='demo.wav')
76
+ args = parser.parse_args()
77
+ prompt_sr, target_sr = 16000, 22050
78
+ main()
runtime/python/fastapi_server.py CHANGED
@@ -1,3 +1,10 @@
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
  import io,time
 
1
+ # Set inference model
2
+ # export MODEL_DIR=pretrained_models/CosyVoice-300M-Instruct
3
+ # For development
4
+ # fastapi dev --port 6006 fastapi_server.py
5
+ # For production deployment
6
+ # fastapi run --port 6006 fastapi_server.py
7
+
8
  import os
9
  import sys
10
  import io,time