Spaces:
Running
on
Zero
Running
on
Zero
update fastapi
Browse files
cosyvoice/cli/cosyvoice.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13 |
# limitations under the License.
|
14 |
import os
|
15 |
import time
|
|
|
16 |
from hyperpyyaml import load_hyperpyyaml
|
17 |
from modelscope import snapshot_download
|
18 |
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
@@ -52,7 +53,7 @@ class CosyVoice:
|
|
52 |
return spks
|
53 |
|
54 |
def inference_sft(self, tts_text, spk_id, stream=False):
|
55 |
-
for i in self.frontend.text_normalize(tts_text, split=True):
|
56 |
model_input = self.frontend.frontend_sft(i, spk_id)
|
57 |
start_time = time.time()
|
58 |
logging.info('synthesis text {}'.format(i))
|
@@ -64,7 +65,7 @@ class CosyVoice:
|
|
64 |
|
65 |
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
|
66 |
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
|
67 |
-
for i in self.frontend.text_normalize(tts_text, split=True):
|
68 |
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
|
69 |
start_time = time.time()
|
70 |
logging.info('synthesis text {}'.format(i))
|
@@ -77,7 +78,7 @@ class CosyVoice:
|
|
77 |
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
|
78 |
if self.frontend.instruct is True:
|
79 |
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
|
80 |
-
for i in self.frontend.text_normalize(tts_text, split=True):
|
81 |
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
|
82 |
start_time = time.time()
|
83 |
logging.info('synthesis text {}'.format(i))
|
@@ -91,7 +92,7 @@ class CosyVoice:
|
|
91 |
if self.frontend.instruct is False:
|
92 |
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
93 |
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
|
94 |
-
for i in self.frontend.text_normalize(tts_text, split=True):
|
95 |
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
96 |
start_time = time.time()
|
97 |
logging.info('synthesis text {}'.format(i))
|
|
|
13 |
# limitations under the License.
|
14 |
import os
|
15 |
import time
|
16 |
+
from tqdm import tqdm
|
17 |
from hyperpyyaml import load_hyperpyyaml
|
18 |
from modelscope import snapshot_download
|
19 |
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
|
|
53 |
return spks
|
54 |
|
55 |
def inference_sft(self, tts_text, spk_id, stream=False):
|
56 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
57 |
model_input = self.frontend.frontend_sft(i, spk_id)
|
58 |
start_time = time.time()
|
59 |
logging.info('synthesis text {}'.format(i))
|
|
|
65 |
|
66 |
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
|
67 |
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
|
68 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
69 |
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
|
70 |
start_time = time.time()
|
71 |
logging.info('synthesis text {}'.format(i))
|
|
|
78 |
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
|
79 |
if self.frontend.instruct is True:
|
80 |
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
|
81 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
82 |
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
|
83 |
start_time = time.time()
|
84 |
logging.info('synthesis text {}'.format(i))
|
|
|
92 |
if self.frontend.instruct is False:
|
93 |
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
94 |
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
|
95 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
96 |
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
97 |
start_time = time.time()
|
98 |
logging.info('synthesis text {}'.format(i))
|
cosyvoice/hifigan/generator.py
CHANGED
@@ -340,7 +340,7 @@ class HiFTGenerator(nn.Module):
|
|
340 |
s = self._f02source(f0)
|
341 |
|
342 |
# use cache_source to avoid glitch
|
343 |
-
if cache_source.shape[2]
|
344 |
s[:, :, :cache_source.shape[2]] = cache_source
|
345 |
|
346 |
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
|
|
340 |
s = self._f02source(f0)
|
341 |
|
342 |
# use cache_source to avoid glitch
|
343 |
+
if cache_source.shape[2] != 0:
|
344 |
s[:, :, :cache_source.shape[2]] = cache_source
|
345 |
|
346 |
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
examples/libritts/cosyvoice/run.sh
CHANGED
@@ -102,4 +102,10 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|
102 |
--deepspeed_config ./conf/ds_stage2.json \
|
103 |
--deepspeed.save_states model+optimizer
|
104 |
done
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
fi
|
|
|
102 |
--deepspeed_config ./conf/ds_stage2.json \
|
103 |
--deepspeed.save_states model+optimizer
|
104 |
done
|
105 |
+
fi
|
106 |
+
|
107 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
108 |
+
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
|
109 |
+
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
|
110 |
+
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
|
111 |
fi
|
examples/magicdata-read/cosyvoice/run.sh
CHANGED
@@ -102,4 +102,10 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|
102 |
--deepspeed_config ./conf/ds_stage2.json \
|
103 |
--deepspeed.save_states model+optimizer
|
104 |
done
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
fi
|
|
|
102 |
--deepspeed_config ./conf/ds_stage2.json \
|
103 |
--deepspeed.save_states model+optimizer
|
104 |
done
|
105 |
+
fi
|
106 |
+
|
107 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
108 |
+
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
|
109 |
+
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
|
110 |
+
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
|
111 |
fi
|
runtime/python/fastapi/client.py
CHANGED
@@ -1,56 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import argparse
|
2 |
import logging
|
3 |
import requests
|
|
|
|
|
|
|
4 |
|
5 |
-
def saveResponse(path, response):
|
6 |
-
# 以二进制写入模式打开文件
|
7 |
-
with open(path, 'wb') as file:
|
8 |
-
# 将响应的二进制内容写入文件
|
9 |
-
file.write(response.content)
|
10 |
|
11 |
def main():
|
12 |
-
|
13 |
if args.mode == 'sft':
|
14 |
-
|
15 |
-
|
16 |
-
'
|
17 |
-
'role': args.spk_id
|
18 |
}
|
19 |
-
response = requests.request("
|
20 |
-
saveResponse(args.tts_wav, response)
|
21 |
elif args.mode == 'zero_shot':
|
22 |
-
|
23 |
-
|
24 |
-
'
|
25 |
-
'prompt': args.prompt_text
|
26 |
}
|
27 |
-
files=[('
|
28 |
-
response = requests.request("
|
29 |
-
saveResponse(args.tts_wav, response)
|
30 |
elif args.mode == 'cross_lingual':
|
31 |
-
|
32 |
-
|
33 |
-
'tts': args.tts_text,
|
34 |
}
|
35 |
-
files=[('
|
36 |
-
response = requests.request("
|
37 |
-
saveResponse(args.tts_wav, response)
|
38 |
else:
|
39 |
-
url = api + "/api/inference/instruct"
|
40 |
payload = {
|
41 |
-
'
|
42 |
-
'
|
43 |
-
'
|
44 |
}
|
45 |
-
response = requests.request("
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
if __name__ == "__main__":
|
50 |
parser = argparse.ArgumentParser()
|
51 |
-
parser.add_argument('--
|
52 |
type=str,
|
53 |
-
default='
|
|
|
|
|
|
|
54 |
parser.add_argument('--mode',
|
55 |
default='sft',
|
56 |
choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
import argparse
|
15 |
import logging
|
16 |
import requests
|
17 |
+
import torch
|
18 |
+
import torchaudio
|
19 |
+
import numpy as np
|
20 |
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def main():
|
23 |
+
url = "http://{}:{}/inference_{}".format(args.host, args.port, args.mode)
|
24 |
if args.mode == 'sft':
|
25 |
+
payload = {
|
26 |
+
'tts_text': args.tts_text,
|
27 |
+
'spk_id': args.spk_id
|
|
|
28 |
}
|
29 |
+
response = requests.request("GET", url, data=payload, stream=True)
|
|
|
30 |
elif args.mode == 'zero_shot':
|
31 |
+
payload = {
|
32 |
+
'tts_text': args.tts_text,
|
33 |
+
'prompt_text': args.prompt_text
|
|
|
34 |
}
|
35 |
+
files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))]
|
36 |
+
response = requests.request("GET", url, data=payload, files=files, stream=True)
|
|
|
37 |
elif args.mode == 'cross_lingual':
|
38 |
+
payload = {
|
39 |
+
'tts_text': args.tts_text,
|
|
|
40 |
}
|
41 |
+
files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
|
42 |
+
response = requests.request("GET", url, data=payload, files=files, stream=True)
|
|
|
43 |
else:
|
|
|
44 |
payload = {
|
45 |
+
'tts_text': args.tts_text,
|
46 |
+
'spk_id': args.spk_id,
|
47 |
+
'instruct_text': args.instruct_text
|
48 |
}
|
49 |
+
response = requests.request("GET", url, data=payload, stream=True)
|
50 |
+
tts_audio = b''
|
51 |
+
for r in response.iter_content(chunk_size=16000):
|
52 |
+
tts_audio += r
|
53 |
+
tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0)
|
54 |
+
logging.info('save response to {}'.format(args.tts_wav))
|
55 |
+
torchaudio.save(args.tts_wav, tts_speech, target_sr)
|
56 |
+
logging.info('get response')
|
57 |
|
58 |
if __name__ == "__main__":
|
59 |
parser = argparse.ArgumentParser()
|
60 |
+
parser.add_argument('--host',
|
61 |
type=str,
|
62 |
+
default='0.0.0.0')
|
63 |
+
parser.add_argument('--port',
|
64 |
+
type=int,
|
65 |
+
default='50000')
|
66 |
parser.add_argument('--mode',
|
67 |
default='sft',
|
68 |
choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
|
runtime/python/fastapi/server.py
CHANGED
@@ -1,119 +1,77 @@
|
|
1 |
-
#
|
2 |
-
#
|
3 |
-
#
|
4 |
-
#
|
5 |
-
#
|
6 |
-
#
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
import os
|
9 |
import sys
|
10 |
-
import io,time
|
11 |
-
from fastapi import FastAPI, Response, File, UploadFile, Form
|
12 |
-
from fastapi.responses import HTMLResponse
|
13 |
-
from fastapi.middleware.cors import CORSMiddleware #引入 CORS中间件模块
|
14 |
-
from contextlib import asynccontextmanager
|
15 |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
16 |
sys.path.append('{}/../../..'.format(ROOT_DIR))
|
17 |
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
18 |
-
|
19 |
-
from cosyvoice.utils.file_utils import load_wav
|
20 |
-
import numpy as np
|
21 |
-
import torch
|
22 |
-
import torchaudio
|
23 |
import logging
|
24 |
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
@asynccontextmanager
|
30 |
-
async def lifespan(app: FastAPI):
|
31 |
-
model_dir = os.getenv("MODEL_DIR", "pretrained_models/CosyVoice-300M-SFT")
|
32 |
-
if model_dir:
|
33 |
-
logging.info("MODEL_DIR is {}", model_dir)
|
34 |
-
app.cosyvoice = CosyVoice(model_dir)
|
35 |
-
# sft usage
|
36 |
-
logging.info("Avaliable speakers {}", app.cosyvoice.list_avaliable_spks())
|
37 |
-
else:
|
38 |
-
raise LaunchFailed("MODEL_DIR environment must set")
|
39 |
-
yield
|
40 |
-
|
41 |
-
app = FastAPI(lifespan=lifespan)
|
42 |
-
|
43 |
-
#设置允许访问的域名
|
44 |
-
origins = ["*"] #"*",即为所有,也可以改为允许的特定ip。
|
45 |
app.add_middleware(
|
46 |
-
CORSMiddleware,
|
47 |
-
allow_origins=
|
48 |
allow_credentials=True,
|
49 |
-
allow_methods=["*"],
|
50 |
-
allow_headers=["*"])
|
51 |
-
|
52 |
-
def buildResponse(output):
|
53 |
-
buffer = io.BytesIO()
|
54 |
-
torchaudio.save(buffer, output, 22050, format="wav")
|
55 |
-
buffer.seek(0)
|
56 |
-
return Response(content=buffer.read(-1), media_type="audio/wav")
|
57 |
-
|
58 |
-
@app.post("/api/inference/sft")
|
59 |
-
@app.get("/api/inference/sft")
|
60 |
-
async def sft(tts: str = Form(), role: str = Form()):
|
61 |
-
start = time.process_time()
|
62 |
-
output = app.cosyvoice.inference_sft(tts, role)
|
63 |
-
end = time.process_time()
|
64 |
-
logging.info("infer time is {} seconds", end-start)
|
65 |
-
return buildResponse(output['tts_speech'])
|
66 |
-
|
67 |
-
@app.post("/api/inference/zero-shot")
|
68 |
-
async def zeroShot(tts: str = Form(), prompt: str = Form(), audio: UploadFile = File()):
|
69 |
-
start = time.process_time()
|
70 |
-
prompt_speech = load_wav(audio.file, 16000)
|
71 |
-
prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
|
72 |
-
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
73 |
-
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
|
80 |
-
@app.
|
81 |
-
async def
|
82 |
-
|
83 |
-
|
84 |
-
prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
|
85 |
-
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
86 |
-
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
92 |
|
93 |
-
@app.
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
end = time.process_time()
|
99 |
-
logging.info("infer time is {} seconds", end-start)
|
100 |
-
return buildResponse(output['tts_speech'])
|
101 |
|
102 |
-
@app.get("/
|
103 |
-
async def
|
104 |
-
|
|
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
</html>
|
119 |
-
"""
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
import os
|
15 |
import sys
|
|
|
|
|
|
|
|
|
|
|
16 |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
17 |
sys.path.append('{}/../../..'.format(ROOT_DIR))
|
18 |
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
19 |
+
import argparse
|
|
|
|
|
|
|
|
|
20 |
import logging
|
21 |
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
22 |
+
from fastapi import FastAPI, UploadFile, Form, File
|
23 |
+
from fastapi.responses import StreamingResponse
|
24 |
+
from fastapi.middleware.cors import CORSMiddleware
|
25 |
+
import uvicorn
|
26 |
+
import numpy as np
|
27 |
+
from cosyvoice.cli.cosyvoice import CosyVoice
|
28 |
+
from cosyvoice.utils.file_utils import load_wav
|
29 |
|
30 |
+
app = FastAPI()
|
31 |
+
# set cross region allowance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
app.add_middleware(
|
33 |
+
CORSMiddleware,
|
34 |
+
allow_origins=["*"],
|
35 |
allow_credentials=True,
|
36 |
+
allow_methods=["*"],
|
37 |
+
allow_headers=["*"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
def generate_data(model_output):
|
40 |
+
for i in model_output:
|
41 |
+
tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
|
42 |
+
yield tts_audio
|
43 |
|
44 |
+
@app.get("/inference_sft")
|
45 |
+
async def inference_sft(tts_text: str = Form(), spk_id: str = Form()):
|
46 |
+
model_output = cosyvoice.inference_sft(tts_text, spk_id)
|
47 |
+
return StreamingResponse(generate_data(model_output))
|
|
|
|
|
|
|
48 |
|
49 |
+
@app.get("/inference_zero_shot")
|
50 |
+
async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()):
|
51 |
+
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
|
52 |
+
model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
|
53 |
+
return StreamingResponse(generate_data(model_output))
|
54 |
|
55 |
+
@app.get("/inference_cross_lingual")
|
56 |
+
async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()):
|
57 |
+
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
|
58 |
+
model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
|
59 |
+
return StreamingResponse(generate_data(model_output))
|
|
|
|
|
|
|
60 |
|
61 |
+
@app.get("/inference_instruct")
|
62 |
+
async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()):
|
63 |
+
model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text)
|
64 |
+
return StreamingResponse(generate_data(model_output))
|
65 |
|
66 |
+
if __name__=='__main__':
|
67 |
+
parser = argparse.ArgumentParser()
|
68 |
+
parser.add_argument('--port',
|
69 |
+
type=int,
|
70 |
+
default=50000)
|
71 |
+
parser.add_argument('--model_dir',
|
72 |
+
type=str,
|
73 |
+
default='iic/CosyVoice-300M',
|
74 |
+
help='local path or modelscope repo id')
|
75 |
+
args = parser.parse_args()
|
76 |
+
cosyvoice = CosyVoice(args.model_dir)
|
77 |
+
uvicorn.run(app, host="127.0.0.1", port=args.port)
|
|
|
|