File size: 5,216 Bytes
b2458f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import logging
import re
import gradio as gr
import numpy
import torch
import utils
from infer import infer, get_net_g
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)
net_g = None
hps = None
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "models/G_1000.pth"
sampling_rate = 22050
def split_sentence(sentence: str):
if len(sentence) == 0:
return []
result = []
is_english = [i.isascii() for i in sentence]
is_chinese = [not re.match(r"[a-zA-Z]", i) for i in sentence]
assert len(is_english) == len(is_chinese) == len(sentence), "bad length"
assert is_english[0] or is_chinese[0], "bad first char: " + sentence[0]
current_language = ''
current_chain = []
for idx in range(len(sentence)):
if not is_english[idx]:
current_language = 'ZH'
current_chain = is_chinese
break
if not is_chinese[idx]:
current_language = 'EN'
current_chain = is_english
break
pass
step = 0
while step < len(sentence):
try:
next_step = current_chain.index(False, step)
except ValueError:
next_step = len(sentence)
result.append((sentence[step:next_step], current_language))
step = next_step
current_language = 'ZH' if current_language == 'EN' else 'EN'
current_chain = is_chinese if current_language == 'ZH' else is_english
pass
return result
def tts_fn(
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
):
language = 'ZH' if language == '普通话' else 'SH'
sentences = split_sentence(text)
silence = numpy.zeros(sampling_rate // 2, dtype=numpy.int16)
audio_data = numpy.array([], dtype=numpy.float32)
for (sentence, sentence_language) in sentences:
sub_audio_data = infer(
sentence,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
sid=speaker,
language=language if sentence_language == "ZH" else sentence_language,
hps=hps,
net_g=net_g,
device=device)
audio_data = numpy.concatenate((audio_data, sub_audio_data, silence))
audio_data = audio_data / numpy.abs(audio_data).max()
audio_data = audio_data * 32767
audio_data = audio_data.astype(numpy.int16)
return "Success", (sampling_rate, audio_data)
def main():
logging.basicConfig(level=logging.DEBUG)
global hps
hps = utils.get_hparams_from_file("configs/config.json")
global net_g
net_g = get_net_g(model_path=model_path, device=device, hps=hps)
speaker_ids = hps.data.spk2id
speakers = list(speaker_ids.keys())
languages = ["普通话", "上海话"]
with gr.Blocks() as app:
with gr.Row():
with gr.Column():
text = gr.TextArea(
label="输入文本内容",
value="\n".join([
"站一个制高点看上海,",
"Looking at Shanghai from a commanding height,",
"上海的弄堂是壮观的景象。",
"The alleys in Shanghai are a great sight.",
"它是这城市背景一样的东西。",
"It is something with the same background as this city."
]),
)
sdp_ratio = gr.Slider(minimum=0, maximum=1, value=0.2, step=0.1, label="SDP/DP混合比")
noise_scale = gr.Slider(minimum=0.1, maximum=2, value=0.6, step=0.1, label="感情")
noise_scale_w = gr.Slider(minimum=0.1, maximum=2, value=0.8, step=0.1, label="音素长度")
length_scale = gr.Slider(minimum=0.1, maximum=2, value=1.0, step=0.1, label="语速")
with gr.Column():
with gr.Row():
with gr.Column():
speaker = gr.Dropdown(choices=speakers, value=speakers[0], label="选择说话人")
with gr.Column():
language = gr.Dropdown(choices=languages, value=languages[0], label="选择语言")
submit_btn = gr.Button("生成音频", variant="primary")
text_output = gr.Textbox(label="状态")
audio_output = gr.Audio(label="音频")
submit_btn.click(
tts_fn,
inputs=[
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
],
outputs=[text_output, audio_output],
)
app.launch(share=False, server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
main()
|