Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,15 +1,19 @@
|
|
1 |
# coding=utf-8
|
2 |
import os
|
3 |
import re
|
|
|
4 |
import utils
|
5 |
import commons
|
6 |
import json
|
|
|
7 |
import gradio as gr
|
8 |
from models import SynthesizerTrn
|
9 |
from text import text_to_sequence
|
10 |
from torch import no_grad, LongTensor
|
11 |
import logging
|
12 |
logging.getLogger('numba').setLevel(logging.WARNING)
|
|
|
|
|
13 |
hps_ms = utils.get_hparams_from_file(r'config/config.json')
|
14 |
|
15 |
def get_text(text, hps):
|
@@ -22,10 +26,11 @@ def get_text(text, hps):
|
|
22 |
def create_tts_fn(net_g_ms, speaker_id):
|
23 |
def tts_fn(text, language, noise_scale, noise_scale_w, length_scale):
|
24 |
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
29 |
if language == 0:
|
30 |
text = f"[ZH]{text}[ZH]"
|
31 |
elif language == 1:
|
@@ -34,11 +39,11 @@ def create_tts_fn(net_g_ms, speaker_id):
|
|
34 |
text = f"{text}"
|
35 |
stn_tst, clean_text = get_text(text, hps_ms)
|
36 |
with no_grad():
|
37 |
-
x_tst = stn_tst.unsqueeze(0)
|
38 |
-
x_tst_lengths = LongTensor([stn_tst.size(0)])
|
39 |
-
sid = LongTensor([speaker_id])
|
40 |
audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
|
41 |
-
length_scale=length_scale)[0][0, 0].data.float().numpy()
|
42 |
|
43 |
return "Success", (22050, audio)
|
44 |
return tts_fn
|
@@ -72,23 +77,29 @@ download_audio_js = """
|
|
72 |
"""
|
73 |
|
74 |
if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
models = []
|
76 |
with open("pretrained_models/info.json", "r", encoding="utf-8") as f:
|
77 |
models_info = json.load(f)
|
78 |
for i, info in models_info.items():
|
|
|
|
|
|
|
|
|
|
|
79 |
net_g_ms = SynthesizerTrn(
|
80 |
len(hps_ms.symbols),
|
81 |
hps_ms.data.filter_length // 2 + 1,
|
82 |
hps_ms.train.segment_size // hps_ms.data.hop_length,
|
83 |
n_speakers=hps_ms.data.n_speakers,
|
84 |
**hps_ms.model)
|
85 |
-
_ = net_g_ms.eval()
|
86 |
-
sid = info['sid']
|
87 |
-
name_en = info['name_en']
|
88 |
-
name_zh = info['name_zh']
|
89 |
-
title = info['title']
|
90 |
-
cover = f"pretrained_models/{i}/{info['cover']}"
|
91 |
utils.load_checkpoint(f'pretrained_models/{i}/{i}.pth', net_g_ms, None)
|
|
|
92 |
models.append((sid, name_en, name_zh, title, cover, net_g_ms, create_tts_fn(net_g_ms, sid)))
|
93 |
with gr.Blocks() as app:
|
94 |
gr.Markdown(
|
|
|
1 |
# coding=utf-8
|
2 |
import os
|
3 |
import re
|
4 |
+
import argparse
|
5 |
import utils
|
6 |
import commons
|
7 |
import json
|
8 |
+
import torch
|
9 |
import gradio as gr
|
10 |
from models import SynthesizerTrn
|
11 |
from text import text_to_sequence
|
12 |
from torch import no_grad, LongTensor
|
13 |
import logging
|
14 |
logging.getLogger('numba').setLevel(logging.WARNING)
|
15 |
+
limitation = os.getenv("SYSTEM") == "spaces" # limit text and audio length in huggingface spaces
|
16 |
+
|
17 |
hps_ms = utils.get_hparams_from_file(r'config/config.json')
|
18 |
|
19 |
def get_text(text, hps):
|
|
|
26 |
def create_tts_fn(net_g_ms, speaker_id):
|
27 |
def tts_fn(text, language, noise_scale, noise_scale_w, length_scale):
|
28 |
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
|
29 |
+
if limitation:
|
30 |
+
text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
|
31 |
+
max_len = 100
|
32 |
+
if text_len > max_len:
|
33 |
+
return "Error: Text is too long", None
|
34 |
if language == 0:
|
35 |
text = f"[ZH]{text}[ZH]"
|
36 |
elif language == 1:
|
|
|
39 |
text = f"{text}"
|
40 |
stn_tst, clean_text = get_text(text, hps_ms)
|
41 |
with no_grad():
|
42 |
+
x_tst = stn_tst.unsqueeze(0).to(device)
|
43 |
+
x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
|
44 |
+
sid = LongTensor([speaker_id]).to(device)
|
45 |
audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
|
46 |
+
length_scale=length_scale)[0][0, 0].data.cpu().float().numpy()
|
47 |
|
48 |
return "Success", (22050, audio)
|
49 |
return tts_fn
|
|
|
77 |
"""
|
78 |
|
79 |
if __name__ == '__main__':
|
80 |
+
parser = argparse.ArgumentParser()
|
81 |
+
parser.add_argument('--device', type=str, default='cpu')
|
82 |
+
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
83 |
+
args = parser.parse_args()
|
84 |
+
device = torch.device(args.device)
|
85 |
+
|
86 |
models = []
|
87 |
with open("pretrained_models/info.json", "r", encoding="utf-8") as f:
|
88 |
models_info = json.load(f)
|
89 |
for i, info in models_info.items():
|
90 |
+
sid = info['sid']
|
91 |
+
name_en = info['name_en']
|
92 |
+
name_zh = info['name_zh']
|
93 |
+
title = info['title']
|
94 |
+
cover = f"pretrained_models/{i}/{info['cover']}"
|
95 |
net_g_ms = SynthesizerTrn(
|
96 |
len(hps_ms.symbols),
|
97 |
hps_ms.data.filter_length // 2 + 1,
|
98 |
hps_ms.train.segment_size // hps_ms.data.hop_length,
|
99 |
n_speakers=hps_ms.data.n_speakers,
|
100 |
**hps_ms.model)
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
utils.load_checkpoint(f'pretrained_models/{i}/{i}.pth', net_g_ms, None)
|
102 |
+
_ = net_g_ms.eval().to(device)
|
103 |
models.append((sid, name_en, name_zh, title, cover, net_g_ms, create_tts_fn(net_g_ms, sid)))
|
104 |
with gr.Blocks() as app:
|
105 |
gr.Markdown(
|