Spaces:
Running
Running
from infer import OnnxInferenceSession | |
from text import cleaned_text_to_sequence, get_bert | |
from text.cleaner import clean_text | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
import asyncio | |
from pathlib import Path | |
OnnxSession = None | |
models = [ | |
{ | |
"local_path": "./bert/bert-large-cantonese", | |
"repo_id": "hon9kon9ize/bert-large-cantonese", | |
"files": [ | |
"pytorch_model.bin" | |
] | |
}, | |
{ | |
"local_path": "./bert/deberta-v3-large", | |
"repo_id": "microsoft/deberta-v3-large", | |
"files": [ | |
"spm.model", | |
"pytorch_model.bin" | |
] | |
}, | |
{ | |
"local_path": "./onnx", | |
"repo_id": "hon9kon9ize/bert-vits-zoengjyutgaai-onnx", | |
"files": [ | |
"BertVits2.2PT.json", | |
"BertVits2.2PT/BertVits2.2PT_enc_p.onnx", | |
"BertVits2.2PT/BertVits2.2PT_emb.onnx", | |
"BertVits2.2PT/BertVits2.2PT_dp.onnx", | |
"BertVits2.2PT/BertVits2.2PT_sdp.onnx", | |
"BertVits2.2PT/BertVits2.2PT_flow.onnx", | |
"BertVits2.2PT/BertVits2.2PT_dec.onnx" | |
] | |
} | |
] | |
def get_onnx_session(): | |
global OnnxSession | |
if OnnxSession is not None: | |
return OnnxSession | |
OnnxSession = OnnxInferenceSession( | |
{ | |
"enc": "onnx/BertVits2.2PT/BertVits2.2PT_enc_p.onnx", | |
"emb_g": "onnx/BertVits2.2PT/BertVits2.2PT_emb.onnx", | |
"dp": "onnx/BertVits2.2PT/BertVits2.2PT_dp.onnx", | |
"sdp": "onnx/BertVits2.2PT/BertVits2.2PT_sdp.onnx", | |
"flow": "onnx/BertVits2.2PT/BertVits2.2PT_flow.onnx", | |
"dec": "onnx/BertVits2.2PT/BertVits2.2PT_dec.onnx", | |
}, | |
Providers=["CPUExecutionProvider"], | |
) | |
return OnnxSession | |
def download_model_files(repo_id, files, local_path): | |
for file in files: | |
if not Path(local_path).joinpath(file).exists(): | |
hf_hub_download( | |
repo_id, file, local_dir=local_path, local_dir_use_symlinks=False | |
) | |
def download_models(): | |
for data in models: | |
download_model_files(data["repo_id"], data["files"], data["local_path"]) | |
def intersperse(lst, item): | |
result = [item] * (len(lst) * 2 + 1) | |
result[1::2] = lst | |
return result | |
def get_text(text, language_str, style_text=None, style_weight=0.7): | |
style_text = None if style_text == "" else style_text | |
# 在此处实现当前版本的get_text | |
norm_text, phone, tone, word2ph = clean_text(text, language_str) | |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) | |
# add blank | |
phone = intersperse(phone, 0) | |
tone = intersperse(tone, 0) | |
language = intersperse(language, 0) | |
for i in range(len(word2ph)): | |
word2ph[i] = word2ph[i] * 2 | |
word2ph[0] += 1 | |
bert_ori = get_bert( | |
norm_text, word2ph, language_str, "cpu", style_text, style_weight | |
) | |
del word2ph | |
assert bert_ori.shape[-1] == len(phone), phone | |
if language_str == "EN": | |
en_bert = bert_ori | |
yue_bert = np.random.randn(1024, len(phone)) | |
elif language_str == "YUE": | |
en_bert = np.random.randn(1024, len(phone)) | |
yue_bert = bert_ori | |
else: | |
raise ValueError("language_str should be EN or YUE") | |
assert yue_bert.shape[-1] == len( | |
phone | |
), f"Bert seq len {yue_bert.shape[-1]} != {len(phone)}" | |
phone = np.asarray(phone) | |
tone = np.asarray(tone) | |
language = np.asarray(language) | |
en_bert = np.asarray(en_bert.T) | |
yue_bert = np.asarray(yue_bert.T) | |
return en_bert, yue_bert, phone, tone, language | |
# Text-to-speech function | |
async def text_to_speech(text, sid=0, language="YUE"): | |
Session = get_onnx_session() | |
if not text.strip(): | |
return None, gr.Warning("Please enter text to convert.") | |
en_bert, yue_bert, x, tone, language = get_text(text, language) | |
sid = np.array([sid]) | |
audio = Session(x, tone, language, en_bert, yue_bert, sid) | |
return audio[0][0] | |
# Create Gradio application | |
import gradio as gr | |
# Gradio interface function | |
def tts_interface(text): | |
audio = asyncio.run(text_to_speech(text, 0, "YUE")) | |
return 44100, audio | |
async def create_demo(): | |
description = """廣東話語音生成器,基於Bert-VITS2模型 | |
注意:model 本身支持廣東話同英文,但呢個 space 未實現中英夾雜生成。 | |
""" | |
demo = gr.Interface( | |
fn=tts_interface, | |
inputs=[ | |
gr.Textbox(label="Input Text", lines=5), | |
], | |
outputs=[ | |
gr.Audio(label="Generated Audio"), | |
], | |
title="Cantonese TTS Text-to-Speech", | |
description=description, | |
analytics_enabled=False, | |
allow_flagging=False | |
) | |
return demo | |
# Run the application | |
if __name__ == "__main__": | |
download_models() | |
demo = asyncio.run(create_demo()) | |
demo.launch() |