linxianzhong0128 commited on
Commit
7088d16
1 Parent(s): e8bdcf1

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +30 -0
  2. .gitignore +15 -0
  3. .gitmodules +7 -0
  4. .ipynb_checkpoints/webui-checkpoint.py +1364 -0
  5. ASR/FunASR.py +54 -0
  6. ASR/README.md +77 -0
  7. ASR/Whisper.py +129 -0
  8. ASR/__init__.py +4 -0
  9. ASR/requirements_funasr.txt +3 -0
  10. AutoDL部署.md +234 -0
  11. FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/.mdl +0 -0
  12. FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/.msc +0 -0
  13. FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/README.md +272 -0
  14. FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/config.yaml +46 -0
  15. FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/configuration.json +13 -0
  16. FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/example/punc_example.txt +3 -0
  17. FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/fig/struct.png +0 -0
  18. FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.pt +3 -0
  19. FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/tokens.json +0 -0
  20. FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch/.mdl +0 -0
  21. FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch/.msc +0 -0
  22. FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch/README.md +296 -0
  23. FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch/am.mvn +8 -0
  24. FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch/config.yaml +56 -0
  25. FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch/configuration.json +13 -0
  26. FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch/fig/struct.png +0 -0
  27. FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.pt +3 -0
  28. FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/.mdl +0 -0
  29. FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/.msc +0 -0
  30. FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/README.md +357 -0
  31. FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn +8 -0
  32. FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml +159 -0
  33. FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/configuration.json +14 -0
  34. FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/hotword.txt +1 -0
  35. FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/fig/res.png +0 -0
  36. FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/fig/seaco.png +0 -0
  37. FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.pt +3 -0
  38. FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/seg_dict +0 -0
  39. FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.json +0 -0
  40. GPT_SoVITS/AR/__init__.py +0 -0
  41. GPT_SoVITS/AR/data/__init__.py +0 -0
  42. GPT_SoVITS/AR/data/bucket_sampler.py +162 -0
  43. GPT_SoVITS/AR/data/data_module.py +74 -0
  44. GPT_SoVITS/AR/data/dataset.py +320 -0
  45. GPT_SoVITS/AR/models/__init__.py +0 -0
  46. GPT_SoVITS/AR/models/t2s_lightning_module.py +140 -0
  47. GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py +106 -0
  48. GPT_SoVITS/AR/models/t2s_model.py +327 -0
  49. GPT_SoVITS/AR/models/t2s_model_onnx.py +337 -0
  50. GPT_SoVITS/AR/models/utils.py +160 -0
.gitattributes CHANGED
@@ -33,3 +33,33 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Musetalk/data/video/man_musev.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ Musetalk/data/video/monalisa_musev.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ Musetalk/data/video/seaside4_musev.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ Musetalk/data/video/sit_musev.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ Musetalk/data/video/sun_musev.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ Musetalk/data/video/yongen_musev.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ NeRF/gridencoder/build/lib.linux-x86_64-3.10/_gridencoder.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
43
+ NeRF/gridencoder/build/temp.linux-x86_64-3.10/root/Linly-Talker/NeRF/gridencoder/src/gridencoder.o filter=lfs diff=lfs merge=lfs -text
44
+ docs/WebUI.png filter=lfs diff=lfs merge=lfs -text
45
+ examples/source_image/art_16.png filter=lfs diff=lfs merge=lfs -text
46
+ examples/source_image/art_17.png filter=lfs diff=lfs merge=lfs -text
47
+ examples/source_image/art_3.png filter=lfs diff=lfs merge=lfs -text
48
+ examples/source_image/art_4.png filter=lfs diff=lfs merge=lfs -text
49
+ examples/source_image/art_5.png filter=lfs diff=lfs merge=lfs -text
50
+ examples/source_image/art_8.png filter=lfs diff=lfs merge=lfs -text
51
+ examples/source_image/art_9.png filter=lfs diff=lfs merge=lfs -text
52
+ inputs/boy.png filter=lfs diff=lfs merge=lfs -text
53
+ pytorch3d/.github/bundle_adjust.gif filter=lfs diff=lfs merge=lfs -text
54
+ pytorch3d/.github/camera_position_teapot.gif filter=lfs diff=lfs merge=lfs -text
55
+ pytorch3d/.github/fit_nerf.gif filter=lfs diff=lfs merge=lfs -text
56
+ pytorch3d/.github/fit_textured_volume.gif filter=lfs diff=lfs merge=lfs -text
57
+ pytorch3d/.github/implicitron_config.gif filter=lfs diff=lfs merge=lfs -text
58
+ pytorch3d/.github/nerf_project_logo.gif filter=lfs diff=lfs merge=lfs -text
59
+ pytorch3d/build/lib.linux-x86_64-3.10/pytorch3d/_C.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
60
+ pytorch3d/build/temp.linux-x86_64-3.10/.ninja_deps filter=lfs diff=lfs merge=lfs -text
61
+ pytorch3d/build/temp.linux-x86_64-3.10/root/Linly-Talker/pytorch3d/pytorch3d/csrc/knn/knn.o filter=lfs diff=lfs merge=lfs -text
62
+ pytorch3d/build/temp.linux-x86_64-3.10/root/Linly-Talker/pytorch3d/pytorch3d/csrc/pulsar/cuda/renderer.construct.gpu.o filter=lfs diff=lfs merge=lfs -text
63
+ pytorch3d/build/temp.linux-x86_64-3.10/root/Linly-Talker/pytorch3d/pytorch3d/csrc/pulsar/cuda/renderer.forward.gpu.o filter=lfs diff=lfs merge=lfs -text
64
+ pytorch3d/docs/notes/assets/batch_modes.gif filter=lfs diff=lfs merge=lfs -text
65
+ pytorch3d/docs/notes/assets/meshrcnn.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ checkpoints/
3
+ gfpgan/
4
+ __pycache__/
5
+ *.pyc
6
+ Linly-AI
7
+ Qwen
8
+ checkpoints
9
+ temp
10
+ *.wav
11
+ *.vtt
12
+ *.srt
13
+ results/example_answer.mp4
14
+ request-Linly-api.py
15
+ results
.gitmodules ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [submodule "MuseV"]
2
+ path = MuseV
3
+ url = https://github.com/TMElyralab/MuseV.git
4
+
5
+ [submodule "ChatTTS"]
6
+ path = ChatTTS
7
+ url = https://github.com/2noise/ChatTTS.git
.ipynb_checkpoints/webui-checkpoint.py ADDED
@@ -0,0 +1,1364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import gradio as gr
4
+ import time
5
+ import torch
6
+ import gc
7
+ import warnings
8
+ warnings.filterwarnings('ignore')
9
+ from zhconv import convert
10
+ from LLM import LLM
11
+ from TTS import EdgeTTS
12
+ from src.cost_time import calculate_time
13
+
14
+ from configs import *
15
+ os.environ["GRADIO_TEMP_DIR"]= './temp'
16
+ os.environ["WEBUI"] = "true"
17
+ def get_title(title = 'Linly 智能对话系统 (Linly-Talker)'):
18
+ description = f"""
19
+ <p style="text-align: center; font-weight: bold;">
20
+ <span style="font-size: 28px;">{title}</span>
21
+ <br>
22
+ <span style="font-size: 18px;" id="paper-info">
23
+ [<a href="https://zhuanlan.zhihu.com/p/671006998" target="_blank">知乎</a>]
24
+ [<a href="https://www.bilibili.com/video/BV1rN4y1a76x/" target="_blank">bilibili</a>]
25
+ [<a href="https://github.com/Kedreamix/Linly-Talker" target="_blank">GitHub</a>]
26
+ [<a herf="https://kedreamix.github.io/" target="_blank">个人主页</a>]
27
+ </span>
28
+ <br>
29
+ <span>Linly-Talker是一款创新的数字人对话系统,它融合了最新的人工智能技术,包括大型语言模型(LLM)🤖、自动语音识别(ASR)🎙️、文本到语音转换(TTS)🗣️和语音克隆技术🎤。</span>
30
+ </p>
31
+ """
32
+ return description
33
+
34
+
35
+ # 设置默认system
36
+ default_system = '你是一个很有帮助的助手'
37
+ # 设置默认的prompt
38
+ prefix_prompt = '''请用少于25个字回答以下问题\n\n'''
39
+
40
+ edgetts = EdgeTTS()
41
+
42
+ # 设定默认参数值,可修改
43
+ blink_every = True
44
+ size_of_image = 256
45
+ preprocess_type = 'crop'
46
+ facerender = 'facevid2vid'
47
+ enhancer = False
48
+ is_still_mode = False
49
+ exp_weight = 1
50
+ use_ref_video = False
51
+ ref_video = None
52
+ ref_info = 'pose'
53
+ use_idle_mode = False
54
+ length_of_audio = 5
55
+
56
+ @calculate_time
57
+ def Asr(audio):
58
+ try:
59
+ question = asr.transcribe(audio)
60
+ question = convert(question, 'zh-cn')
61
+ except Exception as e:
62
+ print("ASR Error: ", e)
63
+ question = 'Gradio存在一些bug,麦克风模式有时候可能音频还未传入,请重新点击一下语音识别即可'
64
+ gr.Warning(question)
65
+ return question
66
+
67
+ def clear_memory():
68
+ """
69
+ 清理PyTorch的显存和系统内存缓存。
70
+ """
71
+ # 1. 清理缓存的变量
72
+ gc.collect() # 触发Python垃圾回收
73
+ torch.cuda.empty_cache() # 清理PyTorch的显存缓存
74
+ torch.cuda.ipc_collect() # 清理PyTorch的跨进程通信缓存
75
+
76
+ # 2. 打印显存使用情况(可选)
77
+ print(f"Memory allocated: {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB")
78
+ print(f"Max memory allocated: {torch.cuda.max_memory_allocated() / (1024 ** 2):.2f} MB")
79
+ print(f"Cached memory: {torch.cuda.memory_reserved() / (1024 ** 2):.2f} MB")
80
+ print(f"Max cached memory: {torch.cuda.max_memory_reserved() / (1024 ** 2):.2f} MB")
81
+
82
+ @calculate_time
83
+ def TTS_response(text,
84
+ voice, rate, volume, pitch,
85
+ am, voc, lang, male,
86
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut,
87
+ question_audio, question, use_mic_voice,
88
+ tts_method = 'PaddleTTS', save_path = 'answer.wav'):
89
+ # print(text, voice, rate, volume, pitch, am, voc, lang, male, tts_method, save_path)
90
+ if tts_method == 'Edge-TTS':
91
+ if not edgetts.network:
92
+ gr.Warning("请检查网络或者使用其他模型,例如PaddleTTS")
93
+ return None, None
94
+ try:
95
+ edgetts.predict(text, voice, rate, volume, pitch , 'answer.wav', 'answer.vtt')
96
+ except:
97
+ os.system(f'edge-tts --text "{text}" --voice {voice} --write-media answer.wav --write-subtitles answer.vtt')
98
+ return 'answer.wav', 'answer.vtt'
99
+ elif tts_method == 'PaddleTTS':
100
+ tts.predict(text, am, voc, lang = lang, male=male, save_path = save_path)
101
+ return save_path, None
102
+ elif tts_method == 'GPT-SoVITS克隆声音':
103
+ if use_mic_voice:
104
+ try:
105
+ vits.predict(ref_wav_path = question_audio,
106
+ prompt_text = question,
107
+ prompt_language = "中文",
108
+ text = text, # 回答
109
+ text_language = "中文",
110
+ how_to_cut = "凑四句一切",
111
+ save_path = 'answer.wav')
112
+ return 'answer.wav', None
113
+ except Exception as e:
114
+ gr.Warning("无克隆环境或者无克隆模型权重,无法克隆声音", e)
115
+ return None, None
116
+ else:
117
+ try:
118
+ vits.predict(ref_wav_path = inp_ref,
119
+ prompt_text = prompt_text,
120
+ prompt_language = prompt_language,
121
+ text = text, # 回答
122
+ text_language = text_language,
123
+ how_to_cut = how_to_cut,
124
+ save_path = 'answer.wav')
125
+ return 'answer.wav', None
126
+ except Exception as e:
127
+ gr.Warning("无克隆环境或者无克隆模型权重,无法克隆声音", e)
128
+ return None, None
129
+ return None, None
130
+ @calculate_time
131
+ def LLM_response(question_audio, question,
132
+ voice = 'zh-CN-XiaoxiaoNeural', rate = 0, volume = 0, pitch = 0,
133
+ am='fastspeech2', voc='pwgan',lang='zh', male=False,
134
+ inp_ref = None, prompt_text = "", prompt_language = "", text_language = "", how_to_cut = "", use_mic_voice = False,
135
+ tts_method = 'Edge-TTS'):
136
+ if len(question) == 0:
137
+ gr.Warning("请输入问题")
138
+ return None, None, None
139
+ answer = llm.generate(question, default_system)
140
+ print(answer)
141
+ driven_audio, driven_vtt = TTS_response(answer, voice, rate, volume, pitch,
142
+ am, voc, lang, male,
143
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, question_audio, question, use_mic_voice,
144
+ tts_method)
145
+ return driven_audio, driven_vtt, answer
146
+
147
+ @calculate_time
148
+ def Talker_response(question_audio = None, method = 'SadTalker', text = '',
149
+ voice = 'zh-CN-XiaoxiaoNeural', rate = 0, volume = 100, pitch = 0,
150
+ am = 'fastspeech2', voc = 'pwgan', lang = 'zh', male = False,
151
+ inp_ref = None, prompt_text = "", prompt_language = "", text_language = "", how_to_cut = "", use_mic_voice = False,
152
+ tts_method = 'Edge-TTS',batch_size = 2, character = '女性角色',
153
+ progress=gr.Progress(track_tqdm=True)):
154
+ default_voice = None
155
+ if character == '女性角色':
156
+ # 女性角色
157
+ source_image, pic_path = r'inputs/girl.png', r'inputs/girl.png'
158
+ crop_pic_path = "./inputs/first_frame_dir_girl/girl.png"
159
+ first_coeff_path = "./inputs/first_frame_dir_girl/girl.mat"
160
+ crop_info = ((403, 403), (19, 30, 502, 513), [40.05956541381802, 40.17324339233366, 443.7892505041507, 443.9029284826663])
161
+ default_voice = 'zh-CN-XiaoxiaoNeural'
162
+ elif character == '男性角色':
163
+ # 男性角色
164
+ source_image = r'./inputs/boy.png'
165
+ pic_path = "./inputs/boy.png"
166
+ crop_pic_path = "./inputs/first_frame_dir_boy/boy.png"
167
+ first_coeff_path = "./inputs/first_frame_dir_boy/boy.mat"
168
+ crop_info = ((876, 747), (0, 0, 886, 838), [10.382158280494476, 0, 886, 747.7078990925525])
169
+ default_voice = 'zh-CN-YunyangNeural'
170
+ else:
171
+ gr.Warning('未知角色')
172
+ return None
173
+
174
+ voice = default_voice if not voice else voice
175
+
176
+ if not voice:
177
+ gr.Warning('请选择声音')
178
+
179
+ driven_audio, driven_vtt, _ = LLM_response(question_audio, text,
180
+ voice, rate, volume, pitch,
181
+ am, voc, lang, male,
182
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
183
+ tts_method)
184
+ if driven_audio is None:
185
+ gr.Warning("音频没有正常生成,请检查TTS是否正确")
186
+ return None
187
+ if method == 'SadTalker':
188
+ pose_style = random.randint(0, 45)
189
+ video = talker.test(pic_path,
190
+ crop_pic_path,
191
+ first_coeff_path,
192
+ crop_info,
193
+ source_image,
194
+ driven_audio,
195
+ preprocess_type,
196
+ is_still_mode,
197
+ enhancer,
198
+ batch_size,
199
+ size_of_image,
200
+ pose_style,
201
+ facerender,
202
+ exp_weight,
203
+ use_ref_video,
204
+ ref_video,
205
+ ref_info,
206
+ use_idle_mode,
207
+ length_of_audio,
208
+ blink_every,
209
+ fps=20)
210
+ elif method == 'Wav2Lip':
211
+ video = talker.predict(crop_pic_path, driven_audio, batch_size, enhancer)
212
+ elif method == 'NeRFTalk':
213
+ video = talker.predict(driven_audio)
214
+ else:
215
+ gr.Warning("不支持的方法:" + method)
216
+ return None
217
+ if driven_vtt:
218
+ return video, driven_vtt
219
+ else:
220
+ return video
221
+
222
+ @calculate_time
223
+ def Talker_response_img(question_audio, method, text, voice, rate, volume, pitch,
224
+ am, voc, lang, male,
225
+ inp_ref , prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
226
+ tts_method,
227
+ source_image,
228
+ preprocess_type,
229
+ is_still_mode,
230
+ enhancer,
231
+ batch_size,
232
+ size_of_image,
233
+ pose_style,
234
+ facerender,
235
+ exp_weight,
236
+ blink_every,
237
+ fps, progress=gr.Progress(track_tqdm=True)
238
+ ):
239
+ if enhancer:
240
+ gr.Warning("记得请先安装GFPGAN库,pip install gfpgan, 已安装可忽略")
241
+ if not voice:
242
+ gr.Warning("请先选择声音")
243
+ driven_audio, driven_vtt, _ = LLM_response(question_audio, text, voice, rate, volume, pitch,
244
+ am, voc, lang, male,
245
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
246
+ tts_method = tts_method)
247
+ if driven_audio is None:
248
+ gr.Warning("音频没有正常生成,请检查TTS是否正确")
249
+ return None
250
+ if method == 'SadTalker':
251
+ video = talker.test2(source_image,
252
+ driven_audio,
253
+ preprocess_type,
254
+ is_still_mode,
255
+ enhancer,
256
+ batch_size,
257
+ size_of_image,
258
+ pose_style,
259
+ facerender,
260
+ exp_weight,
261
+ use_ref_video,
262
+ ref_video,
263
+ ref_info,
264
+ use_idle_mode,
265
+ length_of_audio,
266
+ blink_every,
267
+ fps=fps)
268
+ elif method == 'Wav2Lip':
269
+ video = talker.predict(source_image, driven_audio, batch_size)
270
+ elif method == 'NeRFTalk':
271
+ video = talker.predict(driven_audio)
272
+ else:
273
+ return None
274
+ if driven_vtt:
275
+ return video, driven_vtt
276
+ else:
277
+ return video
278
+
279
+ @calculate_time
280
+ def Talker_Say(preprocess_type,
281
+ is_still_mode,
282
+ enhancer,
283
+ batch_size,
284
+ size_of_image,
285
+ pose_style,
286
+ facerender,
287
+ exp_weight,
288
+ blink_every,
289
+ fps,source_image = None, source_video = None, question_audio = None, method = 'SadTalker', text = '',
290
+ voice = 'zh-CN-XiaoxiaoNeural', rate = 0, volume = 100, pitch = 0,
291
+ am = 'fastspeech2', voc = 'pwgan', lang = 'zh', male = False,
292
+ inp_ref = None, prompt_text = "", prompt_language = "", text_language = "", how_to_cut = "", use_mic_voice = False,
293
+ tts_method = 'Edge-TTS', character = '女性角色',
294
+ progress=gr.Progress(track_tqdm=True)):
295
+ if source_video:
296
+ source_image = source_video
297
+ default_voice = None
298
+
299
+ voice = default_voice if not voice else voice
300
+
301
+ if not voice:
302
+ gr.Warning('请选择声音')
303
+
304
+ driven_audio, driven_vtt = TTS_response(text, voice, rate, volume, pitch,
305
+ am, voc, lang, male,
306
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, question_audio, text, use_mic_voice,
307
+ tts_method)
308
+ if driven_audio is None:
309
+ gr.Warning("音频没有正常生成,请检查TTS是否正确")
310
+ return None
311
+ if method == 'SadTalker':
312
+ pose_style = random.randint(0, 45)
313
+ video = talker.test2(source_image,
314
+ driven_audio,
315
+ preprocess_type,
316
+ is_still_mode,
317
+ enhancer,
318
+ batch_size,
319
+ size_of_image,
320
+ pose_style,
321
+ facerender,
322
+ exp_weight,
323
+ use_ref_video,
324
+ ref_video,
325
+ ref_info,
326
+ use_idle_mode,
327
+ length_of_audio,
328
+ blink_every,
329
+ fps=fps)
330
+ elif method == 'Wav2Lip':
331
+ video = talker.predict(source_image, driven_audio, batch_size, enhancer)
332
+ elif method == 'NeRFTalk':
333
+ video = talker.predict(driven_audio)
334
+ else:
335
+ gr.Warning("不支持的方法:" + method)
336
+ return None
337
+ if driven_vtt:
338
+ return video, driven_vtt
339
+ else:
340
+ return video
341
+
342
+ def chat_response(system, message, history):
343
+ # response = llm.generate(message)
344
+ response, history = llm.chat(system, message, history)
345
+ print(history)
346
+ # 流式输出
347
+ for i in range(len(response)):
348
+ time.sleep(0.01)
349
+ yield "", history[:-1] + [(message, response[:i+1])]
350
+ return "", history
351
+
352
+ def modify_system_session(system: str) -> str:
353
+ if system is None or len(system) == 0:
354
+ system = default_system
355
+ llm.clear_history()
356
+ return system, system, []
357
+
358
+ def clear_session():
359
+ # clear history
360
+ llm.clear_history()
361
+ return '', []
362
+
363
+
364
+ def human_response(source_image, history, question_audio, talker_method, voice, rate, volume, pitch,
365
+ am, voc, lang, male,
366
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
367
+ tts_method, character,
368
+ preprocess_type, is_still_mode, enhancer, batch_size, size_of_image,
369
+ pose_style, facerender, exp_weight, blink_every, fps = 20, progress=gr.Progress(track_tqdm=True)):
370
+ response = history[-1][1]
371
+ qusetion = history[-1][0]
372
+ # driven_audio, video_vtt = 'answer.wav', 'answer.vtt'
373
+ if character == '女性角色':
374
+ # 女性角色
375
+ source_image, pic_path = r'./inputs/girl.png', r"./inputs/girl.png"
376
+ crop_pic_path = "./inputs/first_frame_dir_girl/girl.png"
377
+ first_coeff_path = "./inputs/first_frame_dir_girl/girl.mat"
378
+ crop_info = ((403, 403), (19, 30, 502, 513), [40.05956541381802, 40.17324339233366, 443.7892505041507, 443.9029284826663])
379
+ default_voice = 'zh-CN-XiaoxiaoNeural'
380
+ elif character == '男性角色':
381
+ # 男性角色
382
+ source_image = r'./inputs/boy.png'
383
+ pic_path = "./inputs/boy.png"
384
+ crop_pic_path = "./inputs/first_frame_dir_boy/boy.png"
385
+ first_coeff_path = "./inputs/first_frame_dir_boy/boy.mat"
386
+ crop_info = ((876, 747), (0, 0, 886, 838), [10.382158280494476, 0, 886, 747.7078990925525])
387
+ default_voice = 'zh-CN-YunyangNeural'
388
+ elif character == '自定义角色':
389
+ if source_image is None:
390
+ gr.Error("自定义角色需要上传正确的图片")
391
+ return None
392
+ default_voice = 'zh-CN-XiaoxiaoNeural'
393
+ voice = default_voice if not voice else voice
394
+ # tts.predict(response, voice, rate, volume, pitch, driven_audio, video_vtt)
395
+ driven_audio, driven_vtt = TTS_response(response, voice, rate, volume, pitch,
396
+ am, voc, lang, male,
397
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, question_audio, qusetion, use_mic_voice,
398
+ tts_method)
399
+ if driven_audio is None:
400
+ gr.Warning("音频没有正常生成,请检查TTS是否正确")
401
+ return None
402
+ if talker_method == 'SadTalker':
403
+ pose_style = random.randint(0, 45)
404
+ video = talker.test(pic_path,
405
+ crop_pic_path,
406
+ first_coeff_path,
407
+ crop_info,
408
+ source_image,
409
+ driven_audio,
410
+ preprocess_type,
411
+ is_still_mode,
412
+ enhancer,
413
+ batch_size,
414
+ size_of_image,
415
+ pose_style,
416
+ facerender,
417
+ exp_weight,
418
+ use_ref_video,
419
+ ref_video,
420
+ ref_info,
421
+ use_idle_mode,
422
+ length_of_audio,
423
+ blink_every,
424
+ fps=fps)
425
+ elif talker_method == 'Wav2Lip':
426
+ video = talker.predict(crop_pic_path, driven_audio, batch_size, enhancer)
427
+ elif talker_method == 'NeRFTalk':
428
+ video = talker.predict(driven_audio)
429
+ else:
430
+ gr.Warning("不支持的方法:" + talker_method)
431
+ return None
432
+ if driven_vtt:
433
+ return video, driven_vtt
434
+ else:
435
+ return video
436
+
437
+
438
+ @calculate_time
439
+ def MuseTalker_response(source_video, bbox_shift, question_audio = None, text = '',
440
+ voice = 'zh-CN-XiaoxiaoNeural', rate = 0, volume = 100, pitch = 0,
441
+ am = 'fastspeech2', voc = 'pwgan', lang = 'zh', male = False,
442
+ inp_ref = None, prompt_text = "", prompt_language = "", text_language = "", how_to_cut = "", use_mic_voice = False,
443
+ tts_method = 'Edge-TTS', batch_size = 4,
444
+ progress=gr.Progress(track_tqdm=True)):
445
+ default_voice = None
446
+ voice = default_voice if not voice else voice
447
+
448
+ if not voice:
449
+ gr.Warning('请选择声音')
450
+
451
+ driven_audio, driven_vtt, _ = LLM_response(question_audio, text,
452
+ voice, rate, volume, pitch,
453
+ am, voc, lang, male,
454
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
455
+ tts_method)
456
+ print(driven_audio, driven_vtt)
457
+ video = musetalker.inference_noprepare(driven_audio,
458
+ source_video,
459
+ bbox_shift,
460
+ batch_size,
461
+ fps = 25)
462
+
463
+ if driven_vtt:
464
+ return (video, driven_vtt)
465
+ else:
466
+ return video
467
+ GPT_SoVITS_ckpt = "GPT_SoVITS/pretrained_models"
468
+ def load_vits_model(gpt_path, sovits_path, progress=gr.Progress(track_tqdm=True)):
469
+ global vits
470
+ print("模型加载中...", gpt_path, sovits_path)
471
+ all_gpt_path, all_sovits_path = os.path.join(GPT_SoVITS_ckpt, gpt_path), os.path.join(GPT_SoVITS_ckpt, sovits_path)
472
+ vits.load_model(all_gpt_path, all_sovits_path)
473
+ gr.Info("模型加载成功")
474
+ return gpt_path, sovits_path
475
+
476
+ def list_models(dir, endwith = ".pth"):
477
+ list_folder = os.listdir(dir)
478
+ list_folder = [i for i in list_folder if i.endswith(endwith)]
479
+ return list_folder
480
+
481
+ def character_change(character):
482
+ if character == '女性角色':
483
+ # 女性角色
484
+ source_image = r'./inputs/girl.png'
485
+ elif character == '男性角色':
486
+ # 男性角色
487
+ source_image = r'./inputs/boy.png'
488
+ elif character == '自定义角色':
489
+ # gr.Warnings("自定义角色暂未更新,请继续关注后续,可通过自由上传图片模式进行自定义角色")
490
+ source_image = None
491
+ return source_image
492
+
493
+ def webui_setting(talk = False):
494
+ if not talk:
495
+ with gr.Tabs():
496
+ with gr.TabItem('数字人形象设定'):
497
+ source_image = gr.Image(label="Source image", type="filepath")
498
+ else:
499
+ source_image = None
500
+ with gr.Tabs("TTS Method"):
501
+ with gr.Accordion("TTS Method语音方法调节 ", open=True):
502
+ with gr.Tab("Edge-TTS"):
503
+ voice = gr.Dropdown(edgetts.SUPPORTED_VOICE,
504
+ value='zh-CN-XiaoxiaoNeural',
505
+ label="Voice 声音选择")
506
+ rate = gr.Slider(minimum=-100,
507
+ maximum=100,
508
+ value=0,
509
+ step=1.0,
510
+ label='Rate 速率')
511
+ volume = gr.Slider(minimum=0,
512
+ maximum=100,
513
+ value=100,
514
+ step=1,
515
+ label='Volume 音量')
516
+ pitch = gr.Slider(minimum=-100,
517
+ maximum=100,
518
+ value=0,
519
+ step=1,
520
+ label='Pitch 音调')
521
+ with gr.Tab("PaddleTTS"):
522
+ am = gr.Dropdown(["FastSpeech2"], label="声学模型选择", value = 'FastSpeech2')
523
+ voc = gr.Dropdown(["PWGan", "HifiGan"], label="声码器选择", value = 'PWGan')
524
+ lang = gr.Dropdown(["zh", "en", "mix", "canton"], label="语言选择", value = 'zh')
525
+ male = gr.Checkbox(label="男声(Male)", value=False)
526
+ with gr.Tab('GPT-SoVITS'):
527
+ with gr.Row():
528
+ gpt_path = gr.FileExplorer(root = GPT_SoVITS_ckpt, glob = "*.ckpt", value = "s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", file_count='single', label="GPT模型路径")
529
+ sovits_path = gr.FileExplorer(root = GPT_SoVITS_ckpt, glob = "*.pth", value = "s2G488k.pth", file_count='single', label="SoVITS模型路径")
530
+ # gpt_path = gr.Dropdown(choices=list_models(GPT_SoVITS_ckpt, 'ckpt'))
531
+ # sovits_path = gr.Dropdown(choices=list_models(GPT_SoVITS_ckpt, 'pth'))
532
+ # gpt_path = gr.Textbox(label="GPT模型路径",
533
+ # value="GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
534
+ # sovits_path = gr.Textbox(label="SoVITS模型路径",
535
+ # value="GPT_SoVITS/pretrained_models/s2G488k.pth")
536
+ button = gr.Button("加载模型")
537
+ button.click(fn = load_vits_model,
538
+ inputs=[gpt_path, sovits_path],
539
+ outputs=[gpt_path, sovits_path])
540
+
541
+ with gr.Row():
542
+ inp_ref = gr.Audio(label="请上传3~10秒内参考音频,超过会报错!", sources=["microphone", "upload"], type="filepath")
543
+ use_mic_voice = gr.Checkbox(label="使用语音问答的麦克风")
544
+ prompt_text = gr.Textbox(label="参考音频的文本", value="")
545
+ prompt_language = gr.Dropdown(
546
+ label="参考音频的语种", choices=["中文", "英文", "日文"], value="中文"
547
+ )
548
+ asr_button = gr.Button("语音识别 - 克隆参考音频")
549
+ asr_button.click(fn=Asr,inputs=[inp_ref],outputs=[prompt_text])
550
+ with gr.Row():
551
+ text_language = gr.Dropdown(
552
+ label="需要合成的语种", choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"], value="中文"
553
+ )
554
+
555
+ how_to_cut = gr.Dropdown(
556
+ label="怎么切",
557
+ choices=["不切", "凑四句一切", "凑50字一切", "按中文句号。切", "按英文句号.切", "按标点符号切" ],
558
+ value="凑四句一切",
559
+ interactive=True,
560
+ )
561
+
562
+ with gr.Column(variant='panel'):
563
+ batch_size = gr.Slider(minimum=1,
564
+ maximum=10,
565
+ value=2,
566
+ step=1,
567
+ label='Talker Batch size')
568
+
569
+ character = gr.Radio(['女性角色',
570
+ '男性角色',
571
+ '自定义角色'],
572
+ label="角色选择", value='自定义角色')
573
+ character.change(fn = character_change, inputs=[character], outputs = [source_image])
574
+ tts_method = gr.Radio(['Edge-TTS', 'PaddleTTS', 'GPT-SoVITS克隆声音', 'Comming Soon!!!'], label="Text To Speech Method",
575
+ value = 'Edge-TTS')
576
+ tts_method.change(fn = tts_model_change, inputs=[tts_method], outputs = [tts_method])
577
+ asr_method = gr.Radio(choices = ['Whisper-tiny', 'Whisper-base', 'FunASR', 'Comming Soon!!!'], value='Whisper-base', label = '语音识别模型选择')
578
+ asr_method.change(fn = asr_model_change, inputs=[asr_method], outputs = [asr_method])
579
+ talker_method = gr.Radio(choices = ['SadTalker', 'Wav2Lip', 'NeRFTalk', 'Comming Soon!!!'],
580
+ value = 'SadTalker', label = '数字人模型选择')
581
+ talker_method.change(fn = talker_model_change, inputs=[talker_method], outputs = [talker_method])
582
+ llm_method = gr.Dropdown(choices = ['Qwen', 'Qwen2', 'Linly', 'Gemini', 'ChatGLM', 'ChatGPT', 'GPT4Free', '直接回复 Direct Reply', 'Comming Soon!!!'], value = '直接回复 Direct Reply', label = 'LLM 模型选择')
583
+ llm_method.change(fn = llm_model_change, inputs=[llm_method], outputs = [llm_method])
584
+ return (source_image, voice, rate, volume, pitch,
585
+ am, voc, lang, male,
586
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
587
+ tts_method, batch_size, character, talker_method, asr_method, llm_method)
588
+
589
+
590
+ def exmaple_setting(asr, text, character, talk , tts, voice, llm):
591
+ # 默认text的Example
592
+ examples = [
593
+ ['Whisper-base', '应对压力最有效的方法是什么?', '女性角色', 'SadTalker', 'Edge-TTS', 'zh-CN-XiaoxiaoNeural', '直接回复 Direct Reply'],
594
+ ['Whisper-tiny', '应对压力最有效的方法是什么?', '女性角色', 'SadTalker', 'PaddleTTS', 'None', '直接回复 Direct Reply'],
595
+ ['Whisper-base', '应对压力最有效的方法是什么?', '女性角色', 'SadTalker', 'Edge-TTS', 'zh-CN-XiaoxiaoNeural', 'Qwen'],
596
+ ['FunASR', '如何进行时间管理?','男性角色', 'SadTalker', 'Edge-TTS', 'zh-CN-YunyangNeural', 'Qwen'],
597
+ ['Whisper-tiny', '为什么有些人选择使用纸质地图或寻求方向,而不是依赖GPS设备或智能手机应用程序?','女性角色', 'Wav2Lip', 'PaddleTTS', 'None', 'Qwen'],
598
+ ]
599
+
600
+ with gr.Row(variant='panel'):
601
+ with gr.Column(variant='panel'):
602
+ gr.Markdown("## Test Examples")
603
+ gr.Examples(
604
+ examples = examples,
605
+ inputs = [asr, text, character, talk , tts, voice, llm],
606
+ )
607
+ def app():
608
+ with gr.Blocks(analytics_enabled=False, title = 'Linly-Talker') as inference:
609
+ gr.HTML(get_title("Linly 智能对话系统 (Linly-Talker) 文本/语音对话"))
610
+ with gr.Row(equal_height=False):
611
+ with gr.Column(variant='panel'):
612
+ (source_image, voice, rate, volume, pitch,
613
+ am, voc, lang, male,
614
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
615
+ tts_method, batch_size, character, talker_method, asr_method, llm_method)= webui_setting()
616
+
617
+
618
+ with gr.Column(variant='panel'):
619
+ with gr.Tabs():
620
+ with gr.TabItem('对话'):
621
+ with gr.Group():
622
+ question_audio = gr.Audio(sources=['microphone','upload'], type="filepath", label = '语音对话')
623
+ input_text = gr.Textbox(label="输入文字/问题", lines=3)
624
+ asr_text = gr.Button('语音识别(语音对话后点击)')
625
+ asr_text.click(fn=Asr,inputs=[question_audio],outputs=[input_text])
626
+ # with gr.TabItem('SadTalker数字人参数设置'):
627
+ # with gr.Accordion("Advanced Settings",
628
+ # open=False):
629
+ # gr.Markdown("SadTalker: need help? please visit our [[best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md)] for more detials")
630
+ # with gr.Column(variant='panel'):
631
+ # # width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
632
+ # # height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
633
+ # with gr.Row():
634
+ # pose_style = gr.Slider(minimum=0, maximum=45, step=1, label="Pose style", value=0) #
635
+ # exp_weight = gr.Slider(minimum=0, maximum=3, step=0.1, label="expression scale", value=1) #
636
+ # blink_every = gr.Checkbox(label="use eye blink", value=True)
637
+
638
+ # with gr.Row():
639
+ # size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model? 256 is faster") #
640
+ # preprocess_type = gr.Radio(['crop', 'resize','full'], value='full', label='preprocess', info="How to handle input image?")
641
+
642
+ # with gr.Row():
643
+ # is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion, works with preprocess `full`)")
644
+ # facerender = gr.Radio(['facevid2vid'], value='facevid2vid', label='facerender', info="which face render?")
645
+
646
+ # with gr.Row():
647
+ # # batch_size = gr.Slider(label="batch size in generation", step=1, maximum=10, value=1)
648
+ # fps = gr.Slider(label='fps in generation', step=1, maximum=30, value =20)
649
+ # enhancer = gr.Checkbox(label="GFPGAN as Face enhancer(slow)")
650
+ with gr.Tabs():
651
+ with gr.TabItem('数字人问答'):
652
+ gen_video = gr.Video(label="生成视频", format="mp4", autoplay=False)
653
+ video_button = gr.Button("🎬 生成数字人视频", variant='primary')
654
+ video_button.click(fn=Talker_response,inputs=[question_audio, talker_method, input_text, voice, rate, volume, pitch,
655
+ am, voc, lang, male,
656
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
657
+ tts_method, batch_size, character],outputs=[gen_video])
658
+ exmaple_setting(asr_method, input_text, character, talker_method, tts_method, voice, llm_method)
659
+ return inference
660
+
661
+ def app_multi():
662
+ with gr.Blocks(analytics_enabled=False, title = 'Linly-Talker') as inference:
663
+ gr.HTML(get_title("Linly 智能对话系统 (Linly-Talker) 多轮GPT对话"))
664
+ with gr.Row():
665
+ with gr.Column():
666
+ (source_image, voice, rate, volume, pitch,
667
+ am, voc, lang, male,
668
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
669
+ tts_method, batch_size, character, talker_method, asr_method, llm_method)= webui_setting()
670
+ video = gr.Video(label = '数字人问答', scale = 0.5)
671
+ video_button = gr.Button("🎬 生成数字人视频(对话后)", variant = 'primary')
672
+
673
+ with gr.Column():
674
+ with gr.Tabs(elem_id="sadtalker_checkbox"):
675
+ with gr.TabItem('SadTalker数字人参数设置'):
676
+ with gr.Accordion("Advanced Settings",
677
+ open=False):
678
+ gr.Markdown("SadTalker: need help? please visit our [[best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md)] for more detials")
679
+ with gr.Column(variant='panel'):
680
+ # width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
681
+ # height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
682
+ with gr.Row():
683
+ pose_style = gr.Slider(minimum=0, maximum=45, step=1, label="Pose style", value=0) #
684
+ exp_weight = gr.Slider(minimum=0, maximum=3, step=0.1, label="expression scale", value=1) #
685
+ blink_every = gr.Checkbox(label="use eye blink", value=True)
686
+
687
+ with gr.Row():
688
+ size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model? 256 is faster") #
689
+ preprocess_type = gr.Radio(['crop', 'resize','full', 'extcrop', 'extfull'], value='crop', label='preprocess', info="How to handle input image?")
690
+
691
+ with gr.Row():
692
+ is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion, works with preprocess `full`)")
693
+ facerender = gr.Radio(['facevid2vid'], value='facevid2vid', label='facerender', info="which face render?")
694
+
695
+ with gr.Row():
696
+ fps = gr.Slider(label='fps in generation', step=1, maximum=30, value =20)
697
+ enhancer = gr.Checkbox(label="GFPGAN as Face enhancer(slow)")
698
+ with gr.Row():
699
+ with gr.Column(scale=3):
700
+ system_input = gr.Textbox(value=default_system, lines=1, label='System (设定角色)')
701
+ with gr.Column(scale=1):
702
+ modify_system = gr.Button("🛠️ 设置system并清除历史对话", scale=2)
703
+ system_state = gr.Textbox(value=default_system, visible=False)
704
+
705
+ chatbot = gr.Chatbot(height=400, show_copy_button=True)
706
+ with gr.Group():
707
+ question_audio = gr.Audio(sources=['microphone','upload'], type="filepath", label='语音对话', autoplay=False)
708
+ asr_text = gr.Button('🎤 语音识别(语音对话后点击)')
709
+
710
+ # 创建一个文本框组件,用于输入 prompt。
711
+ msg = gr.Textbox(label="Prompt/问题")
712
+ asr_text.click(fn=Asr,inputs=[question_audio],outputs=[msg])
713
+
714
+ with gr.Row():
715
+ clear_history = gr.Button("🧹 清除历史对话")
716
+ sumbit = gr.Button("🚀 发送", variant = 'primary')
717
+
718
+ # 设置按钮的点击事件。当点击时,调用上面定义的 函数,并传入用户的消息和聊天历史记录,然后更新文本框和聊天机器人组件。
719
+ sumbit.click(chat_response, inputs=[system_input, msg, chatbot],
720
+ outputs=[msg, chatbot])
721
+
722
+ # 点击后清空后端存储的聊天记录
723
+ clear_history.click(fn = clear_session, outputs = [msg, chatbot])
724
+
725
+ # 设置system并清除历史对话
726
+ modify_system.click(fn=modify_system_session,
727
+ inputs=[system_input],
728
+ outputs=[system_state, system_input, chatbot])
729
+
730
+ video_button.click(fn = human_response, inputs = [source_image, chatbot, question_audio, talker_method, voice, rate, volume, pitch,
731
+ am, voc, lang, male,
732
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
733
+ tts_method, character,preprocess_type,
734
+ is_still_mode, enhancer, batch_size, size_of_image,
735
+ pose_style, facerender, exp_weight, blink_every, fps], outputs = [video])
736
+
737
+ exmaple_setting(asr_method, msg, character, talker_method, tts_method, voice, llm_method)
738
+ return inference
739
+
740
+ def app_img():
741
+ with gr.Blocks(analytics_enabled=False, title = 'Linly-Talker') as inference:
742
+ gr.HTML(get_title("Linly 智能对话系统 (Linly-Talker) 个性化角色互动"))
743
+ with gr.Row(equal_height=False):
744
+ with gr.Column(variant='panel'):
745
+ (source_image, voice, rate, volume, pitch,
746
+ am, voc, lang, male,
747
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
748
+ tts_method, batch_size, character, talker_method, asr_method, llm_method)= webui_setting()
749
+
750
+ # driven_audio = 'answer.wav'
751
+ with gr.Column(variant='panel'):
752
+ with gr.Tabs():
753
+ with gr.TabItem('对话'):
754
+ with gr.Group():
755
+ question_audio = gr.Audio(sources=['microphone','upload'], type="filepath", label = '语音对话')
756
+ input_text = gr.Textbox(label="输入文字/问题", lines=3)
757
+ asr_text = gr.Button('语音识别(语音对话后点击)')
758
+ asr_text.click(fn=Asr,inputs=[question_audio],outputs=[input_text])
759
+ with gr.Tabs(elem_id="text_examples"):
760
+ gr.Markdown("## Text Examples")
761
+ examples = [
762
+ ['应对压力最有效���方法是什么?'],
763
+ ['如何进行时间管理?'],
764
+ ['为什么有些人选择使用纸质地图或寻求方向,而不是依赖GPS设备或智能手机应用程序?'],
765
+ ]
766
+ gr.Examples(
767
+ examples = examples,
768
+ inputs = [input_text],
769
+ )
770
+ with gr.Tabs(elem_id="sadtalker_checkbox"):
771
+ with gr.TabItem('SadTalker数字人参数设置'):
772
+ with gr.Accordion("Advanced Settings",
773
+ open=False):
774
+ gr.Markdown("SadTalker: need help? please visit our [[best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md)] for more detials")
775
+ with gr.Column(variant='panel'):
776
+ # width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
777
+ # height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
778
+ with gr.Row():
779
+ pose_style = gr.Slider(minimum=0, maximum=45, step=1, label="Pose style", value=0) #
780
+ exp_weight = gr.Slider(minimum=0, maximum=3, step=0.1, label="expression scale", value=1) #
781
+ blink_every = gr.Checkbox(label="use eye blink", value=True)
782
+
783
+ with gr.Row():
784
+ size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model? 256 is faster") #
785
+ preprocess_type = gr.Radio(['crop', 'resize','full', 'extcrop', 'extfull'], value='crop', label='preprocess', info="How to handle input image?")
786
+
787
+ with gr.Row():
788
+ is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion, works with preprocess `full`)")
789
+ facerender = gr.Radio(['facevid2vid'], value='facevid2vid', label='facerender', info="which face render?")
790
+
791
+ with gr.Row():
792
+ fps = gr.Slider(label='fps in generation', step=1, maximum=30, value =20)
793
+ enhancer = gr.Checkbox(label="GFPGAN as Face enhancer(slow)")
794
+
795
+ with gr.Tabs(elem_id="sadtalker_genearted"):
796
+ gen_video = gr.Video(label="数字人视频", format="mp4")
797
+
798
+ submit = gr.Button('🎬 生成数字人视频', elem_id="sadtalker_generate", variant='primary')
799
+ submit.click(
800
+ fn=Talker_response_img,
801
+ inputs=[question_audio,
802
+ talker_method,
803
+ input_text,
804
+ voice, rate, volume, pitch,
805
+ am, voc, lang, male,
806
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
807
+ tts_method,
808
+ source_image,
809
+ preprocess_type,
810
+ is_still_mode,
811
+ enhancer,
812
+ batch_size,
813
+ size_of_image,
814
+ pose_style,
815
+ facerender,
816
+ exp_weight,
817
+ blink_every,
818
+ fps],
819
+ outputs=[gen_video]
820
+ )
821
+
822
+ with gr.Row():
823
+ examples = [
824
+ [
825
+ 'examples/source_image/full_body_2.png', 'SadTalker',
826
+ 'crop',
827
+ False,
828
+ False
829
+ ],
830
+ [
831
+ 'examples/source_image/full_body_1.png', 'SadTalker',
832
+ 'full',
833
+ True,
834
+ False
835
+ ],
836
+ [
837
+ 'examples/source_image/full4.jpeg', 'SadTalker',
838
+ 'crop',
839
+ False,
840
+ True
841
+ ],
842
+ ]
843
+ gr.Examples(examples=examples,
844
+ inputs=[
845
+ source_image, talker_method,
846
+ preprocess_type,
847
+ is_still_mode,
848
+ enhancer],
849
+ outputs=[gen_video],
850
+ # cache_examples=True,
851
+ )
852
+ return inference
853
+
854
+ def app_vits():
855
+ with gr.Blocks(analytics_enabled=False, title = 'Linly-Talker') as inference:
856
+ gr.HTML(get_title("Linly 智能对话系统 (Linly-Talker) 语音克隆"))
857
+ with gr.Row(equal_height=False):
858
+ with gr.Column(variant='panel'):
859
+ (source_image, voice, rate, volume, pitch,
860
+ am, voc, lang, male,
861
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
862
+ tts_method, batch_size, character, talker_method, asr_method, llm_method)= webui_setting()
863
+ with gr.Column(variant='panel'):
864
+ with gr.Tabs():
865
+ with gr.TabItem('对话'):
866
+ with gr.Group():
867
+ question_audio = gr.Audio(sources=['microphone','upload'], type="filepath", label = '语音对话')
868
+ input_text = gr.Textbox(label="输入文字/问题", lines=3)
869
+ asr_text = gr.Button('语音识别(语音对话后点击)')
870
+ asr_text.click(fn=Asr,inputs=[question_audio],outputs=[input_text])
871
+ with gr.Tabs():
872
+ with gr.TabItem('数字人问答'):
873
+ gen_video = gr.Video(label="数字人视频", format="mp4", autoplay=False)
874
+ video_button = gr.Button("🎬 生成数字人视频", variant='primary')
875
+ video_button.click(fn=Talker_response,inputs=[question_audio, talker_method, input_text, voice, rate, volume, pitch, am, voc, lang, male,
876
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
877
+ tts_method, batch_size, character],outputs=[gen_video])
878
+ exmaple_setting(asr_method, input_text, character, talker_method, tts_method, voice, llm_method)
879
+ return inference
880
+
881
+ def app_talk():
882
+ with gr.Blocks(analytics_enabled=False, title = 'Linly-Talker') as inference:
883
+ gr.HTML(get_title("Linly 智能对话系统 (Linly-Talker) 数字人播报"))
884
+ with gr.Row(equal_height=False):
885
+ with gr.Column(variant='panel'):
886
+ with gr.Tabs():
887
+ with gr.Tab("图片人物"):
888
+ source_image = gr.Image(label='Source image', type = 'filepath')
889
+
890
+ with gr.Tab("视频人物"):
891
+ source_video = gr.Video(label="Source video")
892
+
893
+ (_, voice, rate, volume, pitch,
894
+ am, voc, lang, male,
895
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
896
+ tts_method, batch_size, character, talker_method, asr_method, llm_method)= webui_setting()
897
+
898
+ with gr.Column(variant='panel'):
899
+ with gr.Tabs():
900
+ with gr.TabItem('对话'):
901
+ with gr.Group():
902
+ question_audio = gr.Audio(sources=['microphone','upload'], type="filepath", label = '语音对话')
903
+ input_text = gr.Textbox(label="输入文字/问题", lines=3)
904
+ asr_text = gr.Button('语音识别(语音对话后点击)')
905
+ asr_text.click(fn=Asr,inputs=[question_audio],outputs=[input_text])
906
+ with gr.Tabs():
907
+ with gr.TabItem('SadTalker数字人参数设置'):
908
+ with gr.Accordion("Advanced Settings",
909
+ open=False):
910
+ gr.Markdown("SadTalker: need help? please visit our [[best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md)] for more detials")
911
+ with gr.Column(variant='panel'):
912
+ # width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
913
+ # height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
914
+ with gr.Row():
915
+ pose_style = gr.Slider(minimum=0, maximum=45, step=1, label="Pose style", value=0) #
916
+ exp_weight = gr.Slider(minimum=0, maximum=3, step=0.1, label="expression scale", value=1) #
917
+ blink_every = gr.Checkbox(label="use eye blink", value=True)
918
+
919
+ with gr.Row():
920
+ size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model? 256 is faster") #
921
+ preprocess_type = gr.Radio(['crop', 'resize','full'], value='full', label='preprocess', info="How to handle input image?")
922
+
923
+ with gr.Row():
924
+ is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion, works with preprocess `full`)")
925
+ facerender = gr.Radio(['facevid2vid'], value='facevid2vid', label='facerender', info="which face render?")
926
+
927
+ with gr.Row():
928
+ # batch_size = gr.Slider(label="batch size in generation", step=1, maximum=10, value=1)
929
+ fps = gr.Slider(label='fps in generation', step=1, maximum=30, value =20)
930
+ enhancer = gr.Checkbox(label="GFPGAN as Face enhancer(slow)")
931
+
932
+ with gr.Tabs():
933
+ gen_video = gr.Video(label="数字人视频", format="mp4")
934
+
935
+ video_button = gr.Button('🎬 生成数字人视频', elem_id="sadtalker_generate", variant='primary')
936
+
937
+ video_button.click(fn=Talker_Say,inputs=[preprocess_type, is_still_mode, enhancer, batch_size, size_of_image,
938
+ pose_style, facerender, exp_weight, blink_every, fps,
939
+ source_image, source_video, question_audio, talker_method, input_text, voice, rate, volume, pitch, am, voc, lang, male,
940
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
941
+ tts_method, character],outputs=[gen_video])
942
+
943
+ with gr.Row():
944
+ with gr.Column(variant='panel'):
945
+ gr.Markdown("## Test Examples")
946
+ gr.Examples(
947
+ examples = [
948
+ [
949
+ 'examples/source_image/full_body_2.png',
950
+ '应对压力最有效的方法是什么?',
951
+ ],
952
+ [
953
+ 'examples/source_image/full_body_1.png',
954
+ '如何进行时间管理?',
955
+ ],
956
+ [
957
+ 'examples/source_image/full3.png',
958
+ '为什么有些人选择使用纸质地图或寻求方向,而不是依赖GPS设备或智能手机应用程序?',
959
+ ],
960
+ ],
961
+ fn = Talker_Say,
962
+ inputs = [source_image, input_text],
963
+ )
964
+ return inference
965
+
966
+ def load_musetalk_model():
967
+ gr.Warning("若显存不足,可能会导致模型加载失败,可以尝试使用其他摸型或者换其他设备尝试。")
968
+ gr.Info("MuseTalk模型导入中...")
969
+ musetalker.init_model()
970
+ gr.Info("MuseTalk模型导入成功")
971
+ return "MuseTalk模型导入成功"
972
+ def musetalk_prepare_material(source_video, bbox_shift):
973
+ if musetalker.load is False:
974
+ gr.Warning("请先加载MuseTalk模型后重新上传文件")
975
+ return source_video, None
976
+ return musetalker.prepare_material(source_video, bbox_shift)
977
+ def app_muse():
978
+ with gr.Blocks(analytics_enabled=False, title = 'Linly-Talker') as inference:
979
+ gr.HTML(get_title("Linly 智能对话系统 (Linly-Talker) MuseTalker数字人实时对话"))
980
+ with gr.Row(equal_height=False):
981
+ with gr.Column(variant='panel'):
982
+ with gr.TabItem('MuseV Video'):
983
+ gr.Markdown("MuseV: need help? please visit MuseVDemo to generate Video https://huggingface.co/spaces/AnchorFake/MuseVDemo")
984
+ with gr.Row():
985
+ source_video = gr.Video(label="Reference Video",sources=['upload'])
986
+ gr.Markdown("BBox_shift 推荐值下限,在生成初始结果后生成相应的 bbox 范围。如果结果不理想,可以根据该参考值进行调整。\n一般来说,在我们的实验观察中,我们发现正值(向下半部分移动)通常会增加嘴巴的张开度,而负值(向上半部分移动)通常会减少嘴巴的张开度。然而,需要注意的是,这并不是绝对的规则,用户可能需要根据他们的具体需求和期望效果来调整该参数。")
987
+ with gr.Row():
988
+ bbox_shift = gr.Number(label="BBox_shift value, px", value=0)
989
+ bbox_shift_scale = gr.Textbox(label="bbox_shift_scale",
990
+ value="",interactive=False)
991
+ load_musetalk = gr.Button("加载MuseTalk模型(传入视频前先加载)", variant='primary')
992
+ load_musetalk.click(fn=load_musetalk_model, outputs=bbox_shift_scale)
993
+
994
+ # (_, voice, rate, volume, pitch,
995
+ # am, voc, lang, male,
996
+ # inp_ref, prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
997
+ # tts_method, batch_size, character, talker_method, asr_method, llm_method)= webui_setting()
998
+ with gr.Tabs("TTS Method"):
999
+ with gr.Accordion("TTS Method语音方法调节 ", open=True):
1000
+ with gr.Tab("Edge-TTS"):
1001
+ voice = gr.Dropdown(edgetts.SUPPORTED_VOICE,
1002
+ value='zh-CN-XiaoxiaoNeural',
1003
+ label="Voice 声音选择")
1004
+ rate = gr.Slider(minimum=-100,
1005
+ maximum=100,
1006
+ value=0,
1007
+ step=1.0,
1008
+ label='Rate 速率')
1009
+ volume = gr.Slider(minimum=0,
1010
+ maximum=100,
1011
+ value=100,
1012
+ step=1,
1013
+ label='Volume 音量')
1014
+ pitch = gr.Slider(minimum=-100,
1015
+ maximum=100,
1016
+ value=0,
1017
+ step=1,
1018
+ label='Pitch 音调')
1019
+ with gr.Tab("PaddleTTS"):
1020
+ am = gr.Dropdown(["FastSpeech2"], label="声学模型选择", value = 'FastSpeech2')
1021
+ voc = gr.Dropdown(["PWGan", "HifiGan"], label="声码器选择", value = 'PWGan')
1022
+ lang = gr.Dropdown(["zh", "en", "mix", "canton"], label="语言选择", value = 'zh')
1023
+ male = gr.Checkbox(label="男声(Male)", value=False)
1024
+ with gr.Tab('GPT-SoVITS'):
1025
+ with gr.Row():
1026
+ gpt_path = gr.FileExplorer(root = GPT_SoVITS_ckpt, glob = "*.ckpt", value = "s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", file_count='single', label="GPT模型路径")
1027
+ sovits_path = gr.FileExplorer(root = GPT_SoVITS_ckpt, glob = "*.pth", value = "s2G488k.pth", file_count='single', label="SoVITS模型路径")
1028
+ # gpt_path = gr.Dropdown(choices=list_models(GPT_SoVITS_ckpt, 'ckpt'))
1029
+ # sovits_path = gr.Dropdown(choices=list_models(GPT_SoVITS_ckpt, 'pth'))
1030
+ # gpt_path = gr.Textbox(label="GPT模型路径",
1031
+ # value="GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
1032
+ # sovits_path = gr.Textbox(label="SoVITS模型路径",
1033
+ # value="GPT_SoVITS/pretrained_models/s2G488k.pth")
1034
+ button = gr.Button("加载模型")
1035
+ button.click(fn = load_vits_model,
1036
+ inputs=[gpt_path, sovits_path],
1037
+ outputs=[gpt_path, sovits_path])
1038
+
1039
+ with gr.Row():
1040
+ inp_ref = gr.Audio(label="请上传3~10秒内参考音频,超过会报错!", sources=["microphone", "upload"], type="filepath")
1041
+ use_mic_voice = gr.Checkbox(label="使用语音问答的麦克风")
1042
+ prompt_text = gr.Textbox(label="参考音频的文本", value="")
1043
+ prompt_language = gr.Dropdown(
1044
+ label="参考音频的语种", choices=["中文", "英文", "日文"], value="中文"
1045
+ )
1046
+ asr_button = gr.Button("语音识别 - 克隆参考音频")
1047
+ asr_button.click(fn=Asr,inputs=[inp_ref],outputs=[prompt_text])
1048
+ with gr.Row():
1049
+ text_language = gr.Dropdown(
1050
+ label="需要合成的语种", choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"], value="中文"
1051
+ )
1052
+
1053
+ how_to_cut = gr.Dropdown(
1054
+ label="怎么切",
1055
+ choices=["不切", "凑四句一切", "凑50字一切", "按中文句号。切", "按英文句号.切", "按标点符号切" ],
1056
+ value="凑四句一切",
1057
+ interactive=True,
1058
+ )
1059
+
1060
+ with gr.Column(variant='panel'):
1061
+ batch_size = gr.Slider(minimum=1,
1062
+ maximum=10,
1063
+ value=2,
1064
+ step=1,
1065
+ label='Talker Batch size')
1066
+
1067
+ tts_method = gr.Radio(['Edge-TTS', 'PaddleTTS', 'GPT-SoVITS克隆声音', 'Comming Soon!!!'], label="Text To Speech Method",
1068
+ value = 'Edge-TTS')
1069
+ tts_method.change(fn = tts_model_change, inputs=[tts_method], outputs = [tts_method])
1070
+ asr_method = gr.Radio(choices = ['Whisper-tiny', 'Whisper-base', 'FunASR', 'Comming Soon!!!'], value='Whisper-base', label = '语音识别模型选择')
1071
+ asr_method.change(fn = asr_model_change, inputs=[asr_method], outputs = [asr_method])
1072
+ llm_method = gr.Dropdown(choices = ['Qwen', 'Qwen2', 'Linly', 'Gemini', 'ChatGLM', 'ChatGPT', 'GPT4Free', '直接回复 Direct Reply', 'Comming Soon!!!'], value = '直接回复 Direct Reply', label = 'LLM 模型选择')
1073
+ llm_method.change(fn = llm_model_change, inputs=[llm_method], outputs = [llm_method])
1074
+
1075
+ source_video.change(fn=musetalk_prepare_material, inputs=[source_video, bbox_shift], outputs=[source_video, bbox_shift_scale])
1076
+
1077
+ with gr.Column(variant='panel'):
1078
+ with gr.Tabs():
1079
+ with gr.TabItem('对话'):
1080
+ with gr.Group():
1081
+ question_audio = gr.Audio(sources=['microphone','upload'], type="filepath", label = '语音对话')
1082
+ input_text = gr.Textbox(label="输入文字/问题", lines=3)
1083
+ asr_text = gr.Button('语音识别(语音对话后点击)')
1084
+ asr_text.click(fn=Asr,inputs=[question_audio],outputs=[input_text])
1085
+
1086
+ with gr.TabItem("MuseTalk Video"):
1087
+ gen_video = gr.Video(label="数字人视频", format="mp4")
1088
+ submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
1089
+ examples = [os.path.join('Musetalk/data/video', video) for video in os.listdir("Musetalk/data/video")]
1090
+ # ['Musetalk/data/video/yongen_musev.mp4', 'Musetalk/data/video/musk_musev.mp4', 'Musetalk/data/video/monalisa_musev.mp4', 'Musetalk/data/video/sun_musev.mp4', 'Musetalk/data/video/seaside4_musev.mp4', 'Musetalk/data/video/sit_musev.mp4', 'Musetalk/data/video/man_musev.mp4']
1091
+
1092
+ gr.Markdown("## MuseV Video Examples")
1093
+ gr.Examples(
1094
+ examples=[
1095
+ ['Musetalk/data/video/yongen_musev.mp4', 5],
1096
+ ['Musetalk/data/video/musk_musev.mp4', 5],
1097
+ ['Musetalk/data/video/monalisa_musev.mp4', 5],
1098
+ ['Musetalk/data/video/sun_musev.mp4', 5],
1099
+ ['Musetalk/data/video/seaside4_musev.mp4', 5],
1100
+ ['Musetalk/data/video/sit_musev.mp4', 5],
1101
+ ['Musetalk/data/video/man_musev.mp4', 5]
1102
+ ],
1103
+ inputs =[source_video, bbox_shift],
1104
+ )
1105
+
1106
+ submit.click(
1107
+ fn=MuseTalker_response,
1108
+ inputs=[source_video, bbox_shift, question_audio, input_text, voice, rate, volume, pitch, am, voc, lang, male,
1109
+ inp_ref, prompt_text, prompt_language, text_language, how_to_cut, use_mic_voice,
1110
+ tts_method, batch_size],
1111
+ outputs=[gen_video]
1112
+ )
1113
+ return inference
1114
+ def asr_model_change(model_name, progress=gr.Progress(track_tqdm=True)):
1115
+ global asr
1116
+
1117
+ # 清理显存,在加载新的模型之前释放不必要的显存
1118
+ clear_memory()
1119
+
1120
+ if model_name == "Whisper-tiny":
1121
+ try:
1122
+ if os.path.exists('Whisper/tiny.pt'):
1123
+ asr = WhisperASR('Whisper/tiny.pt')
1124
+ else:
1125
+ asr = WhisperASR('tiny')
1126
+ gr.Info("Whisper-tiny模型导入成功")
1127
+ except Exception as e:
1128
+ gr.Warning(f"Whisper-tiny模型下载失败 {e}")
1129
+ elif model_name == "Whisper-base":
1130
+ try:
1131
+ if os.path.exists('Whisper/base.pt'):
1132
+ asr = WhisperASR('Whisper/base.pt')
1133
+ else:
1134
+ asr = WhisperASR('base')
1135
+ gr.Info("Whisper-base模型导入成功")
1136
+ except Exception as e:
1137
+ gr.Warning(f"Whisper-base模型下载失败 {e}")
1138
+ elif model_name == 'FunASR':
1139
+ try:
1140
+ from ASR import FunASR
1141
+ asr = FunASR()
1142
+ gr.Info("FunASR模型导入成功")
1143
+ except Exception as e:
1144
+ gr.Warning(f"FunASR模型下载失败 {e}")
1145
+ else:
1146
+ gr.Warning("未知ASR模型,可提issue和PR 或者 建议更新模型")
1147
+ return model_name
1148
+
1149
+ def llm_model_change(model_name, progress=gr.Progress(track_tqdm=True)):
1150
+ global llm
1151
+ gemini_apikey = ""
1152
+ openai_apikey = ""
1153
+ proxy_url = None
1154
+
1155
+ # 清理显存,在加载新的模型之前释放不必要的显存
1156
+ clear_memory()
1157
+
1158
+ if model_name == 'Linly':
1159
+ try:
1160
+ llm = llm_class.init_model('Linly', 'Linly-AI/Chinese-LLaMA-2-7B-hf', prefix_prompt=prefix_prompt)
1161
+ gr.Info("Linly模型导入成功")
1162
+ except Exception as e:
1163
+ gr.Warning(f"Linly模型下载失败 {e}")
1164
+ elif model_name == 'Qwen':
1165
+ try:
1166
+ llm = llm_class.init_model('Qwen', 'Qwen/Qwen-1_8B-Chat', prefix_prompt=prefix_prompt)
1167
+ gr.Info("Qwen模型导入成功")
1168
+ except Exception as e:
1169
+ gr.Warning(f"Qwen模型下载失败 {e}")
1170
+ elif model_name == 'Qwen2':
1171
+ try:
1172
+ llm = llm_class.init_model('Qwen2', 'Qwen/Qwen1.5-0.5B-Chat', prefix_prompt=prefix_prompt)
1173
+ gr.Info("Qwen2模型导入成功")
1174
+ except Exception as e:
1175
+ gr.Warning(f"Qwen2模型下载失败 {e}")
1176
+ elif model_name == 'Gemini':
1177
+ if gemini_apikey:
1178
+ llm = llm_class.init_model('Gemini', 'gemini-pro', gemini_apikey, proxy_url)
1179
+ gr.Info("Gemini模型导入成功")
1180
+ else:
1181
+ gr.Warning("请填写Gemini的api_key")
1182
+ elif model_name == 'ChatGLM':
1183
+ try:
1184
+ llm = llm_class.init_model('ChatGLM', 'THUDM/chatglm3-6b', prefix_prompt=prefix_prompt)
1185
+ gr.Info("ChatGLM模型导入成功")
1186
+ except Exception as e:
1187
+ gr.Warning(f"ChatGLM模型导入失败 {e}")
1188
+ elif model_name == 'ChatGPT':
1189
+ if openai_apikey:
1190
+ llm = llm_class.init_model('ChatGPT', api_key=openai_apikey, proxy_url=proxy_url, prefix_prompt=prefix_prompt)
1191
+ else:
1192
+ gr.Warning("请填写OpenAI的api_key")
1193
+ elif model_name == '直接回复 Direct Reply':
1194
+ llm =llm_class.init_model(model_name)
1195
+ gr.Info("直接回复,不实用LLM模型")
1196
+ elif model_name == 'GPT4Free':
1197
+ try:
1198
+ llm = llm_class.init_model('GPT4Free', prefix_prompt=prefix_prompt)
1199
+ gr.Info("GPT4Free模型导入成功, 请注意GPT4Free可能不稳定")
1200
+ except Exception as e:
1201
+ gr.Warning(f"GPT4Free模型下载失败 {e}")
1202
+ else:
1203
+ gr.Warning("未知LLM模型,可提issue和PR 或者 建议更新模型")
1204
+ return model_name
1205
+
1206
+ def talker_model_change(model_name, progress=gr.Progress(track_tqdm=True)):
1207
+ global talker
1208
+
1209
+ # 清理显存,在加载新的模型之前释放不必要的显存
1210
+ clear_memory()
1211
+
1212
+ if model_name not in ['SadTalker', 'Wav2Lip', 'NeRFTalk']:
1213
+ gr.Warning("其他模型还未集成,请等待")
1214
+ if model_name == 'SadTalker':
1215
+ try:
1216
+ from TFG import SadTalker
1217
+ talker = SadTalker(lazy_load=True)
1218
+ gr.Info("SadTalker模型导入成功")
1219
+ except Exception as e:
1220
+ gr.Warning("SadTalker模型加载失败", e)
1221
+ elif model_name == 'Wav2Lip':
1222
+ try:
1223
+ from TFG import Wav2Lip
1224
+ clear_memory()
1225
+ talker = Wav2Lip("checkpoints/wav2lip_gan.pth")
1226
+ gr.Info("Wav2Lip模型导入成功")
1227
+ except Exception as e:
1228
+ gr.Warning("Wav2Lip模型加载失败", e)
1229
+ elif model_name == 'NeRFTalk':
1230
+ try:
1231
+ from TFG import ERNeRF
1232
+ talker = ERNeRF()
1233
+ talker.init_model('checkpoints/Obama_ave.pth', 'checkpoints/Obama.json')
1234
+ gr.Info("NeRFTalk模型导入成功")
1235
+ gr.Warning("NeRFTalk模型是针对单个人进行训练的,内置了奥班马Obama的模型,上传图片无效")
1236
+ except Exception as e:
1237
+ gr.Warning("NeRFTalk模型加载失败", e)
1238
+ else:
1239
+ gr.Warning("未知TFG模型,可提issue和PR 或者 建议更新模型")
1240
+ return model_name
1241
+
1242
+ def tts_model_change(model_name, progress=gr.Progress(track_tqdm=True)):
1243
+ global tts
1244
+
1245
+ # 清理显存,在加载新的模型之前释放不必要的显存
1246
+ clear_memory()
1247
+
1248
+ if model_name == 'Edge-TTS':
1249
+ # tts = EdgeTTS()
1250
+ if edgetts.network:
1251
+ gr.Info("EdgeTTS模型导入成功")
1252
+ else:
1253
+ gr.Warning("EdgeTTS模型加载失败,请检查网络是否正常连接,否则无法使用")
1254
+ elif model_name == 'PaddleTTS':
1255
+ try:
1256
+ from TTS import PaddleTTS
1257
+ tts = PaddleTTS()
1258
+ gr.Info("PaddleTTS模型导入成功")
1259
+ except Exception as e:
1260
+ gr.Warning(f"PaddleTTS模型下载失败 {e}")
1261
+ elif model_name == 'GPT-SoVITS克隆声音':
1262
+ try:
1263
+ gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
1264
+ sovits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
1265
+ vits.load_model(gpt_path, sovits_path)
1266
+ gr.Info("模型加载成功")
1267
+ except Exception as e:
1268
+ gr.Warning(f"模型加载失败 {e}")
1269
+ gr.Warning("注意注意⚠️:GPT-SoVITS要上传参考音频进行克隆,请点击TTS Method语音方法调节操作")
1270
+ else:
1271
+ gr.Warning("未知TTS模型,可提issue和PR 或者 建议更新模型")
1272
+ return model_name
1273
+
1274
+ def success_print(text):
1275
+ print(f"\033[1;32;40m{text}\033[0m")
1276
+
1277
+ def error_print(text):
1278
+ print(f"\033[1;31;40m{text}\033[0m")
1279
+
1280
+ if __name__ == "__main__":
1281
+ llm_class = LLM(mode='offline')
1282
+ llm = llm_class.init_model('直接回复 Direct Reply')
1283
+ success_print("默认不使用LLM模型,直接回复问题,同时减少显存占用!")
1284
+
1285
+ try:
1286
+ from VITS import *
1287
+ vits = GPT_SoVITS()
1288
+ success_print("Success!!! GPT-SoVITS模块加载成功,语音克隆默认使用GPT-SoVITS模型")
1289
+ # gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
1290
+ # sovits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
1291
+ # vits.load_model(gpt_path, sovits_path)
1292
+ except Exception as e:
1293
+ error_print(f"GPT-SoVITS Error: {e}")
1294
+ error_print("如果使用VITS,请先下载GPT-SoVITS模型和安装环境")
1295
+
1296
+ try:
1297
+ from TFG import SadTalker
1298
+ talker = SadTalker(lazy_load=True)
1299
+ success_print("Success!!! SadTalker模块加载成功,默认使用SadTalker模型")
1300
+ except Exception as e:
1301
+ error_print(f"SadTalker Error: {e}")
1302
+ error_print("如果使用SadTalker,请先下载SadTalker模型")
1303
+
1304
+ try:
1305
+ from ASR import WhisperASR
1306
+ if os.path.exists('Whisper/base.pt'):
1307
+ asr = WhisperASR('Whisper/base.pt')
1308
+ else:
1309
+ asr = WhisperASR('base')
1310
+ success_print("Success!!! WhisperASR模块加载成功,默认使用Whisper-base模型")
1311
+ except Exception as e:
1312
+ error_print(f"ASR Error: {e}")
1313
+ error_print("如果使用FunASR,请先下载WhisperASR模型和安装环境")
1314
+
1315
+ # 判断显存是否8g,若小于8g不建议使用MuseTalk功能
1316
+ # Check if GPU is available and has at least 8GB of memory
1317
+ if torch.cuda.is_available():
1318
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) # Convert bytes to GB
1319
+ if gpu_memory < 8:
1320
+ error_print("警告: 您的显卡显存小于8GB,不建议使用MuseTalk功能")
1321
+
1322
+ try:
1323
+ from TFG import MuseTalk_RealTime
1324
+ musetalker = MuseTalk_RealTime()
1325
+ success_print("Success!!! MuseTalk模块加载成功")
1326
+ except Exception as e:
1327
+ error_print(f"MuseTalk Error: {e}")
1328
+ error_print("如果使用MuseTalk,请先下载MuseTalk模型")
1329
+
1330
+ tts = edgetts
1331
+ if not tts.network:
1332
+ error_print("EdgeTTS模块加载失败,请检查网络是否正常连接,否则无法使用")
1333
+
1334
+ gr.close_all()
1335
+ # demo_app = app()
1336
+ demo_img = app_img()
1337
+ demo_multi = app_multi()
1338
+ # demo_vits = app_vits()
1339
+ # demo_talk = app_talk()
1340
+ demo_muse = app_muse()
1341
+ demo = gr.TabbedInterface(interface_list = [
1342
+ # demo_app,
1343
+ demo_img,
1344
+ demo_multi,
1345
+ # demo_vits,
1346
+ # demo_talk,
1347
+ demo_muse,
1348
+ ],
1349
+ tab_names = [
1350
+ "个性化角色互动",
1351
+ "数字人多轮智能对话",
1352
+ "MuseTalk数字人实时对话"
1353
+ ],
1354
+ title = "Linly-Talker WebUI")
1355
+ demo.queue()
1356
+ demo.launch(server_name=ip, # 本地端口localhost:127.0.0.1 全局端口转发:"0.0.0.0"
1357
+ server_port=port,
1358
+ # 似乎在Gradio4.0以上版本可以不使用证书也可以进行麦克风对话
1359
+ # ssl_certfile=ssl_certfile,
1360
+ # ssl_keyfile=ssl_keyfile,
1361
+ # ssl_verify=False,
1362
+ # share=True,
1363
+ debug=True,
1364
+ )
ASR/FunASR.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Reference: https://github.com/alibaba-damo-academy/FunASR
3
+ pip install funasr
4
+ pip install modelscope
5
+ pip install -U rotary_embedding_torch
6
+ '''
7
+ try:
8
+ from funasr import AutoModel
9
+ except:
10
+ print("如果想使用FunASR,请先安装funasr,若使用Whisper,请忽略此条信息")
11
+ import os
12
+ import sys
13
+ sys.path.append('./')
14
+ from src.cost_time import calculate_time
15
+
16
+ class FunASR:
17
+ def __init__(self) -> None:
18
+ # 定义模型的自定义路径
19
+ model_path = "FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
20
+ vad_model_path = "FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch"
21
+ punc_model_path = "FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
22
+
23
+ # 检查文件是否存在于 FunASR 目录下
24
+ model_exists = os.path.exists(model_path)
25
+ vad_model_exists = os.path.exists(vad_model_path)
26
+ punc_model_exists = os.path.exists(punc_model_path)
27
+ # Modelscope AutoDownload
28
+ self.model = AutoModel(
29
+ model=model_path if model_exists else "paraformer-zh",
30
+ vad_model=vad_model_path if vad_model_exists else "fsmn-vad",
31
+ punc_model=punc_model_path if punc_model_exists else "ct-punc-c",
32
+ )
33
+ # 自定义路径
34
+ # self.model = AutoModel(model="FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", # model_revision="v2.0.4",
35
+ # vad_model="FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch", # vad_model_revision="v2.0.4",
36
+ # punc_model="FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", # punc_model_revision="v2.0.4",
37
+ # # spk_model="cam++", spk_model_revision="v2.0.2",
38
+ # )
39
+ @calculate_time
40
+ def transcribe(self, audio_file):
41
+ res = self.model.generate(input=audio_file,
42
+ batch_size_s=300)
43
+ print(res)
44
+ return res[0]['text']
45
+
46
+
47
+ if __name__ == "__main__":
48
+ import os
49
+ # 创建ASR对象并进行语音识别
50
+ audio_file = "output.wav" # 音频文件路径
51
+ if not os.path.exists(audio_file):
52
+ os.system('edge-tts --text "hello" --write-media output.wav')
53
+ asr = FunASR()
54
+ print(asr.transcribe(audio_file))
ASR/README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## ASR 同数字人沟通的桥梁
2
+
3
+ ### Whisper OpenAI
4
+
5
+ Whisper 是一个自动语音识别 (ASR) 系统,它使用从网络上收集的 680,000 小时多语言和多任务监督数据进行训练。使用如此庞大且多样化的数据集可以提高对口音、背景噪音和技术语言的鲁棒性。此外,它还支持多种语言的转录,以及将这些语言翻译成英语。
6
+
7
+ 使用方法很简单,我们只要安装以下库,后续模型会自动下载
8
+
9
+ ```bash
10
+ pip install -U openai-whisper
11
+ ```
12
+
13
+ 借鉴OpenAI的Whisper实现了ASR的语音识别,具体使用方法参考 [https://github.com/openai/whisper](https://github.com/openai/whisper)
14
+
15
+ ```python
16
+ '''
17
+ https://github.com/openai/whisper
18
+ pip install -U openai-whisper
19
+ '''
20
+ import whisper
21
+
22
+ class WhisperASR:
23
+ def __init__(self, model_path):
24
+ self.LANGUAGES = {
25
+ "en": "english",
26
+ "zh": "chinese",
27
+ }
28
+ self.model = whisper.load_model(model_path)
29
+
30
+ def transcribe(self, audio_file):
31
+ result = self.model.transcribe(audio_file)
32
+ return result["text"]
33
+ ```
34
+
35
+
36
+
37
+ ### FunASR Alibaba
38
+
39
+ 阿里的`FunASR`的语音识别效果也是相当不错,而且时间也是比whisper更快的,更能达到实时的效果,所以也将FunASR添加进去了,在ASR文件夹下的FunASR文件里可以进行体验,参考 [https://github.com/alibaba-damo-academy/FunASR](https://github.com/alibaba-damo-academy/FunASR)
40
+
41
+ 需要注意的是,在第一次运行的时候,需要安装以下库。
42
+
43
+ ```bash
44
+ pip install funasr
45
+ pip install modelscope
46
+ pip install -U rotary_embedding_torch
47
+ ```
48
+
49
+ ```python
50
+ '''
51
+ Reference: https://github.com/alibaba-damo-academy/FunASR
52
+ pip install funasr
53
+ pip install modelscope
54
+ pip install -U rotary_embedding_torch
55
+ '''
56
+ try:
57
+ from funasr import AutoModel
58
+ except:
59
+ print("如果想使用FunASR,请先安装funasr,若使用Whisper,请忽略此条信息")
60
+
61
+ class FunASR:
62
+ def __init__(self) -> None:
63
+ self.model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
64
+ vad_model="fsmn-vad", vad_model_revision="v2.0.4",
65
+ punc_model="ct-punc-c", punc_model_revision="v2.0.4",
66
+ # spk_model="cam++", spk_model_revision="v2.0.2",
67
+ )
68
+
69
+ def transcribe(self, audio_file):
70
+ res = self.model.generate(input=audio_file,
71
+ batch_size_s=300)
72
+ print(res)
73
+ return res[0]['text']
74
+ ```
75
+
76
+
77
+
ASR/Whisper.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ https://github.com/openai/whisper
3
+ pip install -U openai-whisper
4
+ '''
5
+ import whisper
6
+ import sys
7
+ sys.path.append('./')
8
+ from src.cost_time import calculate_time
9
+
10
+ class WhisperASR:
11
+ def __init__(self, model_path):
12
+ self.LANGUAGES = {
13
+ "en": "english",
14
+ "zh": "chinese",
15
+ "de": "german",
16
+ "es": "spanish",
17
+ "ru": "russian",
18
+ "ko": "korean",
19
+ "fr": "french",
20
+ "ja": "japanese",
21
+ "pt": "portuguese",
22
+ "tr": "turkish",
23
+ "pl": "polish",
24
+ "ca": "catalan",
25
+ "nl": "dutch",
26
+ "ar": "arabic",
27
+ "sv": "swedish",
28
+ "it": "italian",
29
+ "id": "indonesian",
30
+ "hi": "hindi",
31
+ "fi": "finnish",
32
+ "vi": "vietnamese",
33
+ "he": "hebrew",
34
+ "uk": "ukrainian",
35
+ "el": "greek",
36
+ "ms": "malay",
37
+ "cs": "czech",
38
+ "ro": "romanian",
39
+ "da": "danish",
40
+ "hu": "hungarian",
41
+ "ta": "tamil",
42
+ "no": "norwegian",
43
+ "th": "thai",
44
+ "ur": "urdu",
45
+ "hr": "croatian",
46
+ "bg": "bulgarian",
47
+ "lt": "lithuanian",
48
+ "la": "latin",
49
+ "mi": "maori",
50
+ "ml": "malayalam",
51
+ "cy": "welsh",
52
+ "sk": "slovak",
53
+ "te": "telugu",
54
+ "fa": "persian",
55
+ "lv": "latvian",
56
+ "bn": "bengali",
57
+ "sr": "serbian",
58
+ "az": "azerbaijani",
59
+ "sl": "slovenian",
60
+ "kn": "kannada",
61
+ "et": "estonian",
62
+ "mk": "macedonian",
63
+ "br": "breton",
64
+ "eu": "basque",
65
+ "is": "icelandic",
66
+ "hy": "armenian",
67
+ "ne": "nepali",
68
+ "mn": "mongolian",
69
+ "bs": "bosnian",
70
+ "kk": "kazakh",
71
+ "sq": "albanian",
72
+ "sw": "swahili",
73
+ "gl": "galician",
74
+ "mr": "marathi",
75
+ "pa": "punjabi",
76
+ "si": "sinhala",
77
+ "km": "khmer",
78
+ "sn": "shona",
79
+ "yo": "yoruba",
80
+ "so": "somali",
81
+ "af": "afrikaans",
82
+ "oc": "occitan",
83
+ "ka": "georgian",
84
+ "be": "belarusian",
85
+ "tg": "tajik",
86
+ "sd": "sindhi",
87
+ "gu": "gujarati",
88
+ "am": "amharic",
89
+ "yi": "yiddish",
90
+ "lo": "lao",
91
+ "uz": "uzbek",
92
+ "fo": "faroese",
93
+ "ht": "haitian creole",
94
+ "ps": "pashto",
95
+ "tk": "turkmen",
96
+ "nn": "nynorsk",
97
+ "mt": "maltese",
98
+ "sa": "sanskrit",
99
+ "lb": "luxembourgish",
100
+ "my": "myanmar",
101
+ "bo": "tibetan",
102
+ "tl": "tagalog",
103
+ "mg": "malagasy",
104
+ "as": "assamese",
105
+ "tt": "tatar",
106
+ "haw": "hawaiian",
107
+ "ln": "lingala",
108
+ "ha": "hausa",
109
+ "ba": "bashkir",
110
+ "jw": "javanese",
111
+ "su": "sundanese",
112
+ }
113
+ self.model = whisper.load_model(model_path)
114
+
115
+ @calculate_time
116
+ def transcribe(self, audio_file):
117
+ result = self.model.transcribe(audio_file)
118
+ return result["text"]
119
+
120
+
121
+ if __name__ == "__main__":
122
+ import os
123
+ # 创建ASR对象并进行语音识别
124
+ model_path = "./Whisper/tiny.pt" # 模型路径
125
+ audio_file = "output.wav" # 音频文件路径
126
+ if not os.path.exists(audio_file):
127
+ os.system('edge-tts --text "hello" --write-media output.wav')
128
+ asr = WhisperASR(model_path)
129
+ print(asr.transcribe(audio_file))
ASR/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .Whisper import WhisperASR
2
+ from .FunASR import FunASR
3
+
4
+ __all__ = ['WhisperASR', 'FunASR']
ASR/requirements_funasr.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ funasr
2
+ modelscope
3
+ # rotary_embedding_torch
AutoDL部署.md ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 在AutoDL平台部署Linly-Talker (0基础小白超详细教程)
2
+
3
+ <!-- TOC -->
4
+
5
+ - [在AutoDL平台部署Linly-Talker 0基础小白超详细教程](#%E5%9C%A8autodl%E5%B9%B3%E5%8F%B0%E9%83%A8%E7%BD%B2linly-talker-0%E5%9F%BA%E7%A1%80%E5%B0%8F%E7%99%BD%E8%B6%85%E8%AF%A6%E7%BB%86%E6%95%99%E7%A8%8B)
6
+ - [快速上手直接使用镜像以下安装操作全免](#%E5%BF%AB%E9%80%9F%E4%B8%8A%E6%89%8B%E7%9B%B4%E6%8E%A5%E4%BD%BF%E7%94%A8%E9%95%9C%E5%83%8F%E4%BB%A5%E4%B8%8B%E5%AE%89%E8%A3%85%E6%93%8D%E4%BD%9C%E5%85%A8%E5%85%8D)
7
+ - [一、注册AutoDL](#%E4%B8%80%E6%B3%A8%E5%86%8Cautodl)
8
+ - [二、创建实例](#%E4%BA%8C%E5%88%9B%E5%BB%BA%E5%AE%9E%E4%BE%8B)
9
+ - [登录AutoDL,进入算力市场,选择机器](#%E7%99%BB%E5%BD%95autodl%E8%BF%9B%E5%85%A5%E7%AE%97%E5%8A%9B%E5%B8%82%E5%9C%BA%E9%80%89%E6%8B%A9%E6%9C%BA%E5%99%A8)
10
+ - [配置基础镜像](#%E9%85%8D%E7%BD%AE%E5%9F%BA%E7%A1%80%E9%95%9C%E5%83%8F)
11
+ - [无卡模式开机](#%E6%97%A0%E5%8D%A1%E6%A8%A1%E5%BC%8F%E5%BC%80%E6%9C%BA)
12
+ - [三、部署环境](#%E4%B8%89%E9%83%A8%E7%BD%B2%E7%8E%AF%E5%A2%83)
13
+ - [进入终端](#%E8%BF%9B%E5%85%A5%E7%BB%88%E7%AB%AF)
14
+ - [下载代码文件](#%E4%B8%8B%E8%BD%BD%E4%BB%A3%E7%A0%81%E6%96%87%E4%BB%B6)
15
+ - [下载模型文件](#%E4%B8%8B%E8%BD%BD%E6%A8%A1%E5%9E%8B%E6%96%87%E4%BB%B6)
16
+ - [四、Linly-Talker项目](#%E5%9B%9Blinly-talker%E9%A1%B9%E7%9B%AE)
17
+ - [环境安装](#%E7%8E%AF%E5%A2%83%E5%AE%89%E8%A3%85)
18
+ - [端口设置](#%E7%AB%AF%E5%8F%A3%E8%AE%BE%E7%BD%AE)
19
+ - [有卡开机](#%E6%9C%89%E5%8D%A1%E5%BC%80%E6%9C%BA)
20
+ - [运行网页版对话webui](#%E8%BF%90%E8%A1%8C%E7%BD%91%E9%A1%B5%E7%89%88%E5%AF%B9%E8%AF%9Dwebui)
21
+ - [端口映射](#%E7%AB%AF%E5%8F%A3%E6%98%A0%E5%B0%84)
22
+ - [体验Linly-Talker(成功)](#%E4%BD%93%E9%AA%8Clinly-talker%E6%88%90%E5%8A%9F)
23
+
24
+ <!-- /TOC -->
25
+
26
+
27
+
28
+ ## 快速上手直接使用镜像(以下安装操作全免)
29
+
30
+ 若使用我设定好的镜像,可以直接运行即可,不需要安装环境,直接运行webui.py或者是app_talk.py即可体验,不需要安装任何环境,可直接跳到4.4即可
31
+
32
+ 访问后在自定义设置里面打开端口,默认是6006端口,直接使用运行即可!
33
+
34
+ ```bash
35
+ python webui.py
36
+ python app_talk.py
37
+ ```
38
+
39
+ 环境模型都安装好了,直接使用即可,镜像地址在:[https://www.codewithgpu.com/i/Kedreamix/Linly-Talker/Kedreamix-Linly-Talker](https://www.codewithgpu.com/i/Kedreamix/Linly-Talker/Kedreamix-Linly-Talker),感谢大家的支持
40
+
41
+
42
+
43
+ ## 一、注册AutoDL
44
+
45
+ [AutoDL官网](https://www.autodl.com/home) 注册账户好并充值,自己选择机器,我觉得如果正常跑一下,5元已经够了
46
+
47
+ ![注册AutoDL](https://pic1.zhimg.com/v2-f56bc692a0d22fb1ae749b7697ff5d0f.png)
48
+
49
+ ## 二、创建实例
50
+
51
+ ### 2.1 登录AutoDL,进入算力市场,选择机器
52
+
53
+ 这一部分实际上我觉得12g都OK的,无非是速度问题而已
54
+
55
+ ![选择RTX 3090机器](https://picx.zhimg.com/v2-824956d591eead5d3ed4de87c59258a6.png)
56
+
57
+
58
+
59
+ ### 2.2 配置基础镜像
60
+
61
+ 选择镜像,最好选择2.0以上可以体验克隆声音功能,其他无所谓
62
+
63
+ ![配置基础镜像](https://pic1.zhimg.com/v2-8d064d809e15673dc6f2be8f2ef83ae7.png)
64
+
65
+
66
+
67
+ ### 2.3 无卡模式开机
68
+
69
+ 创建成功后为了省钱先关机,然后使用无卡模式开机。
70
+ 无卡模式一个小时只需要0.1元,比较适合部署环境。
71
+
72
+ ![无卡模式开机](https://picx.zhimg.com/v2-118eead549c35ad06d946b00cd93c668.png)
73
+
74
+ ## 三、部署环境
75
+
76
+ ### 3.1 进入终端
77
+
78
+ 打开jupyterLab,进入数据盘(autodl-tmp),打开终端,将Linly-Talker模型下载到数据盘中。
79
+
80
+ ![进入终端](https://pica.zhimg.com/v2-7c29a3f68c04b9af0963d4acaa63f59c.png)
81
+
82
+
83
+
84
+ ### 3.2 下载代码文件
85
+
86
+ 根据Github上的说明,使用命令行下载模型文件和代码文件,利用学术加速会快一点
87
+
88
+ ```bash
89
+ # 开启学术镜像,更快的clone代码 参考 https://www.autodl.com/docs/network_turbo/
90
+ source /etc/network_turbo
91
+
92
+ cd /root/autodl-tmp/
93
+ # 下载代码
94
+ git clone https://github.com/Kedreamix/Linly-Talker.git
95
+
96
+ # 取消学术加速
97
+ unset http_proxy && unset https_proxy
98
+ ```
99
+
100
+
101
+
102
+ ### 3.3 下载模型文件
103
+
104
+ 安装git lfs
105
+
106
+ ```sh
107
+ curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
108
+ sudo apt-get install git-lfs
109
+ ```
110
+
111
+ ![安装git lfs](https://pic1.zhimg.com/v2-084b5f048d1a70d8712bd6a6891c673c.png)
112
+
113
+ 根据 [https://www.modelscope.cn/Kedreamix/Linly-Talker](https://www.modelscope.cn/Kedreamix/Linly-Talker) 下载模型文件,走modelscope还是很快的,不过文件有点多,还是得等一下,记住是在Linly-Talker代码路径下执行这个文件
114
+
115
+ ```bash
116
+ cd /root/autodl-tmp/Linly-Talker/
117
+ git lfs install
118
+ git lfs clone https://www.modelscope.cn/Kedreamix/Linly-Talker.git
119
+ ```
120
+
121
+ ![下载模型文件](https://picx.zhimg.com/v2-1f61be9be69053ebe16c72762c350d18.png)
122
+
123
+ 等待一段时间下载完以后,利用命令将模型移动到指定目���,直接复制即可
124
+
125
+ ```bash
126
+ # 移动所有模型到当前目录
127
+ # checkpoint中含有SadTalker和Wav2Lip
128
+ mv Linly-Talker/checkpoints/* ./checkpoints
129
+
130
+ # SadTalker的增强GFPGAN
131
+ # pip install gfpgan
132
+ # mv Linly-Talker/gfpan ./
133
+
134
+ # 语音克隆模型
135
+ mv Linly-Talker/GPT_SoVITS/pretrained_models/* ./GPT_SoVITS/pretrained_models/
136
+
137
+ # Qwen大模型
138
+ mv Linly-Talker/Qwen ./
139
+ ```
140
+
141
+
142
+
143
+ ## 四、Linly-Talker项目
144
+
145
+ ### 4.1 环境安装
146
+
147
+ 进入代码路径,进行安装环境,由于选了镜像是含有pytorch的,所以只需要进行安装其他依赖即可
148
+
149
+ ```bash
150
+ cd /root/autodl-tmp/Linly-Talker
151
+
152
+ conda install -q ffmpeg # ffmpeg==4.2.2
153
+
154
+ # 安装Linly-Talker对应依赖
155
+ pip install -r requirements_app.txt
156
+
157
+ # 安装语音克隆对应的依赖
158
+ pip install -r VITS/requirements_gptsovits.txt
159
+ ```
160
+
161
+
162
+
163
+ ### 4.2 端口设置
164
+
165
+ 由于似乎autodl开放的是6006端口,所以这里面的端口映射也可以改一下成6006,这里吗只需要修改configs.py文件里面的port为6006即可
166
+
167
+ ![端口设置](https://picx.zhimg.com/v2-d9516998de3e6ffea282c4c6776c126f.png)
168
+
169
+ 除此之外,我发现其实对于autodl来说,不是很支持https的端口映射,所以需要注释掉几行代码即可,在webui.py的最后几行注释掉代码ssl相关代码
170
+
171
+ ```bash
172
+ demo.launch(server_name="127.0.0.1", # 本地端口localhost:127.0.0.1 全局端口转发:"0.0.0.0"
173
+ server_port=port,
174
+ # 似乎在Gradio4.0以上版本可以不使用证书也可以进行麦克风对话
175
+ # ssl_certfile=ssl_certfile,
176
+ # ssl_keyfile=ssl_keyfile,
177
+ # ssl_verify=False,
178
+ debug=True,
179
+ )
180
+ ```
181
+
182
+ 如果使用app.py同理
183
+
184
+ ![http设置](https://picx.zhimg.com/v2-fd4aed3c765b8a4a425e12520a1461ae.png)
185
+
186
+ ### 4.3 有卡开机
187
+
188
+ 进入autodl容器实例界面,执行关机操作,然后进行有卡开机,开机后打开jupyterLab。
189
+
190
+ 查看配置
191
+
192
+ ```bash
193
+ nvidia-smi
194
+ ```
195
+
196
+ ![有卡开机](https://pic1.zhimg.com/v2-baff41e6634fd25fb5d39806b86b1c40.png)
197
+
198
+
199
+
200
+ ### 4.4 运行网页版对话webui
201
+
202
+ 需要有卡模式开机,执行下边命令,这里面就跟代码是一模一样的了
203
+
204
+ ```bash
205
+ python webui.py
206
+ ```
207
+
208
+ ![运行网页版对话webui](https://pic1.zhimg.com/v2-c1c8a2653cef9d4b044fb85555411928.png)
209
+
210
+
211
+
212
+ ### 4.4 端口映射
213
+
214
+ 这可以直接打开autodl的自定义服务,默认是6006端口,我们已经设置了,所以直接使用即可
215
+
216
+ ![端口映射](https://picx.zhimg.com/v2-19a51c44478a1fad9e0de593ee4d547f.png)
217
+
218
+
219
+
220
+ ### 4.5 体验Linly-Talker(成功)
221
+
222
+ 点开网页,即可正确执行Linly-Talker,这一部分就跟视频一模一样了
223
+
224
+ ![体验Linly-Talker](https://picx.zhimg.com/v2-1b67099f2903fdcac732b784477851aa.png)
225
+
226
+
227
+
228
+
229
+
230
+ ssh端口映射工具:windows:[https://autodl-public.ks3-cn-beijing.ksyuncs.com/tool/AutoDL-SSH-Tools.zip](https://autodl-public.ks3-cn-beijing.ksyuncs.com/tool/AutoDL-SSH-Tools.zip)
231
+
232
+ **!!!注意:不用了,一定要去控制台=》容器实例,把镜像实例关机,它是按时收费的,不关机会一直扣费的。**
233
+
234
+ **建议选北京区的,稍微便宜一些。可以晚上部署,网速快,便宜的GPU也充足。白天部署,北京区的GPU容易没有。**
FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/.mdl ADDED
Binary file (79 Bytes). View file
 
FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/.msc ADDED
Binary file (566 Bytes). View file
 
FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/README.md ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tasks:
3
+ - punctuation
4
+ domain:
5
+ - audio
6
+ model-type:
7
+ - Classification
8
+ frameworks:
9
+ - pytorch
10
+ metrics:
11
+ - f1_score
12
+ license: Apache License 2.0
13
+ language:
14
+ - cn
15
+ tags:
16
+ - FunASR
17
+ - CT-Transformer
18
+ - Alibaba
19
+ - ICASSP 2020
20
+ datasets:
21
+ train:
22
+ - 33M-samples online data
23
+ test:
24
+ - wikipedia data test
25
+ - 10000 industrial Mandarin sentences test
26
+ widgets:
27
+ - task: punctuation
28
+ model_revision: v2.0.4
29
+ inputs:
30
+ - type: text
31
+ name: input
32
+ title: 文本
33
+ examples:
34
+ - name: 1
35
+ title: 示例1
36
+ inputs:
37
+ - name: input
38
+ data: 我们都是木头人不会讲话不会动
39
+ inferencespec:
40
+ cpu: 1 #CPU数量
41
+ memory: 4096
42
+ ---
43
+
44
+ # Controllable Time-delay Transformer模型介绍
45
+
46
+ [//]: # (Controllable Time-delay Transformer 模型是一种端到端标点分类模型。)
47
+
48
+ [//]: # (常规的Transformer会依赖很远的未来信息,导致长时间结果不固定。Controllable Time-delay Transformer 在效果无损的情况下,有效控制标点的延时。)
49
+
50
+ # Highlights
51
+ - 中文标点通用模型:可用于语音识别模型输出文本的标点预测。
52
+ - 基于[Paraformer-large长音频模型](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)场景的使用
53
+ - 基于[FunASR框架](https://github.com/alibaba-damo-academy/FunASR),可进行ASR,VAD,标点的自由组合
54
+ - 基于纯文本输入的标点预测
55
+
56
+ ## <strong>[FunASR开源项目介绍](https://github.com/alibaba-damo-academy/FunASR)</strong>
57
+ <strong>[FunASR](https://github.com/alibaba-damo-academy/FunASR)</strong>希望在语音识别的学术研究和工业应用之间架起一座桥梁。通过发布工业级语音识别模型的训练和微调,研究人员和开发人员可以更方便地进行语音识别模型的研究和生产,并推动语音识别生态的发展。让语音识别更有趣!
58
+
59
+ [**github仓库**](https://github.com/alibaba-damo-academy/FunASR)
60
+ | [**最新动态**](https://github.com/alibaba-damo-academy/FunASR#whats-new)
61
+ | [**环境安装**](https://github.com/alibaba-damo-academy/FunASR#installation)
62
+ | [**服务部署**](https://www.funasr.com)
63
+ | [**模型库**](https://github.com/alibaba-damo-academy/FunASR/tree/main/model_zoo)
64
+ | [**联系我们**](https://github.com/alibaba-damo-academy/FunASR#contact)
65
+
66
+
67
+ ## 模型原理介绍
68
+
69
+ Controllable Time-delay Transformer是达摩院语音团队提出的高效后处理框架中的标点模块。本项目为中文通用标点模型,模型可以被应用于文本类输入的标点预测,也可应用于语音识别结果的后处理步骤,协助语音识别模块输出具有可读性的文本结果。
70
+
71
+ <p align="center">
72
+ <img src="fig/struct.png" alt="Controllable Time-delay Transformer模型结构" width="500" />
73
+
74
+ Controllable Time-delay Transformer 模型结构如上图所示,由 Embedding、Encoder 和 Predictor 三部分组成。Embedding 是词向量叠加位置向量。Encoder可以采用不同的网络结构,例如self-attention,conformer,SAN-M等。Predictor 预测每个token后的标点类型。
75
+
76
+ 在模型的选择上采用了性能优越的Transformer模型。Transformer模型在获得良好性能的同时,由于模型自身序列化输入等特性,会给系统带来较大时延。常规的Transformer可以看到未来的全部信息,导致标点会依赖很远的未来信息。这会给用户带来一种标点一直在变化刷新,长时间结果不固定的不良感受。基于这一问题,我们创新性的提出了可控时延的Transformer模型(Controllable Time-Delay Transformer, CT-Transformer),在模型性能无损失的情况下,有效控制标点的延时。
77
+
78
+ 更详细的细节见:
79
+ - 论文: [CONTROLLABLE TIME-DELAY TRANSFORMER FOR REAL-TIME PUNCTUATION PREDICTION AND DISFLUENCY DETECTION](https://arxiv.org/pdf/2003.01309.pdf)
80
+
81
+ ## 基于ModelScope进行推理
82
+
83
+ 以下为三种支持格式及api调用方式参考如下范例:
84
+ - text.scp文件路径,例如example/punc_example.txt,格式为: key + "\t" + value
85
+ ```sh
86
+ cat example/punc_example.txt
87
+ 1 跨境河流是养育沿岸人民的生命之源
88
+ 2 从存储上来说仅仅是全景图片它就会是图片的四倍的容量
89
+ 3 那今天的会就到这里吧happy new year明年见
90
+ ```
91
+ ```python
92
+ from modelscope.pipelines import pipeline
93
+ from modelscope.utils.constant import Tasks
94
+
95
+ inference_pipline = pipeline(
96
+ task=Tasks.punctuation,
97
+ model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
98
+ model_revision="v2.0.4")
99
+
100
+ rec_result = inference_pipline(input='example/punc_example.txt')
101
+ print(rec_result)
102
+ ```
103
+ - text二进制数据,例如:用户直接从文件里读出bytes数据
104
+ ```python
105
+ rec_result = inference_pipline(input='我们都是木头人不会讲话不会动')
106
+ ```
107
+ - text文件url,例如:https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt
108
+ ```python
109
+ rec_result = inference_pipline(input='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt')
110
+ ```
111
+
112
+
113
+ ## 基于FunASR进行推理
114
+
115
+ 下面为快速上手教程,测试音频([中文](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav),[英文](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_en.wav))
116
+
117
+ ### 可执行命令行
118
+ 在命令行终端执行:
119
+
120
+ ```shell
121
+ funasr ++model=paraformer-zh ++vad_model="fsmn-vad" ++punc_model="ct-punc" ++input=vad_example.wav
122
+ ```
123
+
124
+ 注:支持单条音频文件识别,也支持文件列表,列表为kaldi风格wav.scp:`wav_id wav_path`
125
+
126
+ ### python示例
127
+ #### 非实时语音识别
128
+ ```python
129
+ from funasr import AutoModel
130
+ # paraformer-zh is a multi-functional asr model
131
+ # use vad, punc, spk or not as you need
132
+ model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
133
+ vad_model="fsmn-vad", vad_model_revision="v2.0.4",
134
+ punc_model="ct-punc-c", punc_model_revision="v2.0.4",
135
+ # spk_model="cam++", spk_model_revision="v2.0.2",
136
+ )
137
+ res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
138
+ batch_size_s=300,
139
+ hotword='魔搭')
140
+ print(res)
141
+ ```
142
+ 注:`model_hub`:表示模型仓库,`ms`为选择modelscope下载,`hf`为选择huggingface下载。
143
+
144
+ #### 实时语音识别
145
+
146
+ ```python
147
+ from funasr import AutoModel
148
+
149
+ chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
150
+ encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
151
+ decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
152
+
153
+ model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.4")
154
+
155
+ import soundfile
156
+ import os
157
+
158
+ wav_file = os.path.join(model.model_path, "example/asr_example.wav")
159
+ speech, sample_rate = soundfile.read(wav_file)
160
+ chunk_stride = chunk_size[1] * 960 # 600ms
161
+
162
+ cache = {}
163
+ total_chunk_num = int(len((speech)-1)/chunk_stride+1)
164
+ for i in range(total_chunk_num):
165
+ speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
166
+ is_final = i == total_chunk_num - 1
167
+ res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
168
+ print(res)
169
+ ```
170
+
171
+ 注:`chunk_size`为流式延时配置,`[0,10,5]`表示上屏实时出字粒度为`10*60=600ms`,未来信息为`5*60=300ms`。每次推理输入为`600ms`(采样点数为`16000*0.6=960`),输出为对应文字,最后一个语音片段输入需要设置`is_final=True`来强制输出最后一个字。
172
+
173
+ #### 语音端点检测(非实时)
174
+ ```python
175
+ from funasr import AutoModel
176
+
177
+ model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
178
+
179
+ wav_file = f"{model.model_path}/example/asr_example.wav"
180
+ res = model.generate(input=wav_file)
181
+ print(res)
182
+ ```
183
+
184
+ #### 语音端点检测(实时)
185
+ ```python
186
+ from funasr import AutoModel
187
+
188
+ chunk_size = 200 # ms
189
+ model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
190
+
191
+ import soundfile
192
+
193
+ wav_file = f"{model.model_path}/example/vad_example.wav"
194
+ speech, sample_rate = soundfile.read(wav_file)
195
+ chunk_stride = int(chunk_size * sample_rate / 1000)
196
+
197
+ cache = {}
198
+ total_chunk_num = int(len((speech)-1)/chunk_stride+1)
199
+ for i in range(total_chunk_num):
200
+ speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
201
+ is_final = i == total_chunk_num - 1
202
+ res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
203
+ if len(res[0]["value"]):
204
+ print(res)
205
+ ```
206
+
207
+ #### 标点恢复
208
+ ```python
209
+ from funasr import AutoModel
210
+
211
+ model = AutoModel(model="ct-punc", model_revision="v2.0.4")
212
+
213
+ res = model.generate(input="那今天的会就到这里吧 happy new year 明年见")
214
+ print(res)
215
+ ```
216
+
217
+ #### 时间戳预测
218
+ ```python
219
+ from funasr import AutoModel
220
+
221
+ model = AutoModel(model="fa-zh", model_revision="v2.0.4")
222
+
223
+ wav_file = f"{model.model_path}/example/asr_example.wav"
224
+ text_file = f"{model.model_path}/example/text.txt"
225
+ res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
226
+ print(res)
227
+ ```
228
+
229
+ 更多详细用法([示例](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining))
230
+
231
+
232
+ ## 微调
233
+
234
+ 详细用法([示例](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining))
235
+
236
+
237
+
238
+
239
+
240
+ ## Benchmark
241
+ 中文标点预测通用模型在自采集的通用领域业务场景数据上有良好效果。训练数据大约33M个sample,每个sample可能包含1句或多句。
242
+
243
+ ### 自采集数据(20000+ samples)
244
+
245
+ | precision | recall | f1_score |
246
+ |:------------------------------------:|:-------------------------------------:|:-------------------------------------:|
247
+ | <div style="width: 150pt">53.8</div> | <div style="width: 150pt">60.0</div> | <div style="width: 150pt">56.5</div> |
248
+
249
+ ## 使用方式以及适用范围
250
+
251
+ 运行范围
252
+ - 支持Linux-x86_64、Mac和Windows运行。
253
+
254
+ 使用方式
255
+ - 直接推理:可以直接对输入文本进行计算,输出带有标点的目标文字。
256
+
257
+ 使用范围与目标场景
258
+ - 适合对文本数据进行标点预测,文本长度不限。
259
+
260
+ ## 相关论文以及引用信息
261
+
262
+ ```BibTeX
263
+ @inproceedings{chen2020controllable,
264
+ title={Controllable Time-Delay Transformer for Real-Time Punctuation Prediction and Disfluency Detection},
265
+ author={Chen, Qian and Chen, Mengzhe and Li, Bo and Wang, Wen},
266
+ booktitle={ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
267
+ pages={8069--8073},
268
+ year={2020},
269
+ organization={IEEE}
270
+ }
271
+ ```
272
+
FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: CTTransformer
2
+ model_conf:
3
+ ignore_id: 0
4
+ embed_unit: 256
5
+ att_unit: 256
6
+ dropout_rate: 0.1
7
+ punc_list:
8
+ - <unk>
9
+ - _
10
+ - ,
11
+ - 。
12
+ - ?
13
+ - 、
14
+ punc_weight:
15
+ - 1.0
16
+ - 1.0
17
+ - 1.0
18
+ - 1.0
19
+ - 1.0
20
+ - 1.0
21
+ sentence_end_id: 3
22
+
23
+ encoder: SANMEncoder
24
+ encoder_conf:
25
+ input_size: 256
26
+ output_size: 256
27
+ attention_heads: 8
28
+ linear_units: 1024
29
+ num_blocks: 4
30
+ dropout_rate: 0.1
31
+ positional_dropout_rate: 0.1
32
+ attention_dropout_rate: 0.0
33
+ input_layer: pe
34
+ pos_enc_class: SinusoidalPositionEncoder
35
+ normalize_before: true
36
+ kernel_size: 11
37
+ sanm_shfit: 0
38
+ selfattention_layer_type: sanm
39
+ padding_idx: 0
40
+
41
+ tokenizer: CharTokenizer
42
+ tokenizer_conf:
43
+ unk_symbol: <unk>
44
+
45
+
46
+
FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/configuration.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "framework": "pytorch",
3
+ "task" : "punctuation",
4
+ "model": {"type" : "funasr"},
5
+ "pipeline": {"type":"funasr-pipeline"},
6
+ "model_name_in_hub": {
7
+ "ms":"iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
8
+ "hf":""},
9
+ "file_path_metas": {
10
+ "init_param":"model.pt",
11
+ "config":"config.yaml",
12
+ "tokenizer_conf": {"token_list": "tokens.json"}}
13
+ }
FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/example/punc_example.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 1 跨境河流是养育沿岸人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切愿意进一步完善双方联合工作机制凡是中方能做的我们都会去做而且会做得更好我请印度朋友们放心中国在上游的任何开发利用都会经过科学规划和论证兼顾上下游的利益
2
+ 2 从存储上来说仅仅是全景图片它就会是图片的四倍的容量然后全景的视频会是普通视频八倍的这个存储的容要求而三d的模型会是图片的十倍这都对我们今天运行在的云计算的平台存储的平台提出了更高的要求
3
+ 3 那今天的会就到这里吧 happy new year 明年见
FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/fig/struct.png ADDED
FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5818bb9d933805a916eebe41eb41648f7f9caad30b4bd59d56f3ca135421916
3
+ size 291979892
FunASR/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/tokens.json ADDED
The diff for this file is too large to render. See raw diff
 
FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch/.mdl ADDED
Binary file (67 Bytes). View file
 
FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch/.msc ADDED
Binary file (497 Bytes). View file
 
FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch/README.md ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tasks:
3
+ - voice-activity-detection
4
+ domain:
5
+ - audio
6
+ model-type:
7
+ - VAD model
8
+ frameworks:
9
+ - pytorch
10
+ backbone:
11
+ - fsmn
12
+ metrics:
13
+ - f1_score
14
+ license: Apache License 2.0
15
+ language:
16
+ - cn
17
+ tags:
18
+ - FunASR
19
+ - FSMN
20
+ - Alibaba
21
+ - Online
22
+ datasets:
23
+ train:
24
+ - 20,000 hour industrial Mandarin task
25
+ test:
26
+ - 20,000 hour industrial Mandarin task
27
+ widgets:
28
+ - task: voice-activity-detection
29
+ model_revision: v2.0.4
30
+ inputs:
31
+ - type: audio
32
+ name: input
33
+ title: 音频
34
+ examples:
35
+ - name: 1
36
+ title: 示例1
37
+ inputs:
38
+ - name: input
39
+ data: git://example/vad_example.wav
40
+ inferencespec:
41
+ cpu: 1 #CPU数量
42
+ memory: 4096
43
+ ---
44
+
45
+ # FSMN-Monophone VAD 模型介绍
46
+
47
+ [//]: # (FSMN-Monophone VAD 模型)
48
+
49
+ ## Highlight
50
+ - 16k中文通用VAD模型:可用于检测长语音片段中有效语音的起止时间点。
51
+ - 基于[Paraformer-large长音频模型](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)场景的使用
52
+ - 基于[FunASR框架](https://github.com/alibaba-damo-academy/FunASR),可进行ASR,VAD,[中文标点](https://www.modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary)的自由组合
53
+ - 基于音频数据的有效语音片段起止时间点检测
54
+
55
+ ## <strong>[FunASR开源项目介绍](https://github.com/alibaba-damo-academy/FunASR)</strong>
56
+ <strong>[FunASR](https://github.com/alibaba-damo-academy/FunASR)</strong>希望在语音识别的学术研究和工业应用之间架起一座桥梁。通过发布工业级语音识别模型的训练和微调,研究人员和开发人员可以更方便地进行语音识别模型的研究和生产,并推动语音识别生态的发展。让语音识别更有趣!
57
+
58
+ [**github仓库**](https://github.com/alibaba-damo-academy/FunASR)
59
+ | [**最新动态**](https://github.com/alibaba-damo-academy/FunASR#whats-new)
60
+ | [**环境安装**](https://github.com/alibaba-damo-academy/FunASR#installation)
61
+ | [**服务部署**](https://www.funasr.com)
62
+ | [**模型库**](https://github.com/alibaba-damo-academy/FunASR/tree/main/model_zoo)
63
+ | [**联系我们**](https://github.com/alibaba-damo-academy/FunASR#contact)
64
+
65
+
66
+ ## 模型原理介绍
67
+
68
+ FSMN-Monophone VAD是达摩院语音团队提出的高效语音端点检测模型,用于检测输入音频中有效语音的起止时间点信息,并将检测出来的有效音频片段输入识别引擎进行识别,减少无效语音带来的识别错误。
69
+
70
+ <p align="center">
71
+ <img src="fig/struct.png" alt="VAD模型结构" width="500" />
72
+
73
+ FSMN-Monophone VAD模型结构如上图所示:模型结构层面,FSMN模型结构建模时可考虑上下文信息,训练和推理速度快,且时延可控;同时根据VAD模型size以及低时延的要求,对FSMN的网络结构、右看帧数进行了适配。在建模单元层面,speech信息比较丰富,仅用单类来表征学习能力有限,我们将单一speech类升级为Monophone。建模单元细分,可以避免参数平均,抽象学习能力增强,区分性更好。
74
+
75
+ ## 基于ModelScope进行推理
76
+
77
+ - 推理支持音频格式如下:
78
+ - wav文件路径,例如:data/test/audios/vad_example.wav
79
+ - wav文件url,例如:https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav
80
+ - wav二进制数据,格式bytes,例如:用户直接从文件里读出bytes数据或者是麦克风录出bytes数据。
81
+ - 已解析的audio音频,例如:audio, rate = soundfile.read("vad_example_zh.wav"),类型为numpy.ndarray或者torch.Tensor。
82
+ - wav.scp文件,需符合如下要求:
83
+
84
+ ```sh
85
+ cat wav.scp
86
+ vad_example1 data/test/audios/vad_example1.wav
87
+ vad_example2 data/test/audios/vad_example2.wav
88
+ ...
89
+ ```
90
+
91
+ - 若输入格式wav文件url,api调用方式可参考如下范例:
92
+
93
+ ```python
94
+ from modelscope.pipelines import pipeline
95
+ from modelscope.utils.constant import Tasks
96
+
97
+ inference_pipeline = pipeline(
98
+ task=Tasks.voice_activity_detection,
99
+ model='iic/speech_fsmn_vad_zh-cn-16k-common-pytorch',
100
+ model_revision="v2.0.4",
101
+ )
102
+
103
+ segments_result = inference_pipeline(input='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav')
104
+ print(segments_result)
105
+ ```
106
+
107
+ - 输入音频为pcm格式,调用api时需要传入音频采样率参数fs,例如:
108
+
109
+ ```python
110
+ segments_result = inference_pipeline(input='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.pcm', fs=16000)
111
+ ```
112
+
113
+ - 若输入格式为文件wav.scp(注:文件名需要以.scp结尾),可添加 output_dir 参数将识别结果写入文件中,参考示例如下:
114
+
115
+ ```python
116
+ inference_pipeline(input="wav.scp", output_dir='./output_dir')
117
+ ```
118
+ 识别结果输出路径结构如下:
119
+
120
+ ```sh
121
+ tree output_dir/
122
+ output_dir/
123
+ └── 1best_recog
124
+ └── text
125
+
126
+ 1 directory, 1 files
127
+ ```
128
+ text:VAD检测语音起止时间点结果文件(单位:ms)
129
+
130
+ - 若输入音频为已解析的audio音频,api调用方式可参考如下范例:
131
+
132
+ ```python
133
+ import soundfile
134
+
135
+ waveform, sample_rate = soundfile.read("vad_example_zh.wav")
136
+ segments_result = inference_pipeline(input=waveform)
137
+ print(segments_result)
138
+ ```
139
+
140
+ - VAD常用参数调整说明(参考:vad.yaml文件):
141
+ - max_end_silence_time:尾部连续检测到多长时间静音进行尾点判停,参数范围500ms~6000ms,默认值800ms(该值过低容易出现语音提前截断的情况)。
142
+ - speech_noise_thres:speech的得分减去noise的得分大于此值则判断为speech,参数范围:(-1,1)
143
+ - 取值越趋于-1,噪音被误判定为语音的概率越大,FA越高
144
+ - 取值越趋于+1,语音被误判定为噪音的概率越大,Pmiss越高
145
+ - 通常情况下,该值会根据当前模型在长语音测试集上的效果取balance
146
+
147
+
148
+
149
+
150
+ ## 基于FunASR进行推理
151
+
152
+ 下面为快速上手教程,测试音频([中文](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav),[英文](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_en.wav))
153
+
154
+ ### 可执行命令行
155
+ 在命令行终端执行:
156
+
157
+ ```shell
158
+ funasr ++model=paraformer-zh ++vad_model="fsmn-vad" ++punc_model="ct-punc" ++input=vad_example.wav
159
+ ```
160
+
161
+ 注:支持单条音频文件识别,也支持文件列表,列表为kaldi风格wav.scp:`wav_id wav_path`
162
+
163
+ ### python示例
164
+ #### 非实时语音识别
165
+ ```python
166
+ from funasr import AutoModel
167
+ # paraformer-zh is a multi-functional asr model
168
+ # use vad, punc, spk or not as you need
169
+ model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
170
+ vad_model="fsmn-vad", vad_model_revision="v2.0.4",
171
+ punc_model="ct-punc-c", punc_model_revision="v2.0.4",
172
+ # spk_model="cam++", spk_model_revision="v2.0.2",
173
+ )
174
+ res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
175
+ batch_size_s=300,
176
+ hotword='魔搭')
177
+ print(res)
178
+ ```
179
+ 注:`model_hub`:表示模型仓库,`ms`为选择modelscope下载,`hf`为选择huggingface下载。
180
+
181
+ #### 实时语音识别
182
+
183
+ ```python
184
+ from funasr import AutoModel
185
+
186
+ chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
187
+ encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
188
+ decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
189
+
190
+ model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.4")
191
+
192
+ import soundfile
193
+ import os
194
+
195
+ wav_file = os.path.join(model.model_path, "example/asr_example.wav")
196
+ speech, sample_rate = soundfile.read(wav_file)
197
+ chunk_stride = chunk_size[1] * 960 # 600ms
198
+
199
+ cache = {}
200
+ total_chunk_num = int(len((speech)-1)/chunk_stride+1)
201
+ for i in range(total_chunk_num):
202
+ speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
203
+ is_final = i == total_chunk_num - 1
204
+ res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
205
+ print(res)
206
+ ```
207
+
208
+ 注:`chunk_size`为流式延时配置,`[0,10,5]`表示上屏实时出字粒度为`10*60=600ms`,未来信息为`5*60=300ms`。每次推理输入为`600ms`(采样点数为`16000*0.6=960`),输出为对应文字,最后一个语音片段输入需要设置`is_final=True`来强制输出最后一个字。
209
+
210
+ #### 语音端点检测(非实时)
211
+ ```python
212
+ from funasr import AutoModel
213
+
214
+ model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
215
+
216
+ wav_file = f"{model.model_path}/example/asr_example.wav"
217
+ res = model.generate(input=wav_file)
218
+ print(res)
219
+ ```
220
+
221
+ #### 语音端点检测(实时)
222
+ ```python
223
+ from funasr import AutoModel
224
+
225
+ chunk_size = 200 # ms
226
+ model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
227
+
228
+ import soundfile
229
+
230
+ wav_file = f"{model.model_path}/example/vad_example.wav"
231
+ speech, sample_rate = soundfile.read(wav_file)
232
+ chunk_stride = int(chunk_size * sample_rate / 1000)
233
+
234
+ cache = {}
235
+ total_chunk_num = int(len((speech)-1)/chunk_stride+1)
236
+ for i in range(total_chunk_num):
237
+ speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
238
+ is_final = i == total_chunk_num - 1
239
+ res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
240
+ if len(res[0]["value"]):
241
+ print(res)
242
+ ```
243
+
244
+ #### 标点恢复
245
+ ```python
246
+ from funasr import AutoModel
247
+
248
+ model = AutoModel(model="ct-punc", model_revision="v2.0.4")
249
+
250
+ res = model.generate(input="那今天的会就到这里吧 happy new year 明年见")
251
+ print(res)
252
+ ```
253
+
254
+ #### 时间戳预测
255
+ ```python
256
+ from funasr import AutoModel
257
+
258
+ model = AutoModel(model="fa-zh", model_revision="v2.0.4")
259
+
260
+ wav_file = f"{model.model_path}/example/asr_example.wav"
261
+ text_file = f"{model.model_path}/example/text.txt"
262
+ res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
263
+ print(res)
264
+ ```
265
+
266
+ 更多详细用法��[示例](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining))
267
+
268
+
269
+ ## 微调
270
+
271
+ 详细用法([示例](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining))
272
+
273
+
274
+
275
+
276
+
277
+ ## 使用方式以及适用范围
278
+
279
+ 运行范围
280
+ - 支持Linux-x86_64、Mac和Windows运行。
281
+
282
+ 使用方式
283
+ - 直接推理:可以直接对长语音数据进行计算,有效语音片段的起止时间点信息(单位:ms)。
284
+
285
+ ## 相关论文以及引用信息
286
+
287
+ ```BibTeX
288
+ @inproceedings{zhang2018deep,
289
+ title={Deep-FSMN for large vocabulary continuous speech recognition},
290
+ author={Zhang, Shiliang and Lei, Ming and Yan, Zhijie and Dai, Lirong},
291
+ booktitle={2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
292
+ pages={5869--5873},
293
+ year={2018},
294
+ organization={IEEE}
295
+ }
296
+ ```
FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch/am.mvn ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <Nnet>
2
+ <Splice> 400 400
3
+ [ 0 ]
4
+ <AddShift> 400 400
5
+ <LearnRateCoef> 0 [ -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 ]
6
+ <Rescale> 400 400
7
+ <LearnRateCoef> 0 [ 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 ]
8
+ </Nnet>
FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch/config.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ frontend: WavFrontendOnline
2
+ frontend_conf:
3
+ fs: 16000
4
+ window: hamming
5
+ n_mels: 80
6
+ frame_length: 25
7
+ frame_shift: 10
8
+ dither: 0.0
9
+ lfr_m: 5
10
+ lfr_n: 1
11
+
12
+ model: FsmnVADStreaming
13
+ model_conf:
14
+ sample_rate: 16000
15
+ detect_mode: 1
16
+ snr_mode: 0
17
+ max_end_silence_time: 800
18
+ max_start_silence_time: 3000
19
+ do_start_point_detection: True
20
+ do_end_point_detection: True
21
+ window_size_ms: 200
22
+ sil_to_speech_time_thres: 150
23
+ speech_to_sil_time_thres: 150
24
+ speech_2_noise_ratio: 1.0
25
+ do_extend: 1
26
+ lookback_time_start_point: 200
27
+ lookahead_time_end_point: 100
28
+ max_single_segment_time: 60000
29
+ snr_thres: -100.0
30
+ noise_frame_num_used_for_snr: 100
31
+ decibel_thres: -100.0
32
+ speech_noise_thres: 0.6
33
+ fe_prior_thres: 0.0001
34
+ silence_pdf_num: 1
35
+ sil_pdf_ids: [0]
36
+ speech_noise_thresh_low: -0.1
37
+ speech_noise_thresh_high: 0.3
38
+ output_frame_probs: False
39
+ frame_in_ms: 10
40
+ frame_length_ms: 25
41
+
42
+ encoder: FSMN
43
+ encoder_conf:
44
+ input_dim: 400
45
+ input_affine_dim: 140
46
+ fsmn_layers: 4
47
+ linear_dim: 250
48
+ proj_dim: 128
49
+ lorder: 20
50
+ rorder: 0
51
+ lstride: 1
52
+ rstride: 0
53
+ output_affine_dim: 140
54
+ output_dim: 248
55
+
56
+
FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch/configuration.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "framework": "pytorch",
3
+ "task" : "voice-activity-detection",
4
+ "pipeline": {"type":"funasr-pipeline"},
5
+ "model": {"type" : "funasr"},
6
+ "file_path_metas": {
7
+ "init_param":"model.pt",
8
+ "config":"config.yaml",
9
+ "frontend_conf":{"cmvn_file": "am.mvn"}},
10
+ "model_name_in_hub": {
11
+ "ms":"iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
12
+ "hf":""}
13
+ }
FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch/fig/struct.png ADDED
FunASR/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3be75be477f0780277f3bae0fe489f48718f585f3a6e45d7dd1fbb1a4255fc5
3
+ size 1721366
FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/.mdl ADDED
Binary file (99 Bytes). View file
 
FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/.msc ADDED
Binary file (838 Bytes). View file
 
FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/README.md ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tasks:
3
+ - auto-speech-recognition
4
+ domain:
5
+ - audio
6
+ model-type:
7
+ - Non-autoregressive
8
+ frameworks:
9
+ - pytorch
10
+ backbone:
11
+ - transformer/conformer
12
+ metrics:
13
+ - CER
14
+ license: Apache License 2.0
15
+ language:
16
+ - cn
17
+ tags:
18
+ - FunASR
19
+ - Paraformer
20
+ - Alibaba
21
+ - ICASSP2024
22
+ - Hotword
23
+ datasets:
24
+ train:
25
+ - 50,000 hour industrial Mandarin task
26
+ test:
27
+ - AISHELL-1-hotword dev/test
28
+ indexing:
29
+ results:
30
+ - task:
31
+ name: Automatic Speech Recognition
32
+ dataset:
33
+ name: 50,000 hour industrial Mandarin task
34
+ type: audio # optional
35
+ args: 16k sampling rate, 8404 characters # optional
36
+ metrics:
37
+ - type: CER
38
+ value: 8.53% # float
39
+ description: greedy search, withou lm, avg.
40
+ args: default
41
+ - type: RTF
42
+ value: 0.0251 # float
43
+ description: GPU inference on V100
44
+ args: batch_size=1
45
+ widgets:
46
+ - task: auto-speech-recognition
47
+ inputs:
48
+ - type: audio
49
+ name: input
50
+ title: 音频
51
+ parameters:
52
+ - name: hotword
53
+ title: 热词
54
+ type: string
55
+ examples:
56
+ - name: 1
57
+ title: 示例1
58
+ inputs:
59
+ - name: input
60
+ data: git://example/asr_example.wav
61
+ parameters:
62
+ - name: hotword
63
+ value: 魔搭
64
+ model_revision: v2.0.4
65
+ inferencespec:
66
+ cpu: 8 #CPU数量
67
+ memory: 4096
68
+ ---
69
+
70
+ # Paraformer-large模型介绍
71
+
72
+ ## Highlights
73
+ Paraformer-large热词版模型支持热词定制功能:实现热词定制化功能,基于提供的热词列表进行激励增强,提升热词的召回率和准确率。
74
+
75
+
76
+ ## <strong>[FunASR开源项目介绍](https://github.com/alibaba-damo-academy/FunASR)</strong>
77
+ <strong>[FunASR](https://github.com/alibaba-damo-academy/FunASR)</strong>希望在语音识别的学术研究和工业应用之间架起一座桥梁。通过发布工业级语音识别模型的训练和微调,研究人员和开发人员可以更方便地进行语音识别模型的研究和生产,并推动语音识别生态的发展。让语音识别更有趣!
78
+
79
+ [**github仓库**](https://github.com/alibaba-damo-academy/FunASR)
80
+ | [**最新动态**](https://github.com/alibaba-damo-academy/FunASR#whats-new)
81
+ | [**环境安装**](https://github.com/alibaba-damo-academy/FunASR#installation)
82
+ | [**服务部署**](https://www.funasr.com)
83
+ | [**模型库**](https://github.com/alibaba-damo-academy/FunASR/tree/main/model_zoo)
84
+ | [**联系我们**](https://github.com/alibaba-damo-academy/FunASR#contact)
85
+
86
+
87
+ ## 模型原理介绍
88
+
89
+ SeACoParaformer是阿里巴巴语音实验室提出的新一代热词定制化非自回归语音识别模型。相比于上一代基于CLAS的热词定制化方案,SeACoParaformer解耦了热词模块与ASR模型,通过后验概率融合的方式进行热词激励,使激励过程可见可控,并且热词召回率显著提升。
90
+
91
+ <p align="center">
92
+ <img src="fig/seaco.png" alt="SeACoParaformer模型结构" width="380" />
93
+
94
+
95
+ SeACoParaformer的模型结构与训练流程如上图所示,通过引入bias encoder进行热词embedding提取,bias decoder进行注意力建模,SeACoParaformer能够捕捉到Predictor输出和Decoder输出的信息与热词的相关性,并且预测与ASR结果同步的热词输出。通过后验概率的融合,实现热词激励。与ContextualParaformer相比,SeACoParaformer有明显的效果提升,如下图所示:
96
+
97
+ <p align="center">
98
+ <img src="fig/res.png" alt="SeACoParaformer模型结构" width="700" />
99
+
100
+ 更详细的细节见:
101
+ - 论文: [SeACo-Paraformer: A Non-Autoregressive ASR System with Flexible and Effective Hotword Customization Ability](https://arxiv.org/abs/2308.03266)
102
+
103
+ ## 复现论文中的结果
104
+ ```python
105
+ from funasr import AutoModel
106
+
107
+ model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
108
+ model_revision="v2.0.4",
109
+ # vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
110
+ # vad_model_revision="v2.0.4",
111
+ # punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
112
+ # punc_model_revision="v2.0.4",
113
+ # spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
114
+ # spk_model_revision="v2.0.2",
115
+ device="cuda:0"
116
+ )
117
+
118
+ res = model.generate(input="YOUR_PATH/aishell1_hotword_dev.scp",
119
+ hotword='./data/dev/hotword.txt',
120
+ batch_size_s=300,
121
+ )
122
+ fout1 = open("dev.output", 'w')
123
+ for resi in res:
124
+ fout1.write("{}\t{}\n".format(resi['key'], resi['text']))
125
+
126
+ res = model.generate(input="YOUR_PATH/aishell1_hotword_test.scp",
127
+ hotword='./data/test/hotword.txt',
128
+ batch_size_s=300,
129
+ )
130
+ fout2 = open("test.output", 'w')
131
+ for resi in res:
132
+ fout2.write("{}\t{}\n".format(resi['key'], resi['text']))
133
+ ```
134
+
135
+ ## 基于ModelScope进行推理
136
+
137
+ - 推理支��音频格式如下:
138
+ - wav文件路径,例如:data/test/audios/asr_example.wav
139
+ - pcm文件路径,例如:data/test/audios/asr_example.pcm
140
+ - wav文件url,例如:https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav
141
+ - wav二进制数据,格式bytes,例如:用户直接从文件里读出bytes数据或者是麦克风录出bytes数据。
142
+ - 已解析的audio音频,例如:audio, rate = soundfile.read("asr_example_zh.wav"),类型为numpy.ndarray或者torch.Tensor。
143
+ - wav.scp文件,需符合如下要求:
144
+
145
+ ```sh
146
+ cat wav.scp
147
+ asr_example1 data/test/audios/asr_example1.wav
148
+ asr_example2 data/test/audios/asr_example2.wav
149
+ ...
150
+ ```
151
+
152
+ - 若输入格式wav文件url,api调用方式可参考如下范例:
153
+
154
+ ```python
155
+ from modelscope.pipelines import pipeline
156
+ from modelscope.utils.constant import Tasks
157
+
158
+ inference_pipeline = pipeline(
159
+ task=Tasks.auto_speech_recognition,
160
+ model='iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', model_revision="v2.0.4")
161
+
162
+ rec_result = inference_pipeline('https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav', hotword='达摩院 魔搭')
163
+ print(rec_result)
164
+ ```
165
+
166
+ - 输入音频为pcm格式,调用api时需要传入音频采样率参数audio_fs,例如:
167
+
168
+ ```python
169
+ rec_result = inference_pipeline('https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.pcm', fs=16000, hotword='达摩院 魔搭')
170
+ ```
171
+
172
+ - 输入音频为wav格式,api调用方式可参考如下范例:
173
+
174
+ ```python
175
+ rec_result = inference_pipeline('asr_example_zh.wav', hotword='达摩院 魔搭')
176
+ ```
177
+
178
+ - 若输入格式为文件wav.scp(注:文件名需要以.scp结尾),可添加 output_dir 参数将识别结果写入文件中,api调用方式可参考如下范例:
179
+
180
+ ```python
181
+ inference_pipeline("wav.scp", output_dir='./output_dir', hotword='达摩院 魔搭')
182
+ ```
183
+ 识别结果输出路径结构如下:
184
+
185
+ ```sh
186
+ tree output_dir/
187
+ output_dir/
188
+ └── 1best_recog
189
+ ├── score
190
+ └── text
191
+
192
+ 1 directory, 3 files
193
+ ```
194
+
195
+ score:识别路径得分
196
+
197
+ text:语音识别结果文件
198
+
199
+
200
+ - 若输入音频为已解析的audio音频,api调用方式可参考如下范例:
201
+
202
+ ```python
203
+ import soundfile
204
+
205
+ waveform, sample_rate = soundfile.read("asr_example_zh.wav")
206
+ rec_result = inference_pipeline(waveform, hotword='达摩院 魔搭')
207
+ ```
208
+
209
+ - ASR、VAD、PUNC模型自由组合
210
+
211
+ 可根据使用需求对VAD和PUNC标点模型进行自由组合,使用方式如下:
212
+ ```python
213
+ inference_pipeline = pipeline(
214
+ task=Tasks.auto_speech_recognition,
215
+ model='iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', model_revision="v2.0.4",
216
+ vad_model='iic/speech_fsmn_vad_zh-cn-16k-common-pytorch', vad_model_revision="v2.0.4",
217
+ punc_model='iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch', punc_model_revision="v2.0.3",
218
+ # spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
219
+ # spk_model_revision="v2.0.2",
220
+ )
221
+ ```
222
+ 若不使用PUNC模型,可配置punc_model=None,或不传入punc_model参数,如需加入LM模型,可增加配置lm_model='iic/speech_transformer_lm_zh-cn-common-vocab8404-pytorch',并设置lm_weight和beam_size参数。
223
+
224
+ ## 基于FunASR进行推理
225
+
226
+ 下面为快速上手教程,测试音频([中文](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav),[英文](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_en.wav))
227
+
228
+ ### 可执行命令行
229
+ 在命令行终端执行:
230
+
231
+ ```shell
232
+ funasr +model=paraformer-zh +vad_model="fsmn-vad" +punc_model="ct-punc" +input=vad_example.wav
233
+ ```
234
+
235
+ 注:支持单条音频文件识别,也支持文件列表,列表为kaldi风格wav.scp:`wav_id wav_path`
236
+
237
+ ### python示例
238
+ #### 非实时语音识别
239
+ ```python
240
+ from funasr import AutoModel
241
+ # paraformer-zh is a multi-functional asr model
242
+ # use vad, punc, spk or not as you need
243
+ model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
244
+ vad_model="fsmn-vad", vad_model_revision="v2.0.4",
245
+ punc_model="ct-punc-c", punc_model_revision="v2.0.4",
246
+ # spk_model="cam++", spk_model_revision="v2.0.2",
247
+ )
248
+ res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
249
+ batch_size_s=300,
250
+ hotword='魔搭')
251
+ print(res)
252
+ ```
253
+ 注:`model_hub`:表示模型仓库,`ms`为选择modelscope下载,`hf`为选择huggingface下载。
254
+
255
+ #### 实时语音识别
256
+
257
+ ```python
258
+ from funasr import AutoModel
259
+
260
+ chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
261
+ encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
262
+ decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
263
+
264
+ model = AutoModel(model="paraformer-zh-streaming", model_revision="v2.0.4")
265
+
266
+ import soundfile
267
+ import os
268
+
269
+ wav_file = os.path.join(model.model_path, "example/asr_example.wav")
270
+ speech, sample_rate = soundfile.read(wav_file)
271
+ chunk_stride = chunk_size[1] * 960 # 600ms
272
+
273
+ cache = {}
274
+ total_chunk_num = int(len((speech)-1)/chunk_stride+1)
275
+ for i in range(total_chunk_num):
276
+ speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
277
+ is_final = i == total_chunk_num - 1
278
+ res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
279
+ print(res)
280
+ ```
281
+
282
+ 注:`chunk_size`为流式延时配置,`[0,10,5]`表示上屏实时出字粒度为`10*60=600ms`,未来信息为`5*60=300ms`。每次推理输入为`600ms`(采样点数为`16000*0.6=960`),输出为对应文字,最后一个语音片段输入需要设置`is_final=True`来强制输出最后一个字。
283
+
284
+ #### 语音端点检测(非实时)
285
+ ```python
286
+ from funasr import AutoModel
287
+
288
+ model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
289
+
290
+ wav_file = f"{model.model_path}/example/asr_example.wav"
291
+ res = model.generate(input=wav_file)
292
+ print(res)
293
+ ```
294
+
295
+ #### 语音端点检测(实时)
296
+ ```python
297
+ from funasr import AutoModel
298
+
299
+ chunk_size = 200 # ms
300
+ model = AutoModel(model="fsmn-vad", model_revision="v2.0.4")
301
+
302
+ import soundfile
303
+
304
+ wav_file = f"{model.model_path}/example/vad_example.wav"
305
+ speech, sample_rate = soundfile.read(wav_file)
306
+ chunk_stride = int(chunk_size * sample_rate / 1000)
307
+
308
+ cache = {}
309
+ total_chunk_num = int(len((speech)-1)/chunk_stride+1)
310
+ for i in range(total_chunk_num):
311
+ speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
312
+ is_final = i == total_chunk_num - 1
313
+ res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
314
+ if len(res[0]["value"]):
315
+ print(res)
316
+ ```
317
+
318
+ #### 标点恢复
319
+ ```python
320
+ from funasr import AutoModel
321
+
322
+ model = AutoModel(model="ct-punc", model_revision="v2.0.4")
323
+
324
+ res = model.generate(input="那今天的会就到这里吧 happy new year 明年见")
325
+ print(res)
326
+ ```
327
+
328
+ #### 时间戳预测
329
+ ```python
330
+ from funasr import AutoModel
331
+
332
+ model = AutoModel(model="fa-zh", model_revision="v2.0.4")
333
+
334
+ wav_file = f"{model.model_path}/example/asr_example.wav"
335
+ text_file = f"{model.model_path}/example/text.txt"
336
+ res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
337
+ print(res)
338
+ ```
339
+
340
+ 更多详细用法([示例](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining))
341
+
342
+
343
+ ## 微调
344
+
345
+ 详细用法([示例](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining))
346
+
347
+
348
+ ## 相关论文以及引用信息
349
+
350
+ ```BibTeX
351
+ @article{shi2023seaco,
352
+ title={SeACo-Paraformer: A Non-Autoregressive ASR System with Flexible and Effective Hotword Customization Ability},
353
+ author={Shi, Xian and Yang, Yexin and Li, Zerui and Zhang, Shiliang},
354
+ journal={arXiv preprint arXiv:2308.03266 (accepted by ICASSP2024)},
355
+ year={2023}
356
+ }
357
+ ```
FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <Nnet>
2
+ <Splice> 560 560
3
+ [ 0 ]
4
+ <AddShift> 560 560
5
+ <LearnRateCoef> 0 [ -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 ]
6
+ <Rescale> 560 560
7
+ <LearnRateCoef> 0 [ 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 ]
8
+ </Nnet>
FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example that demonstrates how to configure a model file.
2
+ # You can modify the configuration according to your own requirements.
3
+
4
+ # to print the register_table:
5
+ # from funasr.utils.register import registry_tables
6
+ # registry_tables.print()
7
+
8
+ # network architecture
9
+ model: SeacoParaformer
10
+ model_conf:
11
+ ctc_weight: 0.0
12
+ lsm_weight: 0.1
13
+ length_normalized_loss: true
14
+ predictor_weight: 1.0
15
+ predictor_bias: 1
16
+ sampling_ratio: 0.75
17
+ inner_dim: 512
18
+ bias_encoder_type: lstm
19
+ bias_encoder_bid: false
20
+ seaco_lsm_weight: 0.1
21
+ seaco_length_normal: true
22
+ train_decoder: true
23
+ NO_BIAS: 8377
24
+
25
+ # encoder
26
+ encoder: SANMEncoder
27
+ encoder_conf:
28
+ output_size: 512
29
+ attention_heads: 4
30
+ linear_units: 2048
31
+ num_blocks: 50
32
+ dropout_rate: 0.1
33
+ positional_dropout_rate: 0.1
34
+ attention_dropout_rate: 0.1
35
+ input_layer: pe
36
+ pos_enc_class: SinusoidalPositionEncoder
37
+ normalize_before: true
38
+ kernel_size: 11
39
+ sanm_shfit: 0
40
+ selfattention_layer_type: sanm
41
+
42
+ # decoder
43
+ decoder: ParaformerSANMDecoder
44
+ decoder_conf:
45
+ attention_heads: 4
46
+ linear_units: 2048
47
+ num_blocks: 16
48
+ dropout_rate: 0.1
49
+ positional_dropout_rate: 0.1
50
+ self_attention_dropout_rate: 0.1
51
+ src_attention_dropout_rate: 0.1
52
+ att_layer_num: 16
53
+ kernel_size: 11
54
+ sanm_shfit: 0
55
+
56
+ # seaco decoder
57
+ seaco_decoder: ParaformerSANMDecoder
58
+ seaco_decoder_conf:
59
+ attention_heads: 4
60
+ linear_units: 1024
61
+ num_blocks: 4
62
+ dropout_rate: 0.1
63
+ positional_dropout_rate: 0.1
64
+ self_attention_dropout_rate: 0.1
65
+ src_attention_dropout_rate: 0.1
66
+ kernel_size: 21
67
+ sanm_shfit: 0
68
+ use_output_layer: false
69
+ wo_input_layer: true
70
+
71
+ predictor: CifPredictorV3
72
+ predictor_conf:
73
+ idim: 512
74
+ threshold: 1.0
75
+ l_order: 1
76
+ r_order: 1
77
+ tail_threshold: 0.45
78
+ smooth_factor2: 0.25
79
+ noise_threshold2: 0.01
80
+ upsample_times: 3
81
+ use_cif1_cnn: false
82
+ upsample_type: cnn_blstm
83
+
84
+ # frontend related
85
+ frontend: WavFrontend
86
+ frontend_conf:
87
+ fs: 16000
88
+ window: hamming
89
+ n_mels: 80
90
+ frame_length: 25
91
+ frame_shift: 10
92
+ lfr_m: 7
93
+ lfr_n: 6
94
+ dither: 0.0
95
+
96
+ specaug: SpecAugLFR
97
+ specaug_conf:
98
+ apply_time_warp: false
99
+ time_warp_window: 5
100
+ time_warp_mode: bicubic
101
+ apply_freq_mask: true
102
+ freq_mask_width_range:
103
+ - 0
104
+ - 30
105
+ lfr_rate: 6
106
+ num_freq_mask: 1
107
+ apply_time_mask: true
108
+ time_mask_width_range:
109
+ - 0
110
+ - 12
111
+ num_time_mask: 1
112
+
113
+ train_conf:
114
+ accum_grad: 1
115
+ grad_clip: 5
116
+ max_epoch: 150
117
+ val_scheduler_criterion:
118
+ - valid
119
+ - acc
120
+ best_model_criterion:
121
+ - - valid
122
+ - acc
123
+ - max
124
+ keep_nbest_models: 10
125
+ log_interval: 50
126
+
127
+ optim: adam
128
+ optim_conf:
129
+ lr: 0.0005
130
+ scheduler: warmuplr
131
+ scheduler_conf:
132
+ warmup_steps: 30000
133
+
134
+ dataset: AudioDatasetHotword
135
+ dataset_conf:
136
+ seaco_id: 8377
137
+ index_ds: IndexDSJsonl
138
+ batch_sampler: DynamicBatchLocalShuffleSampler
139
+ batch_type: example # example or length
140
+ batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
141
+ max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
142
+ buffer_size: 500
143
+ shuffle: True
144
+ num_workers: 0
145
+
146
+ tokenizer: CharTokenizer
147
+ tokenizer_conf:
148
+ unk_symbol: <unk>
149
+ split_with_space: true
150
+
151
+
152
+ ctc_conf:
153
+ dropout_rate: 0.0
154
+ ctc_type: builtin
155
+ reduce: true
156
+ ignore_nan_grad: true
157
+
158
+ normalize: null
159
+ unused_parameters: true
FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/configuration.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "framework": "pytorch",
3
+ "task" : "auto-speech-recognition",
4
+ "model": {"type" : "funasr"},
5
+ "pipeline": {"type":"funasr-pipeline"},
6
+ "model_name_in_hub": {
7
+ "ms":"iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
8
+ "hf":""},
9
+ "file_path_metas": {
10
+ "init_param":"model.pt",
11
+ "config":"config.yaml",
12
+ "tokenizer_conf": {"token_list": "tokens.json", "seg_dict_file": "seg_dict"},
13
+ "frontend_conf":{"cmvn_file": "am.mvn"}}
14
+ }
FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/hotword.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 魔搭
FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/fig/res.png ADDED
FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/fig/seaco.png ADDED
FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d491689244ec5dfbf9170ef3827c358aa10f1f20e42a7c59e15e688647946d1
3
+ size 989763045
FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/seg_dict ADDED
The diff for this file is too large to render. See raw diff
 
FunASR/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.json ADDED
The diff for this file is too large to render. See raw diff
 
GPT_SoVITS/AR/__init__.py ADDED
File without changes
GPT_SoVITS/AR/data/__init__.py ADDED
File without changes
GPT_SoVITS/AR/data/bucket_sampler.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/bucketsampler.py
2
+ import itertools
3
+ import math
4
+ import random
5
+ from random import shuffle
6
+ from typing import Iterator
7
+ from typing import Optional
8
+ from typing import TypeVar
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+ from torch.utils.data import Dataset
13
+ from torch.utils.data import Sampler
14
+
15
+ __all__ = [
16
+ "DistributedBucketSampler",
17
+ ]
18
+
19
+ T_co = TypeVar("T_co", covariant=True)
20
+
21
+
22
+ class DistributedBucketSampler(Sampler[T_co]):
23
+ r"""
24
+ sort the dataset wrt. input length
25
+ divide samples into buckets
26
+ sort within buckets
27
+ divide buckets into batches
28
+ sort batches
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ dataset: Dataset,
34
+ num_replicas: Optional[int] = None,
35
+ rank: Optional[int] = None,
36
+ shuffle: bool = True,
37
+ seed: int = 0,
38
+ drop_last: bool = False,
39
+ batch_size: int = 32,
40
+ ) -> None:
41
+ if num_replicas is None:
42
+ if not dist.is_available():
43
+ raise RuntimeError("Requires distributed package to be available")
44
+ num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1
45
+ if rank is None:
46
+ if not dist.is_available():
47
+ raise RuntimeError("Requires distributed package to be available")
48
+ rank = dist.get_rank() if torch.cuda.is_available() else 0
49
+ if torch.cuda.is_available():
50
+ torch.cuda.set_device(rank)
51
+ if rank >= num_replicas or rank < 0:
52
+ raise ValueError(
53
+ "Invalid rank {}, rank should be in the interval"
54
+ " [0, {}]".format(rank, num_replicas - 1)
55
+ )
56
+ self.dataset = dataset
57
+ self.num_replicas = num_replicas
58
+ self.rank = rank
59
+ self.epoch = 0
60
+ self.drop_last = drop_last
61
+ # If the dataset length is evenly divisible by # of replicas, then there
62
+ # is no need to drop any data, since the dataset will be split equally.
63
+ if (
64
+ self.drop_last and len(self.dataset) % self.num_replicas != 0
65
+ ): # type: ignore[arg-type]
66
+ # Split to nearest available length that is evenly divisible.
67
+ # This is to ensure each rank receives the same amount of data when
68
+ # using this Sampler.
69
+ self.num_samples = math.ceil(
70
+ (len(self.dataset) - self.num_replicas)
71
+ / self.num_replicas # type: ignore[arg-type]
72
+ )
73
+ else:
74
+ self.num_samples = math.ceil(
75
+ len(self.dataset) / self.num_replicas
76
+ ) # type: ignore[arg-type]
77
+ self.total_size = self.num_samples * self.num_replicas
78
+ self.shuffle = shuffle
79
+ self.seed = seed
80
+ self.batch_size = batch_size
81
+ self.id_with_length = self._get_sample_lengths()
82
+ self.id_buckets = self.make_buckets(bucket_width=2.0)
83
+
84
+ def _get_sample_lengths(self):
85
+ id_with_lengths = []
86
+ for i in range(len(self.dataset)):
87
+ id_with_lengths.append((i, self.dataset.get_sample_length(i)))
88
+ id_with_lengths.sort(key=lambda x: x[1])
89
+ return id_with_lengths
90
+
91
+ def make_buckets(self, bucket_width: float = 2.0):
92
+ buckets = []
93
+ cur = []
94
+ max_sec = bucket_width
95
+ for id, sec in self.id_with_length:
96
+ if sec < max_sec:
97
+ cur.append(id)
98
+ else:
99
+ buckets.append(cur)
100
+ cur = [id]
101
+ max_sec += bucket_width
102
+ if len(cur) > 0:
103
+ buckets.append(cur)
104
+ return buckets
105
+
106
+ def __iter__(self) -> Iterator[T_co]:
107
+ if self.shuffle:
108
+ # deterministically shuffle based on epoch and seed
109
+ g = torch.Generator()
110
+ g.manual_seed(self.seed + self.epoch)
111
+ random.seed(self.epoch + self.seed)
112
+ shuffled_bucket = []
113
+ for buc in self.id_buckets:
114
+ buc_copy = buc.copy()
115
+ shuffle(buc_copy)
116
+ shuffled_bucket.append(buc_copy)
117
+ grouped_batch_size = self.batch_size * self.num_replicas
118
+ shuffled_bucket = list(itertools.chain(*shuffled_bucket))
119
+ n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
120
+ batches = [
121
+ shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
122
+ for b in range(n_batch)
123
+ ]
124
+ shuffle(batches)
125
+ indices = list(itertools.chain(*batches))
126
+ else:
127
+ # type: ignore[arg-type]
128
+ indices = list(range(len(self.dataset)))
129
+
130
+ if not self.drop_last:
131
+ # add extra samples to make it evenly divisible
132
+ padding_size = self.total_size - len(indices)
133
+ if padding_size <= len(indices):
134
+ indices += indices[:padding_size]
135
+ else:
136
+ indices += (indices * math.ceil(padding_size / len(indices)))[
137
+ :padding_size
138
+ ]
139
+ else:
140
+ # remove tail of data to make it evenly divisible.
141
+ indices = indices[: self.total_size]
142
+ assert len(indices) == self.total_size
143
+
144
+ # subsample
145
+ indices = indices[self.rank : self.total_size : self.num_replicas]
146
+ assert len(indices) == self.num_samples
147
+
148
+ return iter(indices)
149
+
150
+ def __len__(self) -> int:
151
+ return self.num_samples
152
+
153
+ def set_epoch(self, epoch: int) -> None:
154
+ r"""
155
+ Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
156
+ use a different random ordering for each epoch. Otherwise, the next iteration of this
157
+ sampler will yield the same ordering.
158
+
159
+ Args:
160
+ epoch (int): Epoch number.
161
+ """
162
+ self.epoch = epoch
GPT_SoVITS/AR/data/data_module.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
2
+ from pytorch_lightning import LightningDataModule
3
+ from AR.data.bucket_sampler import DistributedBucketSampler
4
+ from AR.data.dataset import Text2SemanticDataset
5
+ from torch.utils.data import DataLoader
6
+
7
+
8
+ class Text2SemanticDataModule(LightningDataModule):
9
+ def __init__(
10
+ self,
11
+ config,
12
+ train_semantic_path,
13
+ train_phoneme_path,
14
+ dev_semantic_path=None,
15
+ dev_phoneme_path=None,
16
+ ):
17
+ super().__init__()
18
+ self.config = config
19
+ self.train_semantic_path = train_semantic_path
20
+ self.train_phoneme_path = train_phoneme_path
21
+ self.dev_semantic_path = dev_semantic_path
22
+ self.dev_phoneme_path = dev_phoneme_path
23
+ self.num_workers = self.config["data"]["num_workers"]
24
+
25
+ def prepare_data(self):
26
+ pass
27
+
28
+ def setup(self, stage=None, output_logs=False):
29
+ self._train_dataset = Text2SemanticDataset(
30
+ phoneme_path=self.train_phoneme_path,
31
+ semantic_path=self.train_semantic_path,
32
+ max_sec=self.config["data"]["max_sec"],
33
+ pad_val=self.config["data"]["pad_val"],
34
+ )
35
+ self._dev_dataset = self._train_dataset
36
+ # self._dev_dataset = Text2SemanticDataset(
37
+ # phoneme_path=self.dev_phoneme_path,
38
+ # semantic_path=self.dev_semantic_path,
39
+ # max_sample=self.config['data']['max_eval_sample'],
40
+ # max_sec=self.config['data']['max_sec'],
41
+ # pad_val=self.config['data']['pad_val'])
42
+
43
+ def train_dataloader(self):
44
+ batch_size = max(min(self.config["train"]["batch_size"],len(self._train_dataset)//4),1)#防止不保存
45
+ sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
46
+ return DataLoader(
47
+ self._train_dataset,
48
+ batch_size=batch_size,
49
+ sampler=sampler,
50
+ collate_fn=self._train_dataset.collate,
51
+ num_workers=self.num_workers,
52
+ persistent_workers=True,
53
+ prefetch_factor=16,
54
+ )
55
+
56
+ def val_dataloader(self):
57
+ return DataLoader(
58
+ self._dev_dataset,
59
+ batch_size=1,
60
+ shuffle=False,
61
+ collate_fn=self._train_dataset.collate,
62
+ num_workers=max(self.num_workers, 12),
63
+ persistent_workers=True,
64
+ prefetch_factor=16,
65
+ )
66
+
67
+ # 这个会使用到嘛?
68
+ def test_dataloader(self):
69
+ return DataLoader(
70
+ self._dev_dataset,
71
+ batch_size=1,
72
+ shuffle=False,
73
+ collate_fn=self._train_dataset.collate,
74
+ )
GPT_SoVITS/AR/data/dataset.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
2
+ import pdb
3
+ import sys
4
+
5
+ # sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
6
+ import traceback, os
7
+ from typing import Dict
8
+ from typing import List
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch, json
13
+ from torch.utils.data import DataLoader
14
+ from torch.utils.data import Dataset
15
+ from transformers import AutoTokenizer
16
+
17
+ from text import cleaned_text_to_sequence
18
+
19
+ # from config import exp_dir
20
+
21
+
22
+ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
23
+ seq = sequences[0]
24
+ ndim = seq.ndim
25
+ if axis < 0:
26
+ axis += ndim
27
+ dtype = seq.dtype
28
+ pad_value = dtype.type(pad_value)
29
+ seq_lengths = [seq.shape[axis] for seq in sequences]
30
+ max_length = np.max(seq_lengths)
31
+
32
+ padded_sequences = []
33
+ for seq, length in zip(sequences, seq_lengths):
34
+ padding = (
35
+ [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
36
+ )
37
+ padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
38
+ padded_sequences.append(padded_seq)
39
+ batch = np.stack(padded_sequences)
40
+ return batch
41
+
42
+
43
+ class Text2SemanticDataset(Dataset):
44
+ """dataset class for text tokens to semantic model training."""
45
+
46
+ def __init__(
47
+ self,
48
+ phoneme_path: str,
49
+ semantic_path: str,
50
+ max_sample: int = None,
51
+ max_sec: int = 100,
52
+ pad_val: int = 1024,
53
+ # min value of phoneme/sec
54
+ min_ps_ratio: int = 3,
55
+ # max value of phoneme/sec
56
+ max_ps_ratio: int = 25,
57
+ ) -> None:
58
+ super().__init__()
59
+
60
+ self.semantic_data = pd.read_csv(
61
+ semantic_path, delimiter="\t", encoding="utf-8"
62
+ )
63
+ # get dict
64
+ self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
65
+ self.path3 = "%s/3-bert" % (
66
+ os.path.basename(phoneme_path)
67
+ ) # "%s/3-bert"%exp_dir#bert_dir
68
+ self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
69
+ assert os.path.exists(self.path2)
70
+ assert os.path.exists(self.path6)
71
+ self.phoneme_data = {}
72
+ with open(self.path2, "r", encoding="utf8") as f:
73
+ lines = f.read().strip("\n").split("\n")
74
+
75
+ for line in lines:
76
+ tmp = line.split("\t")
77
+ if len(tmp) != 4:
78
+ continue
79
+ self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]]
80
+
81
+ # self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
82
+ # pad for semantic tokens
83
+ self.PAD: int = pad_val
84
+ # self.hz = 25
85
+ # with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
86
+ # data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
87
+ # self.hz=int(data[:-2])#
88
+ self.hz = int(os.environ.get("hz", "25hz")[:-2])
89
+
90
+ # max seconds of semantic token
91
+ self.max_sec = max_sec
92
+ self.min_ps_ratio = min_ps_ratio
93
+ self.max_ps_ratio = max_ps_ratio
94
+
95
+ if max_sample is not None:
96
+ self.semantic_data = self.semantic_data[:max_sample]
97
+
98
+ # {idx: (semantic, phoneme)}
99
+ # semantic list, phoneme list
100
+ self.semantic_phoneme = []
101
+ self.item_names = []
102
+
103
+ self.inited = False
104
+
105
+ if not self.inited:
106
+ # 调用初始化函数
107
+ self.init_batch()
108
+ self.inited = True
109
+ del self.semantic_data
110
+ del self.phoneme_data
111
+ # self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
112
+ # self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
113
+
114
+ def init_batch(self):
115
+ semantic_data_len = len(self.semantic_data)
116
+ phoneme_data_len = len(self.phoneme_data.keys())
117
+ print("semantic_data_len:", semantic_data_len)
118
+ print("phoneme_data_len:", phoneme_data_len)
119
+ print(self.semantic_data)
120
+ idx = 0
121
+ num_not_in = 0
122
+ num_deleted_bigger = 0
123
+ num_deleted_ps = 0
124
+ for i in range(semantic_data_len):
125
+ # 先依次遍历
126
+ # get str
127
+ item_name = self.semantic_data.iloc[i,0]
128
+ # print(self.phoneme_data)
129
+ try:
130
+ phoneme, word2ph, text = self.phoneme_data[item_name]
131
+ except Exception:
132
+ traceback.print_exc()
133
+ # print(f"{item_name} not in self.phoneme_data !")
134
+ num_not_in += 1
135
+ continue
136
+
137
+ semantic_str = self.semantic_data.iloc[i,1]
138
+ # get token list
139
+ semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
140
+ # (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
141
+ # 过滤掉太长的样本
142
+ if (
143
+ len(semantic_ids) > self.max_sec * self.hz
144
+ ): #########1###根据token个数推测总时长过滤时长60s(config里)#40*25=1k
145
+ num_deleted_bigger += 1
146
+ continue
147
+ # (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
148
+ phoneme = phoneme.split(" ")
149
+
150
+ try:
151
+ phoneme_ids = cleaned_text_to_sequence(phoneme)
152
+ except:
153
+ traceback.print_exc()
154
+ # print(f"{item_name} not in self.phoneme_data !")
155
+ num_not_in += 1
156
+ continue
157
+ # if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
158
+ if (
159
+ len(phoneme_ids) > self.max_sec * self.hz / 2.5
160
+ ): ###########2:改为恒定限制为semantic/2.5就行
161
+ num_deleted_ps += 1
162
+ continue
163
+ # if len(semantic_ids) > 1000:###########3
164
+ # num_deleted_bigger += 1
165
+ # continue
166
+
167
+ ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
168
+
169
+ if (
170
+ ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
171
+ ): ##########4#3~25#每秒多少个phone
172
+ num_deleted_ps += 1
173
+ # print(item_name)
174
+ continue
175
+
176
+ self.semantic_phoneme.append((semantic_ids, phoneme_ids))
177
+ idx += 1
178
+ self.item_names.append(item_name)
179
+
180
+ min_num = 100 # 20直接不补#30补了也不存ckpt
181
+ leng = len(self.semantic_phoneme)
182
+ if leng < min_num:
183
+ tmp1 = self.semantic_phoneme
184
+ tmp2 = self.item_names
185
+ self.semantic_phoneme = []
186
+ self.item_names = []
187
+ for _ in range(max(2, int(min_num / leng))):
188
+ self.semantic_phoneme += tmp1
189
+ self.item_names += tmp2
190
+ if num_not_in > 0:
191
+ print(f"there are {num_not_in} semantic datas not in phoneme datas")
192
+ if num_deleted_bigger > 0:
193
+ print(
194
+ f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
195
+ )
196
+ if num_deleted_ps > 0:
197
+ # 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
198
+ print(
199
+ f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
200
+ )
201
+ """
202
+ there are 31 semantic datas not in phoneme datas
203
+ deleted 34 audios who's duration are bigger than 54 seconds
204
+ deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
205
+ dataset.__len__(): 366463
206
+
207
+ """
208
+ # 345410 for LibriTTS
209
+ print("dataset.__len__():", self.__len__())
210
+
211
+ def __get_item_names__(self) -> List[str]:
212
+ return self.item_names
213
+
214
+ def __len__(self) -> int:
215
+ return len(self.semantic_phoneme)
216
+
217
+ def __getitem__(self, idx: int) -> Dict:
218
+ semantic_ids, phoneme_ids = self.semantic_phoneme[idx]
219
+ item_name = self.item_names[idx]
220
+ phoneme_ids_len = len(phoneme_ids)
221
+ # semantic tokens target
222
+ semantic_ids_len = len(semantic_ids)
223
+
224
+ flag = 0
225
+ path_bert = "%s/%s.pt" % (self.path3, item_name)
226
+ if os.path.exists(path_bert) == True:
227
+ bert_feature = torch.load(path_bert, map_location="cpu")
228
+ else:
229
+ flag = 1
230
+ if flag == 1:
231
+ # bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
232
+ bert_feature = None
233
+ else:
234
+ assert bert_feature.shape[-1] == len(phoneme_ids)
235
+ return {
236
+ "idx": idx,
237
+ "phoneme_ids": phoneme_ids,
238
+ "phoneme_ids_len": phoneme_ids_len,
239
+ "semantic_ids": semantic_ids,
240
+ "semantic_ids_len": semantic_ids_len,
241
+ "bert_feature": bert_feature,
242
+ }
243
+
244
+ def get_sample_length(self, idx: int):
245
+ semantic_ids = self.semantic_phoneme[idx][0]
246
+ sec = 1.0 * len(semantic_ids) / self.hz
247
+ return sec
248
+
249
+ def collate(self, examples: List[Dict]) -> Dict:
250
+ sample_index: List[int] = []
251
+ phoneme_ids: List[torch.Tensor] = []
252
+ phoneme_ids_lens: List[int] = []
253
+ semantic_ids: List[torch.Tensor] = []
254
+ semantic_ids_lens: List[int] = []
255
+ # return
256
+
257
+ for item in examples:
258
+ sample_index.append(item["idx"])
259
+ phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
260
+ semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
261
+ phoneme_ids_lens.append(item["phoneme_ids_len"])
262
+ semantic_ids_lens.append(item["semantic_ids_len"])
263
+
264
+ # pad 0
265
+ phoneme_ids = batch_sequences(phoneme_ids)
266
+ semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
267
+
268
+ # # convert each batch to torch.tensor
269
+ phoneme_ids = torch.tensor(phoneme_ids)
270
+ semantic_ids = torch.tensor(semantic_ids)
271
+ phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
272
+ semantic_ids_lens = torch.tensor(semantic_ids_lens)
273
+ bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
274
+ bert_padded.zero_()
275
+
276
+ for idx, item in enumerate(examples):
277
+ bert = item["bert_feature"]
278
+ if bert != None:
279
+ bert_padded[idx, :, : bert.shape[-1]] = bert
280
+
281
+ return {
282
+ # List[int]
283
+ "ids": sample_index,
284
+ # torch.Tensor (B, max_phoneme_length)
285
+ "phoneme_ids": phoneme_ids,
286
+ # torch.Tensor (B)
287
+ "phoneme_ids_len": phoneme_ids_lens,
288
+ # torch.Tensor (B, max_semantic_ids_length)
289
+ "semantic_ids": semantic_ids,
290
+ # torch.Tensor (B)
291
+ "semantic_ids_len": semantic_ids_lens,
292
+ # torch.Tensor (B, 1024, max_phoneme_length)
293
+ "bert_feature": bert_padded,
294
+ }
295
+
296
+
297
+ if __name__ == "__main__":
298
+ root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
299
+ dataset = Text2SemanticDataset(
300
+ phoneme_path=root_dir + "phoneme_train.npy",
301
+ semantic_path=root_dir + "semantic_train.tsv",
302
+ )
303
+
304
+ batch_size = 12
305
+ dataloader = DataLoader(
306
+ dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
307
+ )
308
+ for i, batch in enumerate(dataloader):
309
+ if i % 1000 == 0:
310
+ print(i)
311
+ # if i == 0:
312
+ # print('batch["ids"]:', batch["ids"])
313
+ # print('batch["phoneme_ids"]:', batch["phoneme_ids"],
314
+ # batch["phoneme_ids"].shape)
315
+ # print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
316
+ # batch["phoneme_ids_len"].shape)
317
+ # print('batch["semantic_ids"]:', batch["semantic_ids"],
318
+ # batch["semantic_ids"].shape)
319
+ # print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
320
+ # batch["semantic_ids_len"].shape)
GPT_SoVITS/AR/models/__init__.py ADDED
File without changes
GPT_SoVITS/AR/models/t2s_lightning_module.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
2
+ import os, sys
3
+
4
+ now_dir = os.getcwd()
5
+ sys.path.append(now_dir)
6
+ from typing import Dict
7
+
8
+ import torch
9
+ from pytorch_lightning import LightningModule
10
+ from AR.models.t2s_model import Text2SemanticDecoder
11
+ from AR.modules.lr_schedulers import WarmupCosineLRSchedule
12
+ from AR.modules.optim import ScaledAdam
13
+
14
+
15
+ class Text2SemanticLightningModule(LightningModule):
16
+ def __init__(self, config, output_dir, is_train=True):
17
+ super().__init__()
18
+ self.config = config
19
+ self.top_k = 3
20
+ self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
21
+ pretrained_s1 = config.get("pretrained_s1")
22
+ if pretrained_s1 and is_train:
23
+ # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
24
+ print(
25
+ self.load_state_dict(
26
+ torch.load(pretrained_s1, map_location="cpu")["weight"]
27
+ )
28
+ )
29
+ if is_train:
30
+ self.automatic_optimization = False
31
+ self.save_hyperparameters()
32
+ self.eval_dir = output_dir / "eval"
33
+ self.eval_dir.mkdir(parents=True, exist_ok=True)
34
+
35
+ def training_step(self, batch: Dict, batch_idx: int):
36
+ opt = self.optimizers()
37
+ scheduler = self.lr_schedulers()
38
+ loss, acc = self.model.forward(
39
+ batch["phoneme_ids"],
40
+ batch["phoneme_ids_len"],
41
+ batch["semantic_ids"],
42
+ batch["semantic_ids_len"],
43
+ batch["bert_feature"],
44
+ )
45
+ self.manual_backward(loss)
46
+ if batch_idx > 0 and batch_idx % 4 == 0:
47
+ opt.step()
48
+ opt.zero_grad()
49
+ scheduler.step()
50
+
51
+ self.log(
52
+ "total_loss",
53
+ loss,
54
+ on_step=True,
55
+ on_epoch=True,
56
+ prog_bar=True,
57
+ sync_dist=True,
58
+ )
59
+ self.log(
60
+ "lr",
61
+ scheduler.get_last_lr()[0],
62
+ on_epoch=True,
63
+ prog_bar=True,
64
+ sync_dist=True,
65
+ )
66
+ self.log(
67
+ f"top_{self.top_k}_acc",
68
+ acc,
69
+ on_step=True,
70
+ on_epoch=True,
71
+ prog_bar=True,
72
+ sync_dist=True,
73
+ )
74
+
75
+ def validation_step(self, batch: Dict, batch_idx: int):
76
+ return
77
+
78
+ # # get loss
79
+ # loss, acc = self.model.forward(
80
+ # batch['phoneme_ids'], batch['phoneme_ids_len'],
81
+ # batch['semantic_ids'], batch['semantic_ids_len'],
82
+ # batch['bert_feature']
83
+ # )
84
+ #
85
+ # self.log(
86
+ # "val_total_loss",
87
+ # loss,
88
+ # on_step=True,
89
+ # on_epoch=True,
90
+ # prog_bar=True,
91
+ # sync_dist=True)
92
+ # self.log(
93
+ # f"val_top_{self.top_k}_acc",
94
+ # acc,
95
+ # on_step=True,
96
+ # on_epoch=True,
97
+ # prog_bar=True,
98
+ # sync_dist=True)
99
+ #
100
+ # # get infer output
101
+ # semantic_len = batch['semantic_ids'].size(1)
102
+ # prompt_len = min(int(semantic_len * 0.5), 150)
103
+ # prompt = batch['semantic_ids'][:, :prompt_len]
104
+ # pred_semantic = self.model.infer(batch['phoneme_ids'],
105
+ # batch['phoneme_ids_len'], prompt,
106
+ # batch['bert_feature']
107
+ # )
108
+ # save_name = f'semantic_toks_{batch_idx}.pt'
109
+ # save_path = os.path.join(self.eval_dir, save_name)
110
+ # torch.save(pred_semantic.detach().cpu(), save_path)
111
+
112
+ def configure_optimizers(self):
113
+ model_parameters = self.model.parameters()
114
+ parameters_names = []
115
+ parameters_names.append(
116
+ [name_param_pair[0] for name_param_pair in self.model.named_parameters()]
117
+ )
118
+ lm_opt = ScaledAdam(
119
+ model_parameters,
120
+ lr=0.01,
121
+ betas=(0.9, 0.95),
122
+ clipping_scale=2.0,
123
+ parameters_names=parameters_names,
124
+ show_dominant_parameters=False,
125
+ clipping_update_period=1000,
126
+ )
127
+
128
+ return {
129
+ "optimizer": lm_opt,
130
+ "lr_scheduler": {
131
+ "scheduler": WarmupCosineLRSchedule(
132
+ lm_opt,
133
+ init_lr=self.config["optimizer"]["lr_init"],
134
+ peak_lr=self.config["optimizer"]["lr"],
135
+ end_lr=self.config["optimizer"]["lr_end"],
136
+ warmup_steps=self.config["optimizer"]["warmup_steps"],
137
+ total_steps=self.config["optimizer"]["decay_steps"],
138
+ )
139
+ },
140
+ }
GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
2
+ import os, sys
3
+
4
+ now_dir = os.getcwd()
5
+ sys.path.append(now_dir)
6
+ from typing import Dict
7
+
8
+ import torch
9
+ from pytorch_lightning import LightningModule
10
+ from AR.models.t2s_model_onnx import Text2SemanticDecoder
11
+ from AR.modules.lr_schedulers import WarmupCosineLRSchedule
12
+ from AR.modules.optim import ScaledAdam
13
+
14
+
15
+ class Text2SemanticLightningModule(LightningModule):
16
+ def __init__(self, config, output_dir, is_train=True):
17
+ super().__init__()
18
+ self.config = config
19
+ self.top_k = 3
20
+ self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
21
+ pretrained_s1 = config.get("pretrained_s1")
22
+ if pretrained_s1 and is_train:
23
+ # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
24
+ print(
25
+ self.load_state_dict(
26
+ torch.load(pretrained_s1, map_location="cpu")["weight"]
27
+ )
28
+ )
29
+ if is_train:
30
+ self.automatic_optimization = False
31
+ self.save_hyperparameters()
32
+ self.eval_dir = output_dir / "eval"
33
+ self.eval_dir.mkdir(parents=True, exist_ok=True)
34
+
35
+ def training_step(self, batch: Dict, batch_idx: int):
36
+ opt = self.optimizers()
37
+ scheduler = self.lr_schedulers()
38
+ loss, acc = self.model.forward(
39
+ batch["phoneme_ids"],
40
+ batch["phoneme_ids_len"],
41
+ batch["semantic_ids"],
42
+ batch["semantic_ids_len"],
43
+ batch["bert_feature"],
44
+ )
45
+ self.manual_backward(loss)
46
+ if batch_idx > 0 and batch_idx % 4 == 0:
47
+ opt.step()
48
+ opt.zero_grad()
49
+ scheduler.step()
50
+
51
+ self.log(
52
+ "total_loss",
53
+ loss,
54
+ on_step=True,
55
+ on_epoch=True,
56
+ prog_bar=True,
57
+ sync_dist=True,
58
+ )
59
+ self.log(
60
+ "lr",
61
+ scheduler.get_last_lr()[0],
62
+ on_epoch=True,
63
+ prog_bar=True,
64
+ sync_dist=True,
65
+ )
66
+ self.log(
67
+ f"top_{self.top_k}_acc",
68
+ acc,
69
+ on_step=True,
70
+ on_epoch=True,
71
+ prog_bar=True,
72
+ sync_dist=True,
73
+ )
74
+
75
+ def validation_step(self, batch: Dict, batch_idx: int):
76
+ return
77
+
78
+ def configure_optimizers(self):
79
+ model_parameters = self.model.parameters()
80
+ parameters_names = []
81
+ parameters_names.append(
82
+ [name_param_pair[0] for name_param_pair in self.model.named_parameters()]
83
+ )
84
+ lm_opt = ScaledAdam(
85
+ model_parameters,
86
+ lr=0.01,
87
+ betas=(0.9, 0.95),
88
+ clipping_scale=2.0,
89
+ parameters_names=parameters_names,
90
+ show_dominant_parameters=False,
91
+ clipping_update_period=1000,
92
+ )
93
+
94
+ return {
95
+ "optimizer": lm_opt,
96
+ "lr_scheduler": {
97
+ "scheduler": WarmupCosineLRSchedule(
98
+ lm_opt,
99
+ init_lr=self.config["optimizer"]["lr_init"],
100
+ peak_lr=self.config["optimizer"]["lr"],
101
+ end_lr=self.config["optimizer"]["lr_end"],
102
+ warmup_steps=self.config["optimizer"]["warmup_steps"],
103
+ total_steps=self.config["optimizer"]["decay_steps"],
104
+ )
105
+ },
106
+ }
GPT_SoVITS/AR/models/t2s_model.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ from AR.models.utils import make_pad_mask
6
+ from AR.models.utils import (
7
+ topk_sampling,
8
+ sample,
9
+ logits_to_probs,
10
+ multinomial_sample_one_no_sync,
11
+ )
12
+ from AR.modules.embedding import SinePositionalEmbedding
13
+ from AR.modules.embedding import TokenEmbedding
14
+ from AR.modules.transformer import LayerNorm
15
+ from AR.modules.transformer import TransformerEncoder
16
+ from AR.modules.transformer import TransformerEncoderLayer
17
+ from torch import nn
18
+ from torch.nn import functional as F
19
+ from torchmetrics.classification import MulticlassAccuracy
20
+
21
+ default_config = {
22
+ "embedding_dim": 512,
23
+ "hidden_dim": 512,
24
+ "num_head": 8,
25
+ "num_layers": 12,
26
+ "num_codebook": 8,
27
+ "p_dropout": 0.0,
28
+ "vocab_size": 1024 + 1,
29
+ "phoneme_vocab_size": 512,
30
+ "EOS": 1024,
31
+ }
32
+
33
+
34
+ class Text2SemanticDecoder(nn.Module):
35
+ def __init__(self, config, norm_first=False, top_k=3):
36
+ super(Text2SemanticDecoder, self).__init__()
37
+ self.model_dim = config["model"]["hidden_dim"]
38
+ self.embedding_dim = config["model"]["embedding_dim"]
39
+ self.num_head = config["model"]["head"]
40
+ self.num_layers = config["model"]["n_layer"]
41
+ self.norm_first = norm_first
42
+ self.vocab_size = config["model"]["vocab_size"]
43
+ self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
44
+ self.p_dropout = config["model"]["dropout"]
45
+ self.EOS = config["model"]["EOS"]
46
+ self.norm_first = norm_first
47
+ assert self.EOS == self.vocab_size - 1
48
+ # should be same as num of kmeans bin
49
+ # assert self.EOS == 1024
50
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
51
+ self.ar_text_embedding = TokenEmbedding(
52
+ self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
53
+ )
54
+ self.ar_text_position = SinePositionalEmbedding(
55
+ self.embedding_dim, dropout=0.1, scale=False, alpha=True
56
+ )
57
+ self.ar_audio_embedding = TokenEmbedding(
58
+ self.embedding_dim, self.vocab_size, self.p_dropout
59
+ )
60
+ self.ar_audio_position = SinePositionalEmbedding(
61
+ self.embedding_dim, dropout=0.1, scale=False, alpha=True
62
+ )
63
+
64
+ self.h = TransformerEncoder(
65
+ TransformerEncoderLayer(
66
+ d_model=self.model_dim,
67
+ nhead=self.num_head,
68
+ dim_feedforward=self.model_dim * 4,
69
+ dropout=0.1,
70
+ batch_first=True,
71
+ norm_first=norm_first,
72
+ ),
73
+ num_layers=self.num_layers,
74
+ norm=LayerNorm(self.model_dim) if norm_first else None,
75
+ )
76
+
77
+ self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
78
+ self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
79
+
80
+ self.ar_accuracy_metric = MulticlassAccuracy(
81
+ self.vocab_size,
82
+ top_k=top_k,
83
+ average="micro",
84
+ multidim_average="global",
85
+ ignore_index=self.EOS,
86
+ )
87
+
88
+ def forward(self, x, x_lens, y, y_lens, bert_feature):
89
+ """
90
+ x: phoneme_ids
91
+ y: semantic_ids
92
+ """
93
+ x = self.ar_text_embedding(x)
94
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
95
+ x = self.ar_text_position(x)
96
+ x_mask = make_pad_mask(x_lens)
97
+
98
+ y_mask = make_pad_mask(y_lens)
99
+ y_mask_int = y_mask.type(torch.int64)
100
+ codes = y.type(torch.int64) * (1 - y_mask_int)
101
+
102
+ # Training
103
+ # AR Decoder
104
+ y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
105
+ x_len = x_lens.max()
106
+ y_len = y_lens.max()
107
+ y_emb = self.ar_audio_embedding(y)
108
+ y_pos = self.ar_audio_position(y_emb)
109
+
110
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
111
+ ar_xy_padding_mask = xy_padding_mask
112
+
113
+ x_attn_mask = F.pad(
114
+ torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
115
+ (0, y_len),
116
+ value=True,
117
+ )
118
+ y_attn_mask = F.pad(
119
+ torch.triu(
120
+ torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
121
+ diagonal=1,
122
+ ),
123
+ (x_len, 0),
124
+ value=False,
125
+ )
126
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
127
+ bsz, src_len = x.shape[0], x_len + y_len
128
+ _xy_padding_mask = (
129
+ ar_xy_padding_mask.view(bsz, 1, 1, src_len)
130
+ .expand(-1, self.num_head, -1, -1)
131
+ .reshape(bsz * self.num_head, 1, src_len)
132
+ )
133
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
134
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
135
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
136
+ xy_attn_mask = new_attn_mask
137
+ # x 和完整的 y 一次性输入模型
138
+ xy_pos = torch.concat([x, y_pos], dim=1)
139
+ xy_dec, _ = self.h(
140
+ (xy_pos, None),
141
+ mask=xy_attn_mask,
142
+ )
143
+ logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
144
+ # loss
145
+ # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
146
+ loss = F.cross_entropy(logits, targets, reduction="sum")
147
+ acc = self.ar_accuracy_metric(logits.detach(), targets).item()
148
+ return loss, acc
149
+
150
+ # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
151
+ def infer(
152
+ self,
153
+ x,
154
+ x_lens,
155
+ prompts,
156
+ bert_feature,
157
+ top_k: int = -100,
158
+ early_stop_num: int = -1,
159
+ temperature: float = 1.0,
160
+ ):
161
+ x = self.ar_text_embedding(x)
162
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
163
+ x = self.ar_text_position(x)
164
+
165
+ # AR Decoder
166
+ y = prompts
167
+ prefix_len = y.shape[1]
168
+ x_len = x.shape[1]
169
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
170
+ stop = False
171
+ for _ in tqdm(range(1500)):
172
+ y_emb = self.ar_audio_embedding(y)
173
+ y_pos = self.ar_audio_position(y_emb)
174
+ # x 和逐渐增长的 y 一起输入给模型
175
+ xy_pos = torch.concat([x, y_pos], dim=1)
176
+ y_len = y.shape[1]
177
+ x_attn_mask_pad = F.pad(
178
+ x_attn_mask,
179
+ (0, y_len),
180
+ value=True,
181
+ )
182
+ y_attn_mask = F.pad(
183
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
184
+ (x_len, 0),
185
+ value=False,
186
+ )
187
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
188
+ y.device
189
+ )
190
+
191
+ xy_dec, _ = self.h(
192
+ (xy_pos, None),
193
+ mask=xy_attn_mask,
194
+ )
195
+ logits = self.ar_predict_layer(xy_dec[:, -1])
196
+ samples = topk_sampling(
197
+ logits, top_k=top_k, top_p=1.0, temperature=temperature
198
+ )
199
+
200
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
201
+ print("use early stop num:", early_stop_num)
202
+ stop = True
203
+
204
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
205
+ # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
206
+ stop = True
207
+ if stop:
208
+ if prompts.shape[1] == y.shape[1]:
209
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
210
+ print("bad zero prediction")
211
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
212
+ break
213
+ # 本次生成的 semantic_ids 和之前的 y 构成新的 y
214
+ # print(samples.shape)#[1,1]#第一个1是bs
215
+ # import os
216
+ # os._exit(2333)
217
+ y = torch.concat([y, samples], dim=1)
218
+ return y
219
+
220
+ def pad_y_eos(self, y, y_mask_int, eos_id):
221
+ targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
222
+ y_mask_int, (0, 1), value=1
223
+ )
224
+ # 错位
225
+ return targets[:, :-1], targets[:, 1:]
226
+
227
+ def infer_panel(
228
+ self,
229
+ x, #####全部文本token
230
+ x_lens,
231
+ prompts, ####参考音频token
232
+ bert_feature,
233
+ top_k: int = -100,
234
+ early_stop_num: int = -1,
235
+ temperature: float = 1.0,
236
+ ):
237
+ x = self.ar_text_embedding(x)
238
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
239
+ x = self.ar_text_position(x)
240
+
241
+ # AR Decoder
242
+ y = prompts
243
+ prefix_len = y.shape[1]
244
+ x_len = x.shape[1]
245
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
246
+ stop = False
247
+ # print(1111111,self.num_layers)
248
+ cache = {
249
+ "all_stage": self.num_layers,
250
+ "k": [None] * self.num_layers, ###根据配置自己手写
251
+ "v": [None] * self.num_layers,
252
+ # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
253
+ "y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行
254
+ # "logits":None,###原版就已经只对结尾求再拼接了,不用管
255
+ # "xy_dec":None,###不需要,本来只需要最后一个做logits
256
+ "first_infer": 1,
257
+ "stage": 0,
258
+ }
259
+ for idx in tqdm(range(1500)):
260
+ if cache["first_infer"] == 1:
261
+ y_emb = self.ar_audio_embedding(y)
262
+ else:
263
+ y_emb = torch.cat(
264
+ [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
265
+ )
266
+ cache["y_emb"] = y_emb
267
+ y_pos = self.ar_audio_position(y_emb)
268
+ # x 和逐渐增长的 y 一起输入给模型
269
+ if cache["first_infer"] == 1:
270
+ xy_pos = torch.concat([x, y_pos], dim=1)
271
+ else:
272
+ xy_pos = y_pos[:, -1:]
273
+ y_len = y_pos.shape[1]
274
+ ###以下3个不做缓存
275
+ if cache["first_infer"] == 1:
276
+ x_attn_mask_pad = F.pad(
277
+ x_attn_mask,
278
+ (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
279
+ value=True,
280
+ )
281
+ y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
282
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
283
+ (x_len, 0),
284
+ value=False,
285
+ )
286
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
287
+ y.device
288
+ )
289
+ else:
290
+ ###最右边一列(是错的)
291
+ # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
292
+ # xy_attn_mask[:,-1]=False
293
+ ###最下面一行(是对的)
294
+ xy_attn_mask = torch.zeros(
295
+ (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
296
+ )
297
+ # pdb.set_trace()
298
+ ###缓存重头戏
299
+ # print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len)
300
+ xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
301
+ logits = self.ar_predict_layer(
302
+ xy_dec[:, -1]
303
+ ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的
304
+ # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
305
+ if(idx==0):###第一次跑不能EOS否则没有了
306
+ logits = logits[:, :-1] ###刨除1024终止符号的概率
307
+ samples = sample(
308
+ logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35
309
+ )[0].unsqueeze(0)
310
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
311
+ print("use early stop num:", early_stop_num)
312
+ stop = True
313
+
314
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
315
+ # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
316
+ stop = True
317
+ if stop:
318
+ if prompts.shape[1] == y.shape[1]:
319
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
320
+ print("bad zero prediction")
321
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
322
+ break
323
+ # 本次生成的 semantic_ids 和之前的 y 构成新的 y
324
+ # print(samples.shape)#[1,1]#第一个1是bs
325
+ y = torch.concat([y, samples], dim=1)
326
+ cache["first_infer"] = 0
327
+ return y, idx
GPT_SoVITS/AR/models/t2s_model_onnx.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ from AR.modules.embedding_onnx import SinePositionalEmbedding
6
+ from AR.modules.embedding_onnx import TokenEmbedding
7
+ from AR.modules.transformer_onnx import LayerNorm
8
+ from AR.modules.transformer_onnx import TransformerEncoder
9
+ from AR.modules.transformer_onnx import TransformerEncoderLayer
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+ from torchmetrics.classification import MulticlassAccuracy
13
+
14
+ default_config = {
15
+ "embedding_dim": 512,
16
+ "hidden_dim": 512,
17
+ "num_head": 8,
18
+ "num_layers": 12,
19
+ "num_codebook": 8,
20
+ "p_dropout": 0.0,
21
+ "vocab_size": 1024 + 1,
22
+ "phoneme_vocab_size": 512,
23
+ "EOS": 1024,
24
+ }
25
+
26
+ inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
27
+
28
+ def logits_to_probs(
29
+ logits,
30
+ previous_tokens = None,
31
+ temperature: float = 1.0,
32
+ top_k = None,
33
+ top_p = None,
34
+ repetition_penalty: float = 1.0,
35
+ ):
36
+ previous_tokens = previous_tokens.squeeze()
37
+ if previous_tokens is not None and repetition_penalty != 1.0:
38
+ previous_tokens = previous_tokens.long()
39
+ score = torch.gather(logits, dim=0, index=previous_tokens)
40
+ score = torch.where(
41
+ score < 0, score * repetition_penalty, score / repetition_penalty
42
+ )
43
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
44
+
45
+ if top_p is not None and top_p < 1.0:
46
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
47
+ cum_probs = torch.cumsum(
48
+ torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
49
+ )
50
+ sorted_indices_to_remove = cum_probs > top_p
51
+ sorted_indices_to_remove[0] = False # keep at least one option
52
+ indices_to_remove = sorted_indices_to_remove.scatter(
53
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
54
+ )
55
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
56
+
57
+ logits = logits / max(temperature, 1e-5)
58
+
59
+ if top_k is not None:
60
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
61
+ pivot = v.select(-1, -1).unsqueeze(-1)
62
+ logits = torch.where(logits < pivot, inf_tensor_value, logits)
63
+
64
+ probs = torch.nn.functional.softmax(logits, dim=-1)
65
+ return probs
66
+
67
+
68
+ def multinomial_sample_one_no_sync(
69
+ probs_sort
70
+ ): # Does multinomial sampling without a cuda synchronization
71
+ q = torch.randn_like(probs_sort)
72
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
73
+
74
+
75
+ def sample(
76
+ logits,
77
+ previous_tokens,
78
+ **sampling_kwargs,
79
+ ):
80
+ probs = logits_to_probs(
81
+ logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
82
+ )
83
+ idx_next = multinomial_sample_one_no_sync(probs)
84
+ return idx_next, probs
85
+
86
+
87
+ class OnnxEncoder(nn.Module):
88
+ def __init__(self, ar_text_embedding, bert_proj, ar_text_position):
89
+ super().__init__()
90
+ self.ar_text_embedding = ar_text_embedding
91
+ self.bert_proj = bert_proj
92
+ self.ar_text_position = ar_text_position
93
+
94
+ def forward(self, x, bert_feature):
95
+ x = self.ar_text_embedding(x)
96
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
97
+ return self.ar_text_position(x)
98
+
99
+
100
+ class T2SFirstStageDecoder(nn.Module):
101
+ def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
102
+ top_k, early_stop_num, num_layers):
103
+ super().__init__()
104
+ self.ar_audio_embedding = ar_audio_embedding
105
+ self.ar_audio_position = ar_audio_position
106
+ self.h = h
107
+ self.ar_predict_layer = ar_predict_layer
108
+ self.loss_fct = loss_fct
109
+ self.ar_accuracy_metric = ar_accuracy_metric
110
+ self.top_k = top_k
111
+ self.early_stop_num = early_stop_num
112
+ self.num_layers = num_layers
113
+
114
+ def forward(self, x, prompt):
115
+ y = prompt
116
+ x_example = x[:,:,0] * 0.0
117
+ #N, 1, 512
118
+ cache = {
119
+ "all_stage": self.num_layers,
120
+ "k": None,
121
+ "v": None,
122
+ "y_emb": None,
123
+ "first_infer": 1,
124
+ "stage": 0,
125
+ }
126
+
127
+ y_emb = self.ar_audio_embedding(y)
128
+
129
+ cache["y_emb"] = y_emb
130
+ y_pos = self.ar_audio_position(y_emb)
131
+
132
+ xy_pos = torch.concat([x, y_pos], dim=1)
133
+
134
+ y_example = y_pos[:,:,0] * 0.0
135
+ x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool()
136
+ y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
137
+ y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
138
+ torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0
139
+ )
140
+ y_attn_mask = y_attn_mask > 0
141
+
142
+ x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool()
143
+ y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool()
144
+ x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
145
+ y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
146
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
147
+ cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
148
+ .unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
149
+ cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
150
+ .unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
151
+
152
+ xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
153
+ logits = self.ar_predict_layer(xy_dec[:, -1])
154
+ samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
155
+
156
+ y = torch.concat([y, samples], dim=1)
157
+
158
+ return y, cache["k"], cache["v"], cache["y_emb"], x_example
159
+
160
+
161
+ class T2SStageDecoder(nn.Module):
162
+ def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
163
+ top_k, early_stop_num, num_layers):
164
+ super().__init__()
165
+ self.ar_audio_embedding = ar_audio_embedding
166
+ self.ar_audio_position = ar_audio_position
167
+ self.h = h
168
+ self.ar_predict_layer = ar_predict_layer
169
+ self.loss_fct = loss_fct
170
+ self.ar_accuracy_metric = ar_accuracy_metric
171
+ self.top_k = top_k
172
+ self.early_stop_num = early_stop_num
173
+ self.num_layers = num_layers
174
+
175
+ def forward(self, y, k, v, y_emb, x_example):
176
+ cache = {
177
+ "all_stage": self.num_layers,
178
+ "k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
179
+ "v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)),
180
+ "y_emb": y_emb,
181
+ "first_infer": 0,
182
+ "stage": 0,
183
+ }
184
+
185
+ y_emb = torch.cat(
186
+ [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
187
+ )
188
+ cache["y_emb"] = y_emb
189
+ y_pos = self.ar_audio_position(y_emb)
190
+
191
+ xy_pos = y_pos[:, -1:]
192
+
193
+ y_example = y_pos[:,:,0] * 0.0
194
+
195
+ xy_attn_mask = torch.cat([x_example, y_example], dim=1)
196
+ xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
197
+
198
+ xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
199
+ logits = self.ar_predict_layer(xy_dec[:, -1])
200
+ samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
201
+
202
+ y = torch.concat([y, samples], dim=1)
203
+
204
+ return y, cache["k"], cache["v"], cache["y_emb"], logits, samples
205
+
206
+
207
+ class Text2SemanticDecoder(nn.Module):
208
+ def __init__(self, config, norm_first=False, top_k=3):
209
+ super(Text2SemanticDecoder, self).__init__()
210
+ self.model_dim = config["model"]["hidden_dim"]
211
+ self.embedding_dim = config["model"]["embedding_dim"]
212
+ self.num_head = config["model"]["head"]
213
+ self.num_layers = config["model"]["n_layer"]
214
+ self.norm_first = norm_first
215
+ self.vocab_size = config["model"]["vocab_size"]
216
+ self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
217
+ self.p_dropout = float(config["model"]["dropout"])
218
+ self.EOS = config["model"]["EOS"]
219
+ self.norm_first = norm_first
220
+ assert self.EOS == self.vocab_size - 1
221
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
222
+ self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
223
+ self.ar_text_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
224
+ self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout)
225
+ self.ar_audio_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
226
+ self.h = TransformerEncoder(
227
+ TransformerEncoderLayer(
228
+ d_model=self.model_dim,
229
+ nhead=self.num_head,
230
+ dim_feedforward=self.model_dim * 4,
231
+ dropout=0.1,
232
+ batch_first=True,
233
+ norm_first=norm_first,
234
+ ),
235
+ num_layers=self.num_layers,
236
+ norm=LayerNorm(self.model_dim) if norm_first else None,
237
+ )
238
+ self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
239
+ self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
240
+ self.ar_accuracy_metric = MulticlassAccuracy(
241
+ self.vocab_size,
242
+ top_k=top_k,
243
+ average="micro",
244
+ multidim_average="global",
245
+ ignore_index=self.EOS,
246
+ )
247
+ self.top_k = torch.LongTensor([1])
248
+ self.early_stop_num = torch.LongTensor([-1])
249
+
250
+ def init_onnx(self):
251
+ self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
252
+ self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
253
+ self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
254
+ self.num_layers)
255
+ self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
256
+ self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
257
+ self.num_layers)
258
+
259
+ def forward(self, x, prompts, bert_feature):
260
+ early_stop_num = self.early_stop_num
261
+ prefix_len = prompts.shape[1]
262
+
263
+ x = self.onnx_encoder(x, bert_feature)
264
+ y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts)
265
+
266
+ stop = False
267
+ for idx in range(1, 1500):
268
+ enco = self.stage_decoder(y, k, v, y_emb, stage, x_example)
269
+ y, k, v, y_emb, stage, logits, samples = enco
270
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
271
+ stop = True
272
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
273
+ stop = True
274
+ if stop:
275
+ break
276
+ y[0, -1] = 0
277
+ return y, idx
278
+
279
+ def infer(self, x, prompts, bert_feature):
280
+ top_k = self.top_k
281
+ early_stop_num = self.early_stop_num
282
+
283
+ x = self.onnx_encoder(x, bert_feature)
284
+
285
+ y = prompts
286
+ prefix_len = y.shape[1]
287
+ x_len = x.shape[1]
288
+ x_example = x[:,:,0] * 0.0
289
+ x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
290
+ x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
291
+
292
+ stop = False
293
+ cache = {
294
+ "all_stage": self.num_layers,
295
+ "k": [None] * self.num_layers,
296
+ "v": [None] * self.num_layers,
297
+ "y_emb": None,
298
+ "first_infer": 1,
299
+ "stage": 0,
300
+ }
301
+ for idx in range(1500):
302
+ if cache["first_infer"] == 1:
303
+ y_emb = self.ar_audio_embedding(y)
304
+ else:
305
+ y_emb = torch.cat(
306
+ [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
307
+ )
308
+ cache["y_emb"] = y_emb
309
+ y_pos = self.ar_audio_position(y_emb)
310
+ if cache["first_infer"] == 1:
311
+ xy_pos = torch.concat([x, y_pos], dim=1)
312
+ else:
313
+ xy_pos = y_pos[:, -1:]
314
+ y_len = y_pos.shape[1]
315
+ if cache["first_infer"] == 1:
316
+ x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
317
+ y_attn_mask = F.pad(
318
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
319
+ (x_len, 0), value=False
320
+ )
321
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
322
+ else:
323
+ xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
324
+ xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
325
+ logits = self.ar_predict_layer(xy_dec[:, -1])
326
+ samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
327
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
328
+ stop = True
329
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
330
+ stop = True
331
+ if stop:
332
+ if prompts.shape[1] == y.shape[1]:
333
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
334
+ break
335
+ y = torch.concat([y, samples], dim=1)
336
+ cache["first_infer"] = 0
337
+ return y, idx
GPT_SoVITS/AR/models/utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/utils.py\
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def sequence_mask(length, max_length=None):
7
+ if max_length is None:
8
+ max_length = length.max()
9
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
10
+ return x.unsqueeze(0) < length.unsqueeze(1)
11
+
12
+
13
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
14
+ """
15
+ Args:
16
+ lengths:
17
+ A 1-D tensor containing sentence lengths.
18
+ max_len:
19
+ The length of masks.
20
+ Returns:
21
+ Return a 2-D bool tensor, where masked positions
22
+ are filled with `True` and non-masked positions are
23
+ filled with `False`.
24
+
25
+ #>>> lengths = torch.tensor([1, 3, 2, 5])
26
+ #>>> make_pad_mask(lengths)
27
+ tensor([[False, True, True, True, True],
28
+ [False, False, False, True, True],
29
+ [False, False, True, True, True],
30
+ [False, False, False, False, False]])
31
+ """
32
+ assert lengths.ndim == 1, lengths.ndim
33
+ max_len = max(max_len, lengths.max())
34
+ n = lengths.size(0)
35
+ seq_range = torch.arange(0, max_len, device=lengths.device)
36
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
37
+
38
+ return expaned_lengths >= lengths.unsqueeze(-1)
39
+
40
+
41
+ # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
42
+ def top_k_top_p_filtering(
43
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
44
+ ):
45
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
46
+ Args:
47
+ logits: logits distribution shape (batch size, vocabulary size)
48
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
49
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
50
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
51
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
52
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
53
+ """
54
+ if top_k > 0:
55
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
56
+ # Remove all tokens with a probability less than the last token of the top-k
57
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
58
+ logits[indices_to_remove] = filter_value
59
+
60
+ if top_p < 1.0:
61
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
62
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
63
+
64
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
65
+ sorted_indices_to_remove = cumulative_probs > top_p
66
+ if min_tokens_to_keep > 1:
67
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
68
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
69
+ # Shift the indices to the right to keep also the first token above the threshold
70
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
71
+ sorted_indices_to_remove[..., 0] = 0
72
+
73
+ # scatter sorted tensors to original indexing
74
+ indices_to_remove = sorted_indices_to_remove.scatter(
75
+ 1, sorted_indices, sorted_indices_to_remove
76
+ )
77
+ logits[indices_to_remove] = filter_value
78
+ return logits
79
+
80
+
81
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
82
+ # temperature: (`optional`) float
83
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
84
+ # top_k: (`optional`) int
85
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
86
+ # top_p: (`optional`) float
87
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
88
+
89
+ # Temperature (higher temperature => more likely to sample low probability tokens)
90
+ if temperature != 1.0:
91
+ logits = logits / temperature
92
+ # Top-p/top-k filtering
93
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
94
+ # Sample
95
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
96
+ return token
97
+
98
+
99
+ from typing import Optional, Tuple
100
+
101
+
102
+ def multinomial_sample_one_no_sync(
103
+ probs_sort,
104
+ ): # Does multinomial sampling without a cuda synchronization
105
+ q = torch.empty_like(probs_sort).exponential_(1)
106
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
107
+
108
+
109
+ def logits_to_probs(
110
+ logits,
111
+ previous_tokens: Optional[torch.Tensor] = None,
112
+ temperature: float = 1.0,
113
+ top_k: Optional[int] = None,
114
+ top_p: Optional[int] = None,
115
+ repetition_penalty: float = 1.0,
116
+ ):
117
+ previous_tokens = previous_tokens.squeeze()
118
+ # print(logits.shape,previous_tokens.shape)
119
+ # pdb.set_trace()
120
+ if previous_tokens is not None and repetition_penalty != 1.0:
121
+ previous_tokens = previous_tokens.long()
122
+ score = torch.gather(logits, dim=0, index=previous_tokens)
123
+ score = torch.where(
124
+ score < 0, score * repetition_penalty, score / repetition_penalty
125
+ )
126
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
127
+
128
+ if top_p is not None and top_p < 1.0:
129
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
130
+ cum_probs = torch.cumsum(
131
+ torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
132
+ )
133
+ sorted_indices_to_remove = cum_probs > top_p
134
+ sorted_indices_to_remove[0] = False # keep at least one option
135
+ indices_to_remove = sorted_indices_to_remove.scatter(
136
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
137
+ )
138
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
139
+
140
+ logits = logits / max(temperature, 1e-5)
141
+
142
+ if top_k is not None:
143
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
144
+ pivot = v.select(-1, -1).unsqueeze(-1)
145
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
146
+
147
+ probs = torch.nn.functional.softmax(logits, dim=-1)
148
+ return probs
149
+
150
+
151
+ def sample(
152
+ logits,
153
+ previous_tokens: Optional[torch.Tensor] = None,
154
+ **sampling_kwargs,
155
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
156
+ probs = logits_to_probs(
157
+ logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
158
+ )
159
+ idx_next = multinomial_sample_one_no_sync(probs)
160
+ return idx_next, probs