SayaSS commited on
Commit
8e15dd6
·
1 Parent(s): 1500087

add symbol input

Browse files
Files changed (2) hide show
  1. app.py +101 -23
  2. pretrained_models/info.json +2 -2
app.py CHANGED
@@ -8,7 +8,7 @@ import json
8
  import torch
9
  import gradio as gr
10
  from models import SynthesizerTrn
11
- from text import text_to_sequence
12
  from torch import no_grad, LongTensor
13
  import gradio.processing_utils as gr_processing_utils
14
  import logging
@@ -28,28 +28,29 @@ def audio_postprocess(self, y):
28
 
29
  gr.Audio.postprocess = audio_postprocess
30
 
31
- def get_text(text, hps):
32
- text_norm, clean_text = text_to_sequence(text, hps.symbols, hps.data.text_cleaners)
33
  if hps.data.add_blank:
34
  text_norm = commons.intersperse(text_norm, 0)
35
  text_norm = LongTensor(text_norm)
36
  return text_norm, clean_text
37
 
38
  def create_tts_fn(net_g_ms, speaker_id):
39
- def tts_fn(text, language, noise_scale, noise_scale_w, length_scale):
40
  text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
41
  if limitation:
42
  text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
43
  max_len = 100
44
  if text_len > max_len:
45
  return "Error: Text is too long", None
46
- if language == 0:
47
- text = f"[ZH]{text}[ZH]"
48
- elif language == 1:
49
- text = f"[JA]{text}[JA]"
50
- else:
51
- text = f"{text}"
52
- stn_tst, clean_text = get_text(text, hps_ms)
 
53
  with no_grad():
54
  x_tst = stn_tst.unsqueeze(0).to(device)
55
  x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
@@ -60,11 +61,24 @@ def create_tts_fn(net_g_ms, speaker_id):
60
  return "Success", (22050, audio)
61
  return tts_fn
62
 
 
 
 
 
 
 
 
 
 
 
 
63
  def change_lang(language):
64
  if language == 0:
65
- return 0.6, 0.668, 1.2
 
 
66
  else:
67
- return 0.6, 0.668, 1
68
 
69
  download_audio_js = """
70
  () =>{{
@@ -114,12 +128,12 @@ if __name__ == '__main__':
114
  **hps_ms.model)
115
  utils.load_checkpoint(f'pretrained_models/{i}/{i}.pth', net_g_ms, None)
116
  _ = net_g_ms.eval().to(device)
117
- models.append((sid, name_en, name_zh, title, cover, example, language, net_g_ms, create_tts_fn(net_g_ms, sid)))
118
  with gr.Blocks() as app:
119
  gr.Markdown(
120
  "# <center> vits-models\n"
121
  "## <center> Please do not generate content that could infringe upon the rights or cause harm to individuals or organizations.\n"
122
- "## <center> 请不要生成会对个人以及组织造成侵害的内容\n"
123
  "![visitor badge](https://visitor-badge.glitch.me/badge?page_id=sayashi.vits-models)\n\n"
124
  "[Open In Colab]"
125
  "(https://colab.research.google.com/drive/10QOk9NPgoKZUXkIhhuVaZ7SYra1MPMKH?usp=share_link)"
@@ -129,7 +143,7 @@ if __name__ == '__main__':
129
 
130
  with gr.Tabs():
131
  with gr.TabItem("EN"):
132
- for (sid, name_en, name_zh, title, cover, example, language, net_g_ms, tts_fn) in models:
133
  with gr.TabItem(name_en):
134
  with gr.Row():
135
  gr.Markdown(
@@ -143,7 +157,14 @@ if __name__ == '__main__':
143
  input_text = gr.Textbox(label="Text (100 words limitation)", lines=5, value=example, elem_id=f"input-text-en-{name_en.replace(' ','')}")
144
  lang = gr.Dropdown(label="Language", choices=["Chinese", "Japanese", "Mix(wrap the Chinese text with [ZH][ZH], wrap the Japanese text with [JA][JA])"],
145
  type="index", value=language)
146
- btn = gr.Button(value="Generate")
 
 
 
 
 
 
 
147
  with gr.Row():
148
  ns = gr.Slider(label="noise_scale", minimum=0.1, maximum=1.0, step=0.1, value=0.6, interactive=True)
149
  nsw = gr.Slider(label="noise_scale_w", minimum=0.1, maximum=1.0, step=0.1, value=0.668, interactive=True)
@@ -152,11 +173,36 @@ if __name__ == '__main__':
152
  o1 = gr.Textbox(label="Output Message")
153
  o2 = gr.Audio(label="Output Audio", elem_id=f"tts-audio-en-{name_en.replace(' ','')}")
154
  download = gr.Button("Download Audio")
155
- btn.click(tts_fn, inputs=[input_text, lang, ns, nsw, ls], outputs=[o1, o2])
156
- download.click(None, [], [], _js=download_audio_js.format(audio_id=f"en-{name_en.replace(' ','')}"))
157
- lang.change(change_lang, inputs=[lang], outputs=[ns, nsw, ls])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  with gr.TabItem("中文"):
159
- for (sid, name_en, name_zh, title, cover, example, language, net_g_ms, tts_fn) in models:
160
  with gr.TabItem(name_zh):
161
  with gr.Row():
162
  gr.Markdown(
@@ -170,7 +216,14 @@ if __name__ == '__main__':
170
  input_text = gr.Textbox(label="文本 (100字上限)", lines=5, value=example, elem_id=f"input-text-zh-{name_zh}")
171
  lang = gr.Dropdown(label="语言", choices=["中文", "日语", "中日混合(中文用[ZH][ZH]包裹起来,日文用[JA][JA]包裹起来)"],
172
  type="index", value="中文"if language == "Chinese" else "日语")
173
- btn = gr.Button(value="生成")
 
 
 
 
 
 
 
174
  with gr.Row():
175
  ns = gr.Slider(label="控制感情变化程度", minimum=0.1, maximum=1.0, step=0.1, value=0.6, interactive=True)
176
  nsw = gr.Slider(label="控制音素发音长度", minimum=0.1, maximum=1.0, step=0.1, value=0.668, interactive=True)
@@ -179,7 +232,32 @@ if __name__ == '__main__':
179
  o1 = gr.Textbox(label="输出信息")
180
  o2 = gr.Audio(label="输出音频", elem_id=f"tts-audio-zh-{name_zh}")
181
  download = gr.Button("下载音频")
182
- btn.click(tts_fn, inputs=[input_text, lang, ns, nsw, ls], outputs=[o1, o2])
183
  download.click(None, [], [], _js=download_audio_js.format(audio_id=f"zh-{name_zh}"))
184
  lang.change(change_lang, inputs=[lang], outputs=[ns, nsw, ls])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  app.queue(concurrency_count=1).launch(show_api=False, share=args.share)
 
8
  import torch
9
  import gradio as gr
10
  from models import SynthesizerTrn
11
+ from text import text_to_sequence, _clean_text
12
  from torch import no_grad, LongTensor
13
  import gradio.processing_utils as gr_processing_utils
14
  import logging
 
28
 
29
  gr.Audio.postprocess = audio_postprocess
30
 
31
+ def get_text(text, hps, is_symbol):
32
+ text_norm, clean_text = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
33
  if hps.data.add_blank:
34
  text_norm = commons.intersperse(text_norm, 0)
35
  text_norm = LongTensor(text_norm)
36
  return text_norm, clean_text
37
 
38
  def create_tts_fn(net_g_ms, speaker_id):
39
+ def tts_fn(text, language, noise_scale, noise_scale_w, length_scale, is_symbol):
40
  text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
41
  if limitation:
42
  text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
43
  max_len = 100
44
  if text_len > max_len:
45
  return "Error: Text is too long", None
46
+ if not is_symbol:
47
+ if language == 0:
48
+ text = f"[ZH]{text}[ZH]"
49
+ elif language == 1:
50
+ text = f"[JA]{text}[JA]"
51
+ else:
52
+ text = f"{text}"
53
+ stn_tst, clean_text = get_text(text, hps_ms, is_symbol)
54
  with no_grad():
55
  x_tst = stn_tst.unsqueeze(0).to(device)
56
  x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
 
61
  return "Success", (22050, audio)
62
  return tts_fn
63
 
64
+ def create_to_symbol_fn(hps):
65
+ def to_symbol_fn(is_symbol_input, input_text, temp_text, temp_lang):
66
+ if temp_lang == 'Chinese':
67
+ clean_text = f'[ZH]{input_text}[ZH]'
68
+ elif temp_lang == "Japanese":
69
+ clean_text = f'[JA]{input_text}[JA]'
70
+ else:
71
+ clean_text = input_text
72
+ return (_clean_text(clean_text, hps.data.text_cleaners), input_text) if is_symbol_input else (temp_text, temp_text)
73
+
74
+ return to_symbol_fn
75
  def change_lang(language):
76
  if language == 0:
77
+ return 0.6, 0.668, 1.2, "Chinese"
78
+ elif language == 1:
79
+ return 0.6, 0.668, 1, "Japanese"
80
  else:
81
+ return 0.6, 0.668, 1, "Mix"
82
 
83
  download_audio_js = """
84
  () =>{{
 
128
  **hps_ms.model)
129
  utils.load_checkpoint(f'pretrained_models/{i}/{i}.pth', net_g_ms, None)
130
  _ = net_g_ms.eval().to(device)
131
+ models.append((sid, name_en, name_zh, title, cover, example, language, net_g_ms, create_tts_fn(net_g_ms, sid), create_to_symbol_fn(hps_ms)))
132
  with gr.Blocks() as app:
133
  gr.Markdown(
134
  "# <center> vits-models\n"
135
  "## <center> Please do not generate content that could infringe upon the rights or cause harm to individuals or organizations.\n"
136
+ "## <center> ·请不要生成会对个人以及组织造成侵害的内容\n"
137
  "![visitor badge](https://visitor-badge.glitch.me/badge?page_id=sayashi.vits-models)\n\n"
138
  "[Open In Colab]"
139
  "(https://colab.research.google.com/drive/10QOk9NPgoKZUXkIhhuVaZ7SYra1MPMKH?usp=share_link)"
 
143
 
144
  with gr.Tabs():
145
  with gr.TabItem("EN"):
146
+ for (sid, name_en, name_zh, title, cover, example, language, net_g_ms, tts_fn, to_symbol_fn) in models:
147
  with gr.TabItem(name_en):
148
  with gr.Row():
149
  gr.Markdown(
 
157
  input_text = gr.Textbox(label="Text (100 words limitation)", lines=5, value=example, elem_id=f"input-text-en-{name_en.replace(' ','')}")
158
  lang = gr.Dropdown(label="Language", choices=["Chinese", "Japanese", "Mix(wrap the Chinese text with [ZH][ZH], wrap the Japanese text with [JA][JA])"],
159
  type="index", value=language)
160
+ temp_lang = gr.Variable(value=language)
161
+ with gr.Accordion(label="Advanced Options", open=False):
162
+ temp_text_var = gr.Variable()
163
+ symbol_input = gr.Checkbox(value=False, label="Symbol input")
164
+ symbol_list = gr.Dataset(label="Symbol list", components=[input_text],
165
+ samples=[[x] for x in hps_ms.symbols])
166
+ symbol_list_json = gr.Json(value=hps_ms.symbols, visible=False)
167
+ btn = gr.Button(value="Generate", variant="primary")
168
  with gr.Row():
169
  ns = gr.Slider(label="noise_scale", minimum=0.1, maximum=1.0, step=0.1, value=0.6, interactive=True)
170
  nsw = gr.Slider(label="noise_scale_w", minimum=0.1, maximum=1.0, step=0.1, value=0.668, interactive=True)
 
173
  o1 = gr.Textbox(label="Output Message")
174
  o2 = gr.Audio(label="Output Audio", elem_id=f"tts-audio-en-{name_en.replace(' ','')}")
175
  download = gr.Button("Download Audio")
176
+ btn.click(tts_fn, inputs=[input_text, lang, ns, nsw, ls, symbol_input], outputs=[o1, o2])
177
+ download.click(None, [], [], _js=download_audio_js.format(audio_id=f"en-{name_en.replace(' ', '')}"))
178
+ lang.change(change_lang, inputs=[lang], outputs=[ns, nsw, ls, temp_lang])
179
+ symbol_input.change(
180
+ to_symbol_fn,
181
+ [symbol_input, input_text, temp_text_var, temp_lang],
182
+ [input_text, temp_text_var]
183
+ )
184
+ symbol_list.click(None, [symbol_list, symbol_list_json], [input_text],
185
+ _js=f"""
186
+ (i,symbols) => {{
187
+ let root = document.querySelector("body > gradio-app");
188
+ if (root.shadowRoot != null)
189
+ root = root.shadowRoot;
190
+ let text_input = root.querySelector("#input-text-en-{name_en.replace(' ', '')}").querySelector("textarea");
191
+ let startPos = text_input.selectionStart;
192
+ let endPos = text_input.selectionEnd;
193
+ let oldTxt = text_input.value;
194
+ let result = oldTxt.substring(0, startPos) + symbols[i] + oldTxt.substring(endPos);
195
+ text_input.value = result;
196
+ let x = window.scrollX, y = window.scrollY;
197
+ text_input.focus();
198
+ text_input.selectionStart = startPos + symbols[i].length;
199
+ text_input.selectionEnd = startPos + symbols[i].length;
200
+ text_input.blur();
201
+ window.scrollTo(x, y);
202
+ return text_input.value;
203
+ }}""")
204
  with gr.TabItem("中文"):
205
+ for (sid, name_en, name_zh, title, cover, example, language, net_g_ms, tts_fn, to_symbol_fn) in models:
206
  with gr.TabItem(name_zh):
207
  with gr.Row():
208
  gr.Markdown(
 
216
  input_text = gr.Textbox(label="文本 (100字上限)", lines=5, value=example, elem_id=f"input-text-zh-{name_zh}")
217
  lang = gr.Dropdown(label="语言", choices=["中文", "日语", "中日混合(中文用[ZH][ZH]包裹起来,日文用[JA][JA]包裹起来)"],
218
  type="index", value="中文"if language == "Chinese" else "日语")
219
+ temp_lang = gr.Variable(value=language)
220
+ with gr.Accordion(label="高级选项", open=False):
221
+ temp_text_var = gr.Variable()
222
+ symbol_input = gr.Checkbox(value=False, label="符号输入")
223
+ symbol_list = gr.Dataset(label="符号列表", components=[input_text],
224
+ samples=[[x] for x in hps_ms.symbols])
225
+ symbol_list_json = gr.Json(value=hps_ms.symbols, visible=False)
226
+ btn = gr.Button(value="生成", variant="primary")
227
  with gr.Row():
228
  ns = gr.Slider(label="控制感情变化程度", minimum=0.1, maximum=1.0, step=0.1, value=0.6, interactive=True)
229
  nsw = gr.Slider(label="控制音素发音长度", minimum=0.1, maximum=1.0, step=0.1, value=0.668, interactive=True)
 
232
  o1 = gr.Textbox(label="输出信息")
233
  o2 = gr.Audio(label="输出音频", elem_id=f"tts-audio-zh-{name_zh}")
234
  download = gr.Button("下载音频")
235
+ btn.click(tts_fn, inputs=[input_text, lang, ns, nsw, ls, symbol_input], outputs=[o1, o2])
236
  download.click(None, [], [], _js=download_audio_js.format(audio_id=f"zh-{name_zh}"))
237
  lang.change(change_lang, inputs=[lang], outputs=[ns, nsw, ls])
238
+ symbol_input.change(
239
+ to_symbol_fn,
240
+ [symbol_input, input_text, temp_text_var, temp_lang],
241
+ [input_text, temp_text_var]
242
+ )
243
+ symbol_list.click(None, [symbol_list, symbol_list_json], [input_text],
244
+ _js=f"""
245
+ (i,symbols) => {{
246
+ let root = document.querySelector("body > gradio-app");
247
+ if (root.shadowRoot != null)
248
+ root = root.shadowRoot;
249
+ let text_input = root.querySelector("#input-text-zh-{name_zh}").querySelector("textarea");
250
+ let startPos = text_input.selectionStart;
251
+ let endPos = text_input.selectionEnd;
252
+ let oldTxt = text_input.value;
253
+ let result = oldTxt.substring(0, startPos) + symbols[i] + oldTxt.substring(endPos);
254
+ text_input.value = result;
255
+ let x = window.scrollX, y = window.scrollY;
256
+ text_input.focus();
257
+ text_input.selectionStart = startPos + symbols[i].length;
258
+ text_input.selectionEnd = startPos + symbols[i].length;
259
+ text_input.blur();
260
+ window.scrollTo(x, y);
261
+ return text_input.value;
262
+ }}""")
263
  app.queue(concurrency_count=1).launch(show_api=False, share=args.share)
pretrained_models/info.json CHANGED
@@ -20,8 +20,8 @@
20
  "type": "single"
21
  },
22
  "nahida-jp": {
23
- "name_en": "nahida(jp)",
24
- "name_zh": "纳西妲(日语)",
25
  "title": "Genshin Impact-ナヒーダ",
26
  "cover": "cover.png",
27
  "sid": 0,
 
20
  "type": "single"
21
  },
22
  "nahida-jp": {
23
+ "name_en": "nahida-jp",
24
+ "name_zh": "纳西妲-日语",
25
  "title": "Genshin Impact-ナヒーダ",
26
  "cover": "cover.png",
27
  "sid": 0,