Downgrade Gradio verison

#27
by multimodalart HF staff - opened
This view is limited to 50 files because it contains too many changes.  See the raw diff here.
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +0 -35
  3. .github/ISSUE_TEMPLATE/bug_report.yml +0 -50
  4. .github/ISSUE_TEMPLATE/config.yml +0 -1
  5. .github/ISSUE_TEMPLATE/feature_request.yml +0 -62
  6. .github/ISSUE_TEMPLATE/help_wanted.yml +0 -50
  7. .github/ISSUE_TEMPLATE/question.yml +0 -26
  8. .github/workflows/pre-commit.yaml +0 -14
  9. .github/workflows/publish-docker-image.yaml +0 -60
  10. .gitmodules +0 -3
  11. .pre-commit-config.yaml +0 -14
  12. Dockerfile +3 -5
  13. README_REPO.md +115 -73
  14. api.py +0 -132
  15. app.py +524 -584
  16. app_local.py +236 -0
  17. cog.py +180 -0
  18. data/.DS_Store +0 -0
  19. data/Emilia_ZH_EN_pinyin/vocab.txt +2545 -2545
  20. data/librispeech_pc_test_clean_cross_sentence.lst +0 -0
  21. finetune-cli.py +42 -61
  22. finetune_gradio.py +356 -570
  23. gradio_app.py +0 -824
  24. inference-cli.py +292 -70
  25. model/__init__.py +0 -3
  26. model/backbones/dit.py +44 -49
  27. model/backbones/mmdit.py +28 -38
  28. model/backbones/unett.py +52 -70
  29. model/cfm.py +67 -75
  30. model/dataset.py +80 -117
  31. model/ecapa_tdnn.py +35 -97
  32. model/modules.py +114 -120
  33. model/trainer.py +83 -133
  34. model/utils.py +154 -194
  35. model/utils_infer.py +0 -357
  36. packages.txt +1 -0
  37. pyproject.toml +0 -62
  38. requirements.txt +2 -5
  39. ruff.toml +0 -10
  40. scripts/count_max_epoch.py +3 -4
  41. scripts/count_params_gflops.py +6 -10
  42. scripts/eval_infer_batch.py +61 -60
  43. scripts/eval_librispeech_test_clean.py +4 -6
  44. scripts/eval_seedtts_testset.py +8 -10
  45. scripts/prepare_csv_wavs.py +10 -16
  46. scripts/prepare_emilia.py +16 -100
  47. scripts/prepare_wenetspeech4tts.py +12 -15
  48. speech_edit.py +44 -50
  49. src/f5_tts/api.py +0 -166
  50. src/f5_tts/configs/E2TTS_Base_train.yaml +0 -44
.DS_Store DELETED
Binary file (6.15 kB)
 
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/ISSUE_TEMPLATE/bug_report.yml DELETED
@@ -1,50 +0,0 @@
1
- name: "Bug Report"
2
- description: |
3
- Please provide as much details to help address the issue, including logs and screenshots.
4
- labels:
5
- - bug
6
- body:
7
- - type: checkboxes
8
- attributes:
9
- label: Checks
10
- description: "To ensure timely help, please confirm the following:"
11
- options:
12
- - label: This template is only for bug reports, usage problems go with 'Help Wanted'.
13
- required: true
14
- - label: I have thoroughly reviewed the project documentation but couldn't find information to solve my problem.
15
- required: true
16
- - label: I have searched for existing issues, including closed ones, and couldn't find a solution.
17
- required: true
18
- - label: I confirm that I am using English to submit this report in order to facilitate communication.
19
- required: true
20
- - type: textarea
21
- attributes:
22
- label: Environment Details
23
- description: "Provide details such as OS, Python version, and any relevant software or dependencies."
24
- placeholder: e.g., CentOS Linux 7, RTX 3090, Python 3.10, torch==2.3.0, cuda 11.8
25
- validations:
26
- required: true
27
- - type: textarea
28
- attributes:
29
- label: Steps to Reproduce
30
- description: |
31
- Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
32
- placeholder: |
33
- 1. Create a new conda environment.
34
- 2. Clone the repository, install as local editable and properly set up.
35
- 3. Run the command: `accelerate launch src/f5_tts/train/train.py`.
36
- 4. Have following error message... (attach logs).
37
- validations:
38
- required: true
39
- - type: textarea
40
- attributes:
41
- label: ✔️ Expected Behavior
42
- placeholder: Describe what you expected to happen.
43
- validations:
44
- required: false
45
- - type: textarea
46
- attributes:
47
- label: ❌ Actual Behavior
48
- placeholder: Describe what actually happened.
49
- validations:
50
- required: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/ISSUE_TEMPLATE/config.yml DELETED
@@ -1 +0,0 @@
1
- blank_issues_enabled: false
 
 
.github/ISSUE_TEMPLATE/feature_request.yml DELETED
@@ -1,62 +0,0 @@
1
- name: "Feature Request"
2
- description: |
3
- Some constructive suggestions and new ideas regarding current repo.
4
- labels:
5
- - enhancement
6
- body:
7
- - type: checkboxes
8
- attributes:
9
- label: Checks
10
- description: "To help us grasp quickly, please confirm the following:"
11
- options:
12
- - label: This template is only for feature request.
13
- required: true
14
- - label: I have thoroughly reviewed the project documentation but couldn't find any relevant information that meets my needs.
15
- required: true
16
- - label: I have searched for existing issues, including closed ones, and found not discussion yet.
17
- required: true
18
- - label: I confirm that I am using English to submit this report in order to facilitate communication.
19
- required: true
20
- - type: textarea
21
- attributes:
22
- label: 1. Is this request related to a challenge you're experiencing? Tell us your story.
23
- description: |
24
- Describe the specific problem or scenario you're facing in detail. For example:
25
- *"I was trying to use [feature] for [specific task], but encountered [issue]. This was frustrating because...."*
26
- placeholder: Please describe the situation in as much detail as possible.
27
- validations:
28
- required: true
29
-
30
- - type: textarea
31
- attributes:
32
- label: 2. What is your suggested solution?
33
- description: |
34
- Provide a clear description of the feature or enhancement you'd like to propose.
35
- How would this feature solve your issue or improve the project?
36
- placeholder: Describe your idea or proposed solution here.
37
- validations:
38
- required: true
39
-
40
- - type: textarea
41
- attributes:
42
- label: 3. Additional context or comments
43
- description: |
44
- Any other relevant information, links, documents, or screenshots that provide clarity.
45
- Use this section for anything not covered above.
46
- placeholder: Add any extra details here.
47
- validations:
48
- required: false
49
-
50
- - type: checkboxes
51
- attributes:
52
- label: 4. Can you help us with this feature?
53
- description: |
54
- Let us know if you're interested in contributing. This is not a commitment but a way to express interest in collaboration.
55
- options:
56
- - label: I am interested in contributing to this feature.
57
- required: false
58
-
59
- - type: markdown
60
- attributes:
61
- value: |
62
- **Note:** Please submit only one request per issue to keep discussions focused and manageable.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/ISSUE_TEMPLATE/help_wanted.yml DELETED
@@ -1,50 +0,0 @@
1
- name: "Help Wanted"
2
- description: |
3
- Please provide as much details to help address the issue, including logs and screenshots.
4
- labels:
5
- - help wanted
6
- body:
7
- - type: checkboxes
8
- attributes:
9
- label: Checks
10
- description: "To ensure timely help, please confirm the following:"
11
- options:
12
- - label: This template is only for usage issues encountered.
13
- required: true
14
- - label: I have thoroughly reviewed the project documentation but couldn't find information to solve my problem.
15
- required: true
16
- - label: I have searched for existing issues, including closed ones, and couldn't find a solution.
17
- required: true
18
- - label: I confirm that I am using English to submit this report in order to facilitate communication.
19
- required: true
20
- - type: textarea
21
- attributes:
22
- label: Environment Details
23
- description: "Provide details such as OS, Python version, and any relevant software or dependencies."
24
- placeholder: e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1
25
- validations:
26
- required: true
27
- - type: textarea
28
- attributes:
29
- label: Steps to Reproduce
30
- description: |
31
- Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
32
- placeholder: |
33
- 1. Create a new conda environment.
34
- 2. Clone the repository and install as pip package.
35
- 3. Run the command: `f5-tts_infer-gradio` with no ref_text provided.
36
- 4. Stuck there with the following message... (attach logs and also error msg e.g. after ctrl-c).
37
- validations:
38
- required: true
39
- - type: textarea
40
- attributes:
41
- label: ✔️ Expected Behavior
42
- placeholder: Describe what you expected to happen, e.g. output a generated audio
43
- validations:
44
- required: false
45
- - type: textarea
46
- attributes:
47
- label: ❌ Actual Behavior
48
- placeholder: Describe what actually happened, failure messages, etc.
49
- validations:
50
- required: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/ISSUE_TEMPLATE/question.yml DELETED
@@ -1,26 +0,0 @@
1
- name: "Question"
2
- description: |
3
- Pure question or inquiry about the project, usage issue goes with "help wanted".
4
- labels:
5
- - question
6
- body:
7
- - type: checkboxes
8
- attributes:
9
- label: Checks
10
- description: "To help us grasp quickly, please confirm the following:"
11
- options:
12
- - label: This template is only for question, not feature requests or bug reports.
13
- required: true
14
- - label: I have thoroughly reviewed the project documentation and read the related paper(s).
15
- required: true
16
- - label: I have searched for existing issues, including closed ones, no similar questions.
17
- required: true
18
- - label: I confirm that I am using English to submit this report in order to facilitate communication.
19
- required: true
20
- - type: textarea
21
- attributes:
22
- label: Question details
23
- description: |
24
- Question details, clearly stated using proper markdown syntax.
25
- validations:
26
- required: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/workflows/pre-commit.yaml DELETED
@@ -1,14 +0,0 @@
1
- name: pre-commit
2
-
3
- on:
4
- pull_request:
5
- push:
6
- branches: [main]
7
-
8
- jobs:
9
- pre-commit:
10
- runs-on: ubuntu-latest
11
- steps:
12
- - uses: actions/checkout@v3
13
- - uses: actions/setup-python@v3
14
- - uses: pre-commit/action@v3.0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/workflows/publish-docker-image.yaml DELETED
@@ -1,60 +0,0 @@
1
- name: Create and publish a Docker image
2
-
3
- # Configures this workflow to run every time a change is pushed to the branch called `release`.
4
- on:
5
- push:
6
- branches: ['main']
7
-
8
- # Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds.
9
- env:
10
- REGISTRY: ghcr.io
11
- IMAGE_NAME: ${{ github.repository }}
12
-
13
- # There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu.
14
- jobs:
15
- build-and-push-image:
16
- runs-on: ubuntu-latest
17
- # Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job.
18
- permissions:
19
- contents: read
20
- packages: write
21
- #
22
- steps:
23
- - name: Checkout repository
24
- uses: actions/checkout@v4
25
- - name: Free Up GitHub Actions Ubuntu Runner Disk Space 🔧
26
- uses: jlumbroso/free-disk-space@main
27
- with:
28
- # This might remove tools that are actually needed, if set to "true" but frees about 6 GB
29
- tool-cache: false
30
-
31
- # All of these default to true, but feel free to set to "false" if necessary for your workflow
32
- android: true
33
- dotnet: true
34
- haskell: true
35
- large-packages: false
36
- swap-storage: false
37
- docker-images: false
38
- # Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
39
- - name: Log in to the Container registry
40
- uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
41
- with:
42
- registry: ${{ env.REGISTRY }}
43
- username: ${{ github.actor }}
44
- password: ${{ secrets.GITHUB_TOKEN }}
45
- # This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels.
46
- - name: Extract metadata (tags, labels) for Docker
47
- id: meta
48
- uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
49
- with:
50
- images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
51
- # This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages.
52
- # It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository.
53
- # It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step.
54
- - name: Build and push Docker image
55
- uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
56
- with:
57
- context: .
58
- push: true
59
- tags: ${{ steps.meta.outputs.tags }}
60
- labels: ${{ steps.meta.outputs.labels }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitmodules DELETED
@@ -1,3 +0,0 @@
1
- [submodule "src/third_party/BigVGAN"]
2
- path = src/third_party/BigVGAN
3
- url = https://github.com/NVIDIA/BigVGAN.git
 
 
 
 
.pre-commit-config.yaml DELETED
@@ -1,14 +0,0 @@
1
- repos:
2
- - repo: https://github.com/astral-sh/ruff-pre-commit
3
- # Ruff version.
4
- rev: v0.7.0
5
- hooks:
6
- # Run the linter.
7
- - id: ruff
8
- args: [--fix]
9
- # Run the formatter.
10
- - id: ruff-format
11
- - repo: https://github.com/pre-commit/pre-commit-hooks
12
- rev: v2.3.0
13
- hooks:
14
- - id: check-yaml
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Dockerfile CHANGED
@@ -10,17 +10,15 @@ RUN set -x \
10
  && apt-get update \
11
  && apt-get -y install wget curl man git less openssl libssl-dev unzip unar build-essential aria2 tmux vim \
12
  && apt-get install -y openssh-server sox libsox-fmt-all libsox-fmt-mp3 libsndfile1-dev ffmpeg \
13
- && apt-get install -y librdmacm1 libibumad3 librdmacm-dev libibverbs1 libibverbs-dev ibverbs-utils ibverbs-providers \
14
  && rm -rf /var/lib/apt/lists/* \
15
  && apt-get clean
16
-
17
  WORKDIR /workspace
18
 
19
  RUN git clone https://github.com/SWivid/F5-TTS.git \
20
  && cd F5-TTS \
21
- && git submodule update --init --recursive \
22
- && sed -i '7iimport sys\nsys.path.append(os.path.dirname(os.path.abspath(__file__)))' src/third_party/BigVGAN/bigvgan.py \
23
- && pip install -e . --no-cache-dir
24
 
25
  ENV SHELL=/bin/bash
26
 
 
10
  && apt-get update \
11
  && apt-get -y install wget curl man git less openssl libssl-dev unzip unar build-essential aria2 tmux vim \
12
  && apt-get install -y openssh-server sox libsox-fmt-all libsox-fmt-mp3 libsndfile1-dev ffmpeg \
 
13
  && rm -rf /var/lib/apt/lists/* \
14
  && apt-get clean
15
+
16
  WORKDIR /workspace
17
 
18
  RUN git clone https://github.com/SWivid/F5-TTS.git \
19
  && cd F5-TTS \
20
+ && pip install --no-cache-dir -r requirements.txt \
21
+ && pip install --no-cache-dir -r requirements_eval.txt
 
22
 
23
  ENV SHELL=/bin/bash
24
 
README_REPO.md CHANGED
@@ -16,133 +16,176 @@
16
 
17
  ### Thanks to all the contributors !
18
 
19
- ## News
20
- - **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN).
21
-
22
  ## Installation
23
 
24
- ```bash
25
- # Create a python 3.10 conda env (you could also use virtualenv)
26
- conda create -n f5-tts python=3.10
27
- conda activate f5-tts
28
 
29
- # Install pytorch with your CUDA version, e.g.
30
- pip install torch==2.3.0+cu118 torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
 
31
  ```
32
 
33
- Then you can choose from a few options below:
34
-
35
- ### 1. As a pip package (if just for inference)
36
 
37
  ```bash
38
- pip install git+https://github.com/SWivid/F5-TTS.git
 
39
  ```
40
 
41
- ### 2. Local editable (if also do training, finetuning)
42
 
43
  ```bash
44
- git clone https://github.com/SWivid/F5-TTS.git
45
- cd F5-TTS
46
- # git submodule update --init --recursive # (optional, if need bigvgan)
47
- pip install -e .
48
- ```
49
- If initialize submodule, you should add the following code at the beginning of `src/third_party/BigVGAN/bigvgan.py`.
50
- ```python
51
- import os
52
- import sys
53
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
54
  ```
55
 
56
- ### 3. Docker usage
 
 
 
57
  ```bash
58
- # Build from Dockerfile
59
- docker build -t f5tts:v1 .
 
 
 
60
 
61
- # Or pull from GitHub Container Registry
62
- docker pull ghcr.io/swivid/f5-tts:main
63
  ```
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  ## Inference
67
 
68
- ### 1. Gradio App
69
 
70
- Currently supported features:
 
 
 
 
 
 
71
 
72
- - Basic TTS with Chunk Inference
73
- - Multi-Style / Multi-Speaker Generation
74
- - Voice Chat powered by Qwen2.5-3B-Instruct
75
- - [Custom inference with more language support](src/f5_tts/infer/SHARED.md)
76
 
77
  ```bash
78
- # Launch a Gradio app (web interface)
79
- f5-tts_infer-gradio
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- # Specify the port/host
82
- f5-tts_infer-gradio --port 7860 --host 0.0.0.0
 
 
 
83
 
84
- # Launch a share link
85
- f5-tts_infer-gradio --share
 
 
86
  ```
87
 
88
- ### 2. CLI Inference
89
 
90
  ```bash
91
- # Run with flags
92
- # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
93
- f5-tts_infer-cli \
94
- --model "F5-TTS" \
95
- --ref_audio "ref_audio.wav" \
96
- --ref_text "The content, subtitle or transcription of reference audio." \
97
- --gen_text "Some text you want TTS model generate for you."
98
 
99
- # Run with default setting. src/f5_tts/infer/examples/basic/basic.toml
100
- f5-tts_infer-cli
101
- # Or with your own .toml file
102
- f5-tts_infer-cli -c custom.toml
103
 
104
- # Multi voice. See src/f5_tts/infer/README.md
105
- f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
106
  ```
107
 
108
- ### 3. More instructions
 
 
109
 
110
- - In order to have better generation results, take a moment to read [detailed guidance](src/f5_tts/infer).
111
- - The [Issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very useful, please try to find the solution by properly searching the keywords of problem encountered. If no answer found, then feel free to open an issue.
 
 
 
112
 
 
113
 
114
- ## Training
 
 
 
 
115
 
116
- ### 1. Gradio App
117
 
118
- Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
119
 
120
  ```bash
121
- # Quick start with Gradio web interface
122
- f5-tts_finetune-gradio
 
123
  ```
124
 
 
125
 
126
- ## [Evaluation](src/f5_tts/eval)
 
 
127
 
 
128
 
129
- ## Development
130
 
131
- Use pre-commit to ensure code quality (will run linters and formatters automatically)
 
 
 
 
 
 
132
 
133
  ```bash
134
- pip install pre-commit
135
- pre-commit install
136
  ```
137
 
138
- When making a pull request, before each commit, run:
139
 
140
  ```bash
141
- pre-commit run --all-files
142
  ```
143
 
144
- Note: Some model components have linting exceptions for E722 to accommodate tensor notation
 
 
 
145
 
 
 
 
146
 
147
  ## Acknowledgements
148
 
@@ -154,8 +197,7 @@ Note: Some model components have linting exceptions for E722 to accommodate tens
154
  - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
155
  - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
156
  - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
157
- - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
158
- - [F5-TTS-ONNX](https://github.com/DakeQQ/F5-TTS-ONNX) ONNX Runtime version by [DakeQQ](https://github.com/DakeQQ)
159
 
160
  ## Citation
161
  If our work and codebase is useful for you, please cite as:
 
16
 
17
  ### Thanks to all the contributors !
18
 
 
 
 
19
  ## Installation
20
 
21
+ Clone the repository:
 
 
 
22
 
23
+ ```bash
24
+ git clone https://github.com/SWivid/F5-TTS.git
25
+ cd F5-TTS
26
  ```
27
 
28
+ Install torch with your CUDA version, e.g. :
 
 
29
 
30
  ```bash
31
+ pip install torch==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
32
+ pip install torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
33
  ```
34
 
35
+ Install other packages:
36
 
37
  ```bash
38
+ pip install -r requirements.txt
 
 
 
 
 
 
 
 
 
39
  ```
40
 
41
+ ## Prepare Dataset
42
+
43
+ Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
44
+
45
  ```bash
46
+ # prepare custom dataset up to your need
47
+ # download corresponding dataset first, and fill in the path in scripts
48
+
49
+ # Prepare the Emilia dataset
50
+ python scripts/prepare_emilia.py
51
 
52
+ # Prepare the Wenetspeech4TTS dataset
53
+ python scripts/prepare_wenetspeech4tts.py
54
  ```
55
 
56
+ ## Training & Finetuning
57
+
58
+ Once your datasets are prepared, you can start the training process.
59
+
60
+ ```bash
61
+ # setup accelerate config, e.g. use multi-gpu ddp, fp16
62
+ # will be to: ~/.cache/huggingface/accelerate/default_config.yaml
63
+ accelerate config
64
+ accelerate launch train.py
65
+ ```
66
+ An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
67
+
68
+ Gradio UI finetuning with `finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
69
 
70
  ## Inference
71
 
72
+ The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or automatically downloaded with `inference-cli` and `gradio_app`.
73
 
74
+ Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`.
75
+ - To avoid possible inference failures, make sure you have seen through the following instructions.
76
+ - A longer prompt audio allows shorter generated output. The part longer than 30s cannot be generated properly. Consider using a prompt audio <15s.
77
+ - Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
78
+ - Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses. If first few words skipped in code-switched generation (cuz different speed with different languages), this might help.
79
+
80
+ ### CLI Inference
81
 
82
+ Either you can specify everything in `inference-cli.toml` or override with flags. Leave `--ref_text ""` will have ASR model transcribe the reference audio automatically (use extra GPU memory). If encounter network error, consider use local ckpt, just set `ckpt_path` in `inference-cli.py`
 
 
 
83
 
84
  ```bash
85
+ python inference-cli.py \
86
+ --model "F5-TTS" \
87
+ --ref_audio "tests/ref_audio/test_en_1_ref_short.wav" \
88
+ --ref_text "Some call me nature, others call me mother nature." \
89
+ --gen_text "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
90
+
91
+ python inference-cli.py \
92
+ --model "E2-TTS" \
93
+ --ref_audio "tests/ref_audio/test_zh_1_ref_short.wav" \
94
+ --ref_text "对,这就是我,万人敬仰的太乙真人。" \
95
+ --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺���了胸膛,甩了甩那稍显肉感的双臂,轻笑道,我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"
96
+
97
+ # Multi voice
98
+ python inference-cli.py -c samples/story.toml
99
+ ```
100
 
101
+ ### Gradio App
102
+ Currently supported features:
103
+ - Chunk inference
104
+ - Podcast Generation
105
+ - Multiple Speech-Type Generation
106
 
107
+ You can launch a Gradio app (web interface) to launch a GUI for inference (will load ckpt from Huggingface, you may set `ckpt_path` to local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
108
+
109
+ ```bash
110
+ python gradio_app.py
111
  ```
112
 
113
+ You can specify the port/host:
114
 
115
  ```bash
116
+ python gradio_app.py --port 7860 --host 0.0.0.0
117
+ ```
 
 
 
 
 
118
 
119
+ Or launch a share link:
 
 
 
120
 
121
+ ```bash
122
+ python gradio_app.py --share
123
  ```
124
 
125
+ ### Speech Editing
126
+
127
+ To test speech editing capabilities, use the following command.
128
 
129
+ ```bash
130
+ python speech_edit.py
131
+ ```
132
+
133
+ ## Evaluation
134
 
135
+ ### Prepare Test Datasets
136
 
137
+ 1. Seed-TTS test set: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
138
+ 2. LibriSpeech test-clean: Download from [OpenSLR](http://www.openslr.org/12/).
139
+ 3. Unzip the downloaded datasets and place them in the data/ directory.
140
+ 4. Update the path for the test-clean data in `scripts/eval_infer_batch.py`
141
+ 5. Our filtered LibriSpeech-PC 4-10s subset is already under data/ in this repo
142
 
143
+ ### Batch Inference for Test Set
144
 
145
+ To run batch inference for evaluations, execute the following commands:
146
 
147
  ```bash
148
+ # batch inference for evaluations
149
+ accelerate config # if not set before
150
+ bash scripts/eval_infer_batch.sh
151
  ```
152
 
153
+ ### Download Evaluation Model Checkpoints
154
 
155
+ 1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
156
+ 2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
157
+ 3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
158
 
159
+ ### Objective Evaluation
160
 
161
+ Install packages for evaluation:
162
 
163
+ ```bash
164
+ pip install -r requirements_eval.txt
165
+ ```
166
+
167
+ **Some Notes**
168
+
169
+ For faster-whisper with CUDA 11:
170
 
171
  ```bash
172
+ pip install --force-reinstall ctranslate2==3.24.0
 
173
  ```
174
 
175
+ (Recommended) To avoid possible ASR failures, such as abnormal repetitions in output:
176
 
177
  ```bash
178
+ pip install faster-whisper==0.10.1
179
  ```
180
 
181
+ Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
182
+ ```bash
183
+ # Evaluation for Seed-TTS test set
184
+ python scripts/eval_seedtts_testset.py
185
 
186
+ # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
187
+ python scripts/eval_librispeech_test_clean.py
188
+ ```
189
 
190
  ## Acknowledgements
191
 
 
197
  - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
198
  - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
199
  - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
200
+ - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation of F5-TTS, with the MLX framework.
 
201
 
202
  ## Citation
203
  If our work and codebase is useful for you, please cite as:
api.py DELETED
@@ -1,132 +0,0 @@
1
- import soundfile as sf
2
- import torch
3
- import tqdm
4
- from cached_path import cached_path
5
-
6
- from model import DiT, UNetT
7
- from model.utils import save_spectrogram
8
-
9
- from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
10
- from model.utils import seed_everything
11
- import random
12
- import sys
13
-
14
-
15
- class F5TTS:
16
- def __init__(
17
- self,
18
- model_type="F5-TTS",
19
- ckpt_file="",
20
- vocab_file="",
21
- ode_method="euler",
22
- use_ema=True,
23
- local_path=None,
24
- device=None,
25
- ):
26
- # Initialize parameters
27
- self.final_wave = None
28
- self.target_sample_rate = 24000
29
- self.n_mel_channels = 100
30
- self.hop_length = 256
31
- self.target_rms = 0.1
32
- self.seed = -1
33
-
34
- # Set device
35
- self.device = device or (
36
- "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
37
- )
38
-
39
- # Load models
40
- self.load_vocoder_model(local_path)
41
- self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
42
-
43
- def load_vocoder_model(self, local_path):
44
- self.vocos = load_vocoder(local_path is not None, local_path, self.device)
45
-
46
- def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
47
- if model_type == "F5-TTS":
48
- if not ckpt_file:
49
- ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
50
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
51
- model_cls = DiT
52
- elif model_type == "E2-TTS":
53
- if not ckpt_file:
54
- ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
55
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
56
- model_cls = UNetT
57
- else:
58
- raise ValueError(f"Unknown model type: {model_type}")
59
-
60
- self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device)
61
-
62
- def export_wav(self, wav, file_wave, remove_silence=False):
63
- sf.write(file_wave, wav, self.target_sample_rate)
64
-
65
- if remove_silence:
66
- remove_silence_for_generated_wav(file_wave)
67
-
68
- def export_spectrogram(self, spect, file_spect):
69
- save_spectrogram(spect, file_spect)
70
-
71
- def infer(
72
- self,
73
- ref_file,
74
- ref_text,
75
- gen_text,
76
- show_info=print,
77
- progress=tqdm,
78
- target_rms=0.1,
79
- cross_fade_duration=0.15,
80
- sway_sampling_coef=-1,
81
- cfg_strength=2,
82
- nfe_step=32,
83
- speed=1.0,
84
- fix_duration=None,
85
- remove_silence=False,
86
- file_wave=None,
87
- file_spect=None,
88
- seed=-1,
89
- ):
90
- if seed == -1:
91
- seed = random.randint(0, sys.maxsize)
92
- seed_everything(seed)
93
- self.seed = seed
94
- wav, sr, spect = infer_process(
95
- ref_file,
96
- ref_text,
97
- gen_text,
98
- self.ema_model,
99
- show_info=show_info,
100
- progress=progress,
101
- target_rms=target_rms,
102
- cross_fade_duration=cross_fade_duration,
103
- nfe_step=nfe_step,
104
- cfg_strength=cfg_strength,
105
- sway_sampling_coef=sway_sampling_coef,
106
- speed=speed,
107
- fix_duration=fix_duration,
108
- device=self.device,
109
- )
110
-
111
- if file_wave is not None:
112
- self.export_wav(wav, file_wave, remove_silence)
113
-
114
- if file_spect is not None:
115
- self.export_spectrogram(spect, file_spect)
116
-
117
- return wav, sr, spect
118
-
119
-
120
- if __name__ == "__main__":
121
- f5tts = F5TTS()
122
-
123
- wav, sr, spect = f5tts.infer(
124
- ref_file="tests/ref_audio/test_en_1_ref_short.wav",
125
- ref_text="some call me nature, others call me mother nature.",
126
- gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
127
- file_wave="tests/out.wav",
128
- file_spect="tests/out.png",
129
- seed=-1, # random seed = -1
130
- )
131
-
132
- print("seed :", f5tts.seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,169 +1,403 @@
1
- # ruff: noqa: E402
2
- # Above allows ruff to ignore E402: module level import not at top of file
3
-
4
  import re
5
- import tempfile
6
- from collections import OrderedDict
7
- from importlib.resources import files
8
-
9
- import click
10
  import gradio as gr
11
  import numpy as np
12
- import soundfile as sf
13
- import torchaudio
 
 
 
14
  from cached_path import cached_path
15
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
16
 
17
  try:
18
  import spaces
19
-
20
  USING_SPACES = True
21
  except ImportError:
22
  USING_SPACES = False
23
 
24
-
25
  def gpu_decorator(func):
26
  if USING_SPACES:
27
  return spaces.GPU(func)
28
  else:
29
  return func
30
 
31
-
32
- from f5_tts.model import DiT, UNetT
33
- from f5_tts.infer.utils_infer import (
34
- load_vocoder,
35
- load_model,
36
- preprocess_ref_audio_text,
37
- infer_process,
38
- remove_silence_for_generated_wav,
39
- save_spectrogram,
40
  )
41
 
 
42
 
43
- DEFAULT_TTS_MODEL = "F5-TTS"
44
- tts_model_choice = DEFAULT_TTS_MODEL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
  # load models
 
 
 
 
48
 
49
- vocoder = load_vocoder()
50
-
 
 
 
 
51
 
52
- def load_f5tts(ckpt_path=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))):
53
- F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
54
- return load_model(DiT, F5TTS_model_cfg, ckpt_path)
55
 
 
 
 
56
 
57
- def load_e2tts(ckpt_path=str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))):
58
- E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
59
- return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
61
 
62
- def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
63
- ckpt_path, vocab_path = ckpt_path.strip(), vocab_path.strip()
64
- if ckpt_path.startswith("hf://"):
65
- ckpt_path = str(cached_path(ckpt_path))
66
- if vocab_path.startswith("hf://"):
67
- vocab_path = str(cached_path(vocab_path))
68
- if model_cfg is None:
69
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
70
- return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
71
 
 
 
 
 
 
 
72
 
73
- F5TTS_ema_model = load_f5tts()
74
- E2TTS_ema_model = load_e2tts() if USING_SPACES else None
75
- custom_ema_model, pre_custom_path = None, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- chat_model_state = None
78
- chat_tokenizer_state = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
 
 
 
80
 
81
- @gpu_decorator
82
- def generate_response(messages, model, tokenizer):
83
- """Generate response using Qwen"""
84
- text = tokenizer.apply_chat_template(
85
- messages,
86
- tokenize=False,
87
- add_generation_prompt=True,
88
- )
89
 
90
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
91
- generated_ids = model.generate(
92
- **model_inputs,
93
- max_new_tokens=512,
94
- temperature=0.7,
95
- top_p=0.95,
96
- )
97
 
98
- generated_ids = [
99
- output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
100
- ]
101
- return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
102
 
 
 
103
 
104
- @gpu_decorator
105
- def infer(
106
- ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1, show_info=gr.Info
107
- ):
108
- ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
 
109
 
110
- if model == "F5-TTS":
111
- ema_model = F5TTS_ema_model
112
- elif model == "E2-TTS":
113
- global E2TTS_ema_model
114
- if E2TTS_ema_model is None:
115
- show_info("Loading E2-TTS model...")
116
- E2TTS_ema_model = load_e2tts()
117
- ema_model = E2TTS_ema_model
118
- elif isinstance(model, list) and model[0] == "Custom":
119
- assert not USING_SPACES, "Only official checkpoints allowed in Spaces."
120
- global custom_ema_model, pre_custom_path
121
- if pre_custom_path != model[1]:
122
- show_info("Loading Custom TTS model...")
123
- custom_ema_model = load_custom(model[1], vocab_path=model[2])
124
- pre_custom_path = model[1]
125
- ema_model = custom_ema_model
126
-
127
- final_wave, final_sample_rate, combined_spectrogram = infer_process(
128
- ref_audio,
129
- ref_text,
130
- gen_text,
131
- ema_model,
132
- vocoder,
133
- cross_fade_duration=cross_fade_duration,
134
- speed=speed,
135
- show_info=show_info,
136
- progress=gr.Progress(),
137
- )
138
 
139
  # Remove silence
140
  if remove_silence:
141
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
142
- sf.write(f.name, final_wave, final_sample_rate)
143
- remove_silence_for_generated_wav(f.name)
 
 
 
 
 
 
144
  final_wave, _ = torchaudio.load(f.name)
145
  final_wave = final_wave.squeeze().cpu().numpy()
146
 
147
- # Save the spectrogram
 
 
148
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
149
  spectrogram_path = tmp_spectrogram.name
150
  save_spectrogram(combined_spectrogram, spectrogram_path)
151
 
152
- return (final_sample_rate, final_wave), spectrogram_path, ref_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  with gr.Blocks() as app_credits:
156
  gr.Markdown("""
157
  # Credits
158
 
159
  * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
160
- * [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
161
- * [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
162
  """)
163
  with gr.Blocks() as app_tts:
164
  gr.Markdown("# Batched TTS")
165
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
166
  gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
 
 
 
167
  generate_btn = gr.Button("Synthesize", variant="primary")
168
  with gr.Accordion("Advanced Settings", open=False):
169
  ref_text_input = gr.Textbox(
@@ -180,7 +414,7 @@ with gr.Blocks() as app_tts:
180
  label="Speed",
181
  minimum=0.3,
182
  maximum=2.0,
183
- value=1.0,
184
  step=0.1,
185
  info="Adjust the speed of the audio.",
186
  )
@@ -192,302 +426,301 @@ with gr.Blocks() as app_tts:
192
  step=0.01,
193
  info="Set the duration of the cross-fade between audio clips.",
194
  )
 
195
 
196
  audio_output = gr.Audio(label="Synthesized Audio")
197
  spectrogram_output = gr.Image(label="Spectrogram")
198
 
199
- @gpu_decorator
200
- def basic_tts(
201
- ref_audio_input,
202
- ref_text_input,
203
- gen_text_input,
204
- remove_silence,
205
- cross_fade_duration_slider,
206
- speed_slider,
207
- ):
208
- audio_out, spectrogram_path, ref_text_out = infer(
209
- ref_audio_input,
210
- ref_text_input,
211
- gen_text_input,
212
- tts_model_choice,
213
- remove_silence,
214
- cross_fade_duration_slider,
215
- speed_slider,
216
- )
217
- return audio_out, spectrogram_path, gr.update(value=ref_text_out)
218
-
219
  generate_btn.click(
220
- basic_tts,
221
  inputs=[
222
  ref_audio_input,
223
  ref_text_input,
224
  gen_text_input,
 
225
  remove_silence,
226
  cross_fade_duration_slider,
227
- speed_slider,
228
  ],
229
- outputs=[audio_output, spectrogram_output, ref_text_input],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  )
 
 
231
 
 
 
232
 
233
- def parse_speechtypes_text(gen_text):
234
- # Pattern to find {speechtype}
235
- pattern = r"\{(.*?)\}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
  # Split the text by the pattern
238
  tokens = re.split(pattern, gen_text)
239
 
240
  segments = []
241
 
242
- current_style = "Regular"
243
 
244
  for i in range(len(tokens)):
245
  if i % 2 == 0:
246
  # This is text
247
  text = tokens[i].strip()
248
  if text:
249
- segments.append({"style": current_style, "text": text})
250
  else:
251
- # This is style
252
- style = tokens[i].strip()
253
- current_style = style
254
 
255
  return segments
256
 
257
-
258
- with gr.Blocks() as app_multistyle:
259
- # New section for multistyle generation
260
  gr.Markdown(
261
  """
262
  # Multiple Speech-Type Generation
263
 
264
- This section allows you to generate multiple speech types or multiple people's voices. Enter your text in the format shown below, and the system will generate speech using the appropriate type. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
265
- """
266
- )
267
 
268
- with gr.Row():
269
- gr.Markdown(
270
- """
271
- **Example Input:**
272
- {Regular} Hello, I'd like to order a sandwich please.
273
- {Surprised} What do you mean you're out of bread?
274
- {Sad} I really wanted a sandwich though...
275
- {Angry} You know what, darn you and your little shop!
276
- {Whisper} I'll just go back home and cry now.
277
- {Shouting} Why me?!
278
- """
279
- )
280
 
281
- gr.Markdown(
282
- """
283
- **Example Input 2:**
284
- {Speaker1_Happy} Hello, I'd like to order a sandwich please.
285
- {Speaker2_Regular} Sorry, we're out of bread.
286
- {Speaker1_Sad} I really wanted a sandwich though...
287
- {Speaker2_Whisper} I'll give you the last one I was hiding.
288
- """
289
- )
290
-
291
- gr.Markdown(
292
- "Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button."
293
  )
294
 
 
 
295
  # Regular speech type (mandatory)
296
  with gr.Row():
297
- with gr.Column():
298
- regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
299
- regular_insert = gr.Button("Insert Label", variant="secondary")
300
- regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
301
- regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=2)
302
 
303
- # Regular speech type (max 100)
304
  max_speech_types = 100
305
- speech_type_rows = [] # 99
306
- speech_type_names = [regular_name] # 100
307
- speech_type_audios = [regular_audio] # 100
308
- speech_type_ref_texts = [regular_ref_text] # 100
309
- speech_type_delete_btns = [] # 99
310
- speech_type_insert_btns = [regular_insert] # 100
311
-
312
- # Additional speech types (99 more)
313
  for i in range(max_speech_types - 1):
314
- with gr.Row(visible=False) as row:
315
- with gr.Column():
316
- name_input = gr.Textbox(label="Speech Type Name")
317
- delete_btn = gr.Button("Delete Type", variant="secondary")
318
- insert_btn = gr.Button("Insert Label", variant="secondary")
319
- audio_input = gr.Audio(label="Reference Audio", type="filepath")
320
- ref_text_input = gr.Textbox(label="Reference Text", lines=2)
321
- speech_type_rows.append(row)
322
  speech_type_names.append(name_input)
323
  speech_type_audios.append(audio_input)
324
  speech_type_ref_texts.append(ref_text_input)
325
  speech_type_delete_btns.append(delete_btn)
326
- speech_type_insert_btns.append(insert_btn)
327
 
328
  # Button to add speech type
329
  add_speech_type_btn = gr.Button("Add Speech Type")
330
 
331
  # Keep track of current number of speech types
332
- speech_type_count = gr.State(value=1)
333
 
334
  # Function to add a speech type
335
  def add_speech_type_fn(speech_type_count):
336
- if speech_type_count < max_speech_types:
337
  speech_type_count += 1
338
- # Prepare updates for the rows
339
- row_updates = []
340
- for i in range(1, max_speech_types):
 
 
 
341
  if i < speech_type_count:
342
- row_updates.append(gr.update(visible=True))
 
 
 
343
  else:
344
- row_updates.append(gr.update())
 
 
 
345
  else:
346
  # Optionally, show a warning
347
- row_updates = [gr.update() for _ in range(1, max_speech_types)]
348
- return [speech_type_count] + row_updates
 
 
 
 
349
 
350
  add_speech_type_btn.click(
351
- add_speech_type_fn, inputs=speech_type_count, outputs=[speech_type_count] + speech_type_rows
 
 
352
  )
353
 
354
  # Function to delete a speech type
355
  def make_delete_speech_type_fn(index):
356
  def delete_speech_type_fn(speech_type_count):
357
  # Prepare updates
358
- row_updates = []
 
 
 
359
 
360
- for i in range(1, max_speech_types):
361
  if i == index:
362
- row_updates.append(gr.update(visible=False))
 
 
 
363
  else:
364
- row_updates.append(gr.update())
 
 
 
365
 
366
- speech_type_count = max(1, speech_type_count)
367
 
368
- return [speech_type_count] + row_updates
369
 
370
  return delete_speech_type_fn
371
 
372
- # Update delete button clicks
373
  for i, delete_btn in enumerate(speech_type_delete_btns):
374
  delete_fn = make_delete_speech_type_fn(i)
375
- delete_btn.click(delete_fn, inputs=speech_type_count, outputs=[speech_type_count] + speech_type_rows)
 
 
 
 
376
 
377
  # Text input for the prompt
378
- gen_text_input_multistyle = gr.Textbox(
379
- label="Text to Generate",
380
- lines=10,
381
- placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
382
- )
383
 
384
- def make_insert_speech_type_fn(index):
385
- def insert_speech_type_fn(current_text, speech_type_name):
386
- current_text = current_text or ""
387
- speech_type_name = speech_type_name or "None"
388
- updated_text = current_text + f"{{{speech_type_name}}} "
389
- return gr.update(value=updated_text)
390
-
391
- return insert_speech_type_fn
392
-
393
- for i, insert_btn in enumerate(speech_type_insert_btns):
394
- insert_fn = make_insert_speech_type_fn(i)
395
- insert_btn.click(
396
- insert_fn,
397
- inputs=[gen_text_input_multistyle, speech_type_names[i]],
398
- outputs=gen_text_input_multistyle,
399
- )
400
 
401
  with gr.Accordion("Advanced Settings", open=False):
402
- remove_silence_multistyle = gr.Checkbox(
403
  label="Remove Silences",
404
  value=True,
405
  )
406
 
407
  # Generate button
408
- generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
409
 
410
  # Output audio
411
- audio_output_multistyle = gr.Audio(label="Synthesized Audio")
412
-
413
  @gpu_decorator
414
- def generate_multistyle_speech(
 
 
415
  gen_text,
416
  *args,
417
  ):
418
- speech_type_names_list = args[:max_speech_types]
419
- speech_type_audios_list = args[max_speech_types : 2 * max_speech_types]
420
- speech_type_ref_texts_list = args[2 * max_speech_types : 3 * max_speech_types]
421
- remove_silence = args[3 * max_speech_types]
 
 
 
422
  # Collect the speech types and their audios into a dict
423
- speech_types = OrderedDict()
424
 
425
- ref_text_idx = 0
426
- for name_input, audio_input, ref_text_input in zip(
427
- speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
428
- ):
429
  if name_input and audio_input:
430
- speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input}
431
- else:
432
- speech_types[f"@{ref_text_idx}@"] = {"audio": "", "ref_text": ""}
433
- ref_text_idx += 1
434
 
435
  # Parse the gen_text into segments
436
  segments = parse_speechtypes_text(gen_text)
437
 
438
  # For each segment, generate speech
439
  generated_audio_segments = []
440
- current_style = "Regular"
441
 
442
  for segment in segments:
443
- style = segment["style"]
444
- text = segment["text"]
445
 
446
- if style in speech_types:
447
- current_style = style
448
  else:
449
- # If style not available, default to Regular
450
- current_style = "Regular"
451
 
452
- ref_audio = speech_types[current_style]["audio"]
453
- ref_text = speech_types[current_style].get("ref_text", "")
454
 
455
  # Generate speech for this segment
456
- audio_out, _, ref_text_out = infer(
457
- ref_audio, ref_text, text, tts_model_choice, remove_silence, 0, show_info=print
458
- ) # show_info=print no pull to top when generating
459
- sr, audio_data = audio_out
460
 
461
  generated_audio_segments.append(audio_data)
462
- speech_types[current_style]["ref_text"] = ref_text_out
463
 
464
  # Concatenate all audio segments
465
  if generated_audio_segments:
466
  final_audio_data = np.concatenate(generated_audio_segments)
467
- return [(sr, final_audio_data)] + [
468
- gr.update(value=speech_types[style]["ref_text"]) for style in speech_types
469
- ]
470
  else:
471
  gr.Warning("No audio generated.")
472
- return [None] + [gr.update(value=speech_types[style]["ref_text"]) for style in speech_types]
473
 
474
- generate_multistyle_btn.click(
475
- generate_multistyle_speech,
476
  inputs=[
477
- gen_text_input_multistyle,
478
- ]
479
- + speech_type_names
480
- + speech_type_audios
481
- + speech_type_ref_texts
482
- + [
483
- remove_silence_multistyle,
484
  ],
485
- outputs=[audio_output_multistyle] + speech_type_ref_texts,
486
  )
487
 
488
  # Validation function to disable Generate button if speech types are missing
489
- def validate_speech_types(gen_text, regular_name, *args):
490
- speech_type_names_list = args[:max_speech_types]
 
 
 
 
 
491
 
492
  # Collect the speech types names
493
  speech_types_available = set()
@@ -498,8 +731,8 @@ with gr.Blocks() as app_multistyle:
498
  speech_types_available.add(name_input)
499
 
500
  # Parse the gen_text to get the speech types used
501
- segments = parse_speechtypes_text(gen_text)
502
- speech_types_in_text = set(segment["style"] for segment in segments)
503
 
504
  # Check if all speech types in text are available
505
  missing_speech_types = speech_types_in_text - speech_types_available
@@ -511,221 +744,11 @@ with gr.Blocks() as app_multistyle:
511
  # Enable the generate button
512
  return gr.update(interactive=True)
513
 
514
- gen_text_input_multistyle.change(
515
  validate_speech_types,
516
- inputs=[gen_text_input_multistyle, regular_name] + speech_type_names,
517
- outputs=generate_multistyle_btn,
518
  )
519
-
520
-
521
- with gr.Blocks() as app_chat:
522
- gr.Markdown(
523
- """
524
- # Voice Chat
525
- Have a conversation with an AI using your reference voice!
526
- 1. Upload a reference audio clip and optionally its transcript.
527
- 2. Load the chat model.
528
- 3. Record your message through your microphone.
529
- 4. The AI will respond using the reference voice.
530
- """
531
- )
532
-
533
- if not USING_SPACES:
534
- load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")
535
-
536
- chat_interface_container = gr.Column(visible=False)
537
-
538
- @gpu_decorator
539
- def load_chat_model():
540
- global chat_model_state, chat_tokenizer_state
541
- if chat_model_state is None:
542
- show_info = gr.Info
543
- show_info("Loading chat model...")
544
- model_name = "Qwen/Qwen2.5-3B-Instruct"
545
- chat_model_state = AutoModelForCausalLM.from_pretrained(
546
- model_name, torch_dtype="auto", device_map="auto"
547
- )
548
- chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
549
- show_info("Chat model loaded.")
550
-
551
- return gr.update(visible=False), gr.update(visible=True)
552
-
553
- load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container])
554
-
555
- else:
556
- chat_interface_container = gr.Column()
557
-
558
- if chat_model_state is None:
559
- model_name = "Qwen/Qwen2.5-3B-Instruct"
560
- chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
561
- chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
562
-
563
- with chat_interface_container:
564
- with gr.Row():
565
- with gr.Column():
566
- ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
567
- with gr.Column():
568
- with gr.Accordion("Advanced Settings", open=False):
569
- remove_silence_chat = gr.Checkbox(
570
- label="Remove Silences",
571
- value=True,
572
- )
573
- ref_text_chat = gr.Textbox(
574
- label="Reference Text",
575
- info="Optional: Leave blank to auto-transcribe",
576
- lines=2,
577
- )
578
- system_prompt_chat = gr.Textbox(
579
- label="System Prompt",
580
- value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
581
- lines=2,
582
- )
583
-
584
- chatbot_interface = gr.Chatbot(label="Conversation")
585
-
586
- with gr.Row():
587
- with gr.Column():
588
- audio_input_chat = gr.Microphone(
589
- label="Speak your message",
590
- type="filepath",
591
- )
592
- audio_output_chat = gr.Audio(autoplay=True)
593
- with gr.Column():
594
- text_input_chat = gr.Textbox(
595
- label="Type your message",
596
- lines=1,
597
- )
598
- send_btn_chat = gr.Button("Send Message")
599
- clear_btn_chat = gr.Button("Clear Conversation")
600
-
601
- conversation_state = gr.State(
602
- value=[
603
- {
604
- "role": "system",
605
- "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
606
- }
607
- ]
608
- )
609
-
610
- # Modify process_audio_input to use model and tokenizer from state
611
- @gpu_decorator
612
- def process_audio_input(audio_path, text, history, conv_state):
613
- """Handle audio or text input from user"""
614
-
615
- if not audio_path and not text.strip():
616
- return history, conv_state, ""
617
-
618
- if audio_path:
619
- text = preprocess_ref_audio_text(audio_path, text)[1]
620
-
621
- if not text.strip():
622
- return history, conv_state, ""
623
-
624
- conv_state.append({"role": "user", "content": text})
625
- history.append((text, None))
626
-
627
- response = generate_response(conv_state, chat_model_state, chat_tokenizer_state)
628
-
629
- conv_state.append({"role": "assistant", "content": response})
630
- history[-1] = (text, response)
631
-
632
- return history, conv_state, ""
633
-
634
- @gpu_decorator
635
- def generate_audio_response(history, ref_audio, ref_text, remove_silence):
636
- """Generate TTS audio for AI response"""
637
- if not history or not ref_audio:
638
- return None
639
-
640
- last_user_message, last_ai_response = history[-1]
641
- if not last_ai_response:
642
- return None
643
-
644
- audio_result, _, ref_text_out = infer(
645
- ref_audio,
646
- ref_text,
647
- last_ai_response,
648
- tts_model_choice,
649
- remove_silence,
650
- cross_fade_duration=0.15,
651
- speed=1.0,
652
- show_info=print, # show_info=print no pull to top when generating
653
- )
654
- return audio_result, gr.update(value=ref_text_out)
655
-
656
- def clear_conversation():
657
- """Reset the conversation"""
658
- return [], [
659
- {
660
- "role": "system",
661
- "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
662
- }
663
- ]
664
-
665
- def update_system_prompt(new_prompt):
666
- """Update the system prompt and reset the conversation"""
667
- new_conv_state = [{"role": "system", "content": new_prompt}]
668
- return [], new_conv_state
669
-
670
- # Handle audio input
671
- audio_input_chat.stop_recording(
672
- process_audio_input,
673
- inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
674
- outputs=[chatbot_interface, conversation_state],
675
- ).then(
676
- generate_audio_response,
677
- inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
678
- outputs=[audio_output_chat, ref_text_chat],
679
- ).then(
680
- lambda: None,
681
- None,
682
- audio_input_chat,
683
- )
684
-
685
- # Handle text input
686
- text_input_chat.submit(
687
- process_audio_input,
688
- inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
689
- outputs=[chatbot_interface, conversation_state],
690
- ).then(
691
- generate_audio_response,
692
- inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
693
- outputs=[audio_output_chat, ref_text_chat],
694
- ).then(
695
- lambda: None,
696
- None,
697
- text_input_chat,
698
- )
699
-
700
- # Handle send button
701
- send_btn_chat.click(
702
- process_audio_input,
703
- inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
704
- outputs=[chatbot_interface, conversation_state],
705
- ).then(
706
- generate_audio_response,
707
- inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
708
- outputs=[audio_output_chat, ref_text_chat],
709
- ).then(
710
- lambda: None,
711
- None,
712
- text_input_chat,
713
- )
714
-
715
- # Handle clear button
716
- clear_btn_chat.click(
717
- clear_conversation,
718
- outputs=[chatbot_interface, conversation_state],
719
- )
720
-
721
- # Handle system prompt change and reset conversation
722
- system_prompt_chat.change(
723
- update_system_prompt,
724
- inputs=system_prompt_chat,
725
- outputs=[chatbot_interface, conversation_state],
726
- )
727
-
728
-
729
  with gr.Blocks() as app:
730
  gr.Markdown(
731
  """
@@ -736,89 +759,14 @@ This is a local web UI for F5 TTS with advanced batch processing support. This a
736
  * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
737
  * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
738
 
739
- The checkpoints currently support English and Chinese.
740
 
741
- If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s with ✂ in the bottom right corner (otherwise might have non-optimal auto-trimmed result).
742
 
743
  **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
744
  """
745
  )
746
-
747
- last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom.txt")
748
-
749
- def load_last_used_custom():
750
- try:
751
- with open(last_used_custom, "r") as f:
752
- return f.read().split(",")
753
- except FileNotFoundError:
754
- last_used_custom.parent.mkdir(parents=True, exist_ok=True)
755
- return [
756
- "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
757
- "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
758
- ]
759
-
760
- def switch_tts_model(new_choice):
761
- global tts_model_choice
762
- if new_choice == "Custom": # override in case webpage is refreshed
763
- custom_ckpt_path, custom_vocab_path = load_last_used_custom()
764
- tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path]
765
- return gr.update(visible=True, value=custom_ckpt_path), gr.update(visible=True, value=custom_vocab_path)
766
- else:
767
- tts_model_choice = new_choice
768
- return gr.update(visible=False), gr.update(visible=False)
769
-
770
- def set_custom_model(custom_ckpt_path, custom_vocab_path):
771
- global tts_model_choice
772
- tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path]
773
- with open(last_used_custom, "w") as f:
774
- f.write(f"{custom_ckpt_path},{custom_vocab_path}")
775
-
776
- with gr.Row():
777
- if not USING_SPACES:
778
- choose_tts_model = gr.Radio(
779
- choices=[DEFAULT_TTS_MODEL, "E2-TTS", "Custom"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
780
- )
781
- else:
782
- choose_tts_model = gr.Radio(
783
- choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
784
- )
785
- custom_ckpt_path = gr.Dropdown(
786
- choices=["hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"],
787
- value=load_last_used_custom()[0],
788
- allow_custom_value=True,
789
- label="MODEL CKPT: local_path | hf://user_id/repo_id/model_ckpt",
790
- visible=False,
791
- )
792
- custom_vocab_path = gr.Dropdown(
793
- choices=["hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt"],
794
- value=load_last_used_custom()[1],
795
- allow_custom_value=True,
796
- label="VOCAB FILE: local_path | hf://user_id/repo_id/vocab_file",
797
- visible=False,
798
- )
799
-
800
- choose_tts_model.change(
801
- switch_tts_model,
802
- inputs=[choose_tts_model],
803
- outputs=[custom_ckpt_path, custom_vocab_path],
804
- show_progress="hidden",
805
- )
806
- custom_ckpt_path.change(
807
- set_custom_model,
808
- inputs=[custom_ckpt_path, custom_vocab_path],
809
- show_progress="hidden",
810
- )
811
- custom_vocab_path.change(
812
- set_custom_model,
813
- inputs=[custom_ckpt_path, custom_vocab_path],
814
- show_progress="hidden",
815
- )
816
-
817
- gr.TabbedInterface(
818
- [app_tts, app_multistyle, app_chat, app_credits],
819
- ["Basic-TTS", "Multi-Speech", "Voice-Chat", "Credits"],
820
- )
821
-
822
 
823
  @click.command()
824
  @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
@@ -831,21 +779,13 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
831
  help="Share the app via Gradio share link",
832
  )
833
  @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
834
- @click.option(
835
- "--root_path",
836
- "-r",
837
- default=None,
838
- type=str,
839
- help='The root path (or "mount point") of the application, if it\'s not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application, e.g. set "/myapp" or full URL for application served at "https://example.com/myapp".',
840
- )
841
- def main(port, host, share, api, root_path):
842
  global app
843
- print("Starting app...")
844
- app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api, root_path=root_path)
 
 
845
 
846
 
847
- if __name__ == "__main__":
848
- if not USING_SPACES:
849
- main()
850
- else:
851
- app.queue().launch()
 
 
 
 
1
  import re
2
+ import torch
3
+ import torchaudio
 
 
 
4
  import gradio as gr
5
  import numpy as np
6
+ import tempfile
7
+ from einops import rearrange
8
+ from vocos import Vocos
9
+ from pydub import AudioSegment, silence
10
+ from model import CFM, UNetT, DiT, MMDiT
11
  from cached_path import cached_path
12
+ from model.utils import (
13
+ load_checkpoint,
14
+ get_tokenizer,
15
+ convert_char_to_pinyin,
16
+ save_spectrogram,
17
+ )
18
+ from transformers import pipeline
19
+ import click
20
+ import soundfile as sf
21
 
22
  try:
23
  import spaces
 
24
  USING_SPACES = True
25
  except ImportError:
26
  USING_SPACES = False
27
 
 
28
  def gpu_decorator(func):
29
  if USING_SPACES:
30
  return spaces.GPU(func)
31
  else:
32
  return func
33
 
34
+ device = (
35
+ "cuda"
36
+ if torch.cuda.is_available()
37
+ else "mps" if torch.backends.mps.is_available() else "cpu"
 
 
 
 
 
38
  )
39
 
40
+ print(f"Using {device} device")
41
 
42
+ pipe = pipeline(
43
+ "automatic-speech-recognition",
44
+ model="openai/whisper-large-v3-turbo",
45
+ torch_dtype=torch.float16,
46
+ device=device,
47
+ )
48
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
49
+
50
+ # --------------------- Settings -------------------- #
51
+
52
+ target_sample_rate = 24000
53
+ n_mel_channels = 100
54
+ hop_length = 256
55
+ target_rms = 0.1
56
+ nfe_step = 32 # 16, 32
57
+ cfg_strength = 2.0
58
+ ode_method = "euler"
59
+ sway_sampling_coef = -1.0
60
+ speed = 1.0
61
+ fix_duration = None
62
+
63
+
64
+ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
65
+ ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
66
+ # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
67
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
68
+ model = CFM(
69
+ transformer=model_cls(
70
+ **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
71
+ ),
72
+ mel_spec_kwargs=dict(
73
+ target_sample_rate=target_sample_rate,
74
+ n_mel_channels=n_mel_channels,
75
+ hop_length=hop_length,
76
+ ),
77
+ odeint_kwargs=dict(
78
+ method=ode_method,
79
+ ),
80
+ vocab_char_map=vocab_char_map,
81
+ ).to(device)
82
+
83
+ model = load_checkpoint(model, ckpt_path, device, use_ema = True)
84
+
85
+ return model
86
 
87
 
88
  # load models
89
+ F5TTS_model_cfg = dict(
90
+ dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
91
+ )
92
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
93
 
94
+ F5TTS_ema_model = load_model(
95
+ "F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
96
+ )
97
+ E2TTS_ema_model = load_model(
98
+ "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
99
+ )
100
 
101
+ def chunk_text(text, max_chars=135):
102
+ """
103
+ Splits the input text into chunks, each with a maximum number of characters.
104
 
105
+ Args:
106
+ text (str): The text to be split.
107
+ max_chars (int): The maximum number of characters per chunk.
108
 
109
+ Returns:
110
+ List[str]: A list of text chunks.
111
+ """
112
+ chunks = []
113
+ current_chunk = ""
114
+ # Split the text into sentences based on punctuation followed by whitespace
115
+ sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
116
+
117
+ for sentence in sentences:
118
+ if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
119
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
120
+ else:
121
+ if current_chunk:
122
+ chunks.append(current_chunk.strip())
123
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
124
 
125
+ if current_chunk:
126
+ chunks.append(current_chunk.strip())
127
 
128
+ return chunks
 
 
 
 
 
 
 
 
129
 
130
+ @gpu_decorator
131
+ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, cross_fade_duration=0.15, progress=gr.Progress()):
132
+ if exp_name == "F5-TTS":
133
+ ema_model = F5TTS_ema_model
134
+ elif exp_name == "E2-TTS":
135
+ ema_model = E2TTS_ema_model
136
 
137
+ audio, sr = ref_audio
138
+ if audio.shape[0] > 1:
139
+ audio = torch.mean(audio, dim=0, keepdim=True)
140
+
141
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
142
+ if rms < target_rms:
143
+ audio = audio * target_rms / rms
144
+ if sr != target_sample_rate:
145
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
146
+ audio = resampler(audio)
147
+ audio = audio.to(device)
148
+
149
+ generated_waves = []
150
+ spectrograms = []
151
+
152
+ for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
153
+ # Prepare the text
154
+ if len(ref_text[-1].encode('utf-8')) == 1:
155
+ ref_text = ref_text + " "
156
+ text_list = [ref_text + gen_text]
157
+ final_text_list = convert_char_to_pinyin(text_list)
158
+
159
+ # Calculate duration
160
+ ref_audio_len = audio.shape[-1] // hop_length
161
+ zh_pause_punc = r"。,、;:?!"
162
+ ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
163
+ gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
164
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
165
+
166
+ # inference
167
+ with torch.inference_mode():
168
+ generated, _ = ema_model.sample(
169
+ cond=audio,
170
+ text=final_text_list,
171
+ duration=duration,
172
+ steps=nfe_step,
173
+ cfg_strength=cfg_strength,
174
+ sway_sampling_coef=sway_sampling_coef,
175
+ )
176
 
177
+ generated = generated[:, ref_audio_len:, :]
178
+ generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
179
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
180
+ if rms < target_rms:
181
+ generated_wave = generated_wave * rms / target_rms
182
+
183
+ # wav -> numpy
184
+ generated_wave = generated_wave.squeeze().cpu().numpy()
185
+
186
+ generated_waves.append(generated_wave)
187
+ spectrograms.append(generated_mel_spec[0].cpu().numpy())
188
+
189
+ # Combine all generated waves with cross-fading
190
+ if cross_fade_duration <= 0:
191
+ # Simply concatenate
192
+ final_wave = np.concatenate(generated_waves)
193
+ else:
194
+ final_wave = generated_waves[0]
195
+ for i in range(1, len(generated_waves)):
196
+ prev_wave = final_wave
197
+ next_wave = generated_waves[i]
198
 
199
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
200
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
201
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
202
 
203
+ if cross_fade_samples <= 0:
204
+ # No overlap possible, concatenate
205
+ final_wave = np.concatenate([prev_wave, next_wave])
206
+ continue
 
 
 
 
207
 
208
+ # Overlapping parts
209
+ prev_overlap = prev_wave[-cross_fade_samples:]
210
+ next_overlap = next_wave[:cross_fade_samples]
 
 
 
 
211
 
212
+ # Fade out and fade in
213
+ fade_out = np.linspace(1, 0, cross_fade_samples)
214
+ fade_in = np.linspace(0, 1, cross_fade_samples)
 
215
 
216
+ # Cross-faded overlap
217
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
218
 
219
+ # Combine
220
+ new_wave = np.concatenate([
221
+ prev_wave[:-cross_fade_samples],
222
+ cross_faded_overlap,
223
+ next_wave[cross_fade_samples:]
224
+ ])
225
 
226
+ final_wave = new_wave
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  # Remove silence
229
  if remove_silence:
230
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
231
+ sf.write(f.name, final_wave, target_sample_rate)
232
+ aseg = AudioSegment.from_file(f.name)
233
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
234
+ non_silent_wave = AudioSegment.silent(duration=0)
235
+ for non_silent_seg in non_silent_segs:
236
+ non_silent_wave += non_silent_seg
237
+ aseg = non_silent_wave
238
+ aseg.export(f.name, format="wav")
239
  final_wave, _ = torchaudio.load(f.name)
240
  final_wave = final_wave.squeeze().cpu().numpy()
241
 
242
+ # Create a combined spectrogram
243
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
244
+
245
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
246
  spectrogram_path = tmp_spectrogram.name
247
  save_spectrogram(combined_spectrogram, spectrogram_path)
248
 
249
+ return (target_sample_rate, final_wave), spectrogram_path
250
+
251
+ @gpu_decorator
252
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, cross_fade_duration=0.15):
253
+
254
+ print(gen_text)
255
+
256
+ gr.Info("Converting audio...")
257
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
258
+ aseg = AudioSegment.from_file(ref_audio_orig)
259
+
260
+ non_silent_segs = silence.split_on_silence(
261
+ aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000
262
+ )
263
+ non_silent_wave = AudioSegment.silent(duration=0)
264
+ for non_silent_seg in non_silent_segs:
265
+ non_silent_wave += non_silent_seg
266
+ aseg = non_silent_wave
267
+
268
+ audio_duration = len(aseg)
269
+ if audio_duration > 15000:
270
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
271
+ aseg = aseg[:15000]
272
+ aseg.export(f.name, format="wav")
273
+ ref_audio = f.name
274
+
275
+ if not ref_text.strip():
276
+ gr.Info("No reference text provided, transcribing reference audio...")
277
+ ref_text = pipe(
278
+ ref_audio,
279
+ chunk_length_s=30,
280
+ batch_size=128,
281
+ generate_kwargs={"task": "transcribe"},
282
+ return_timestamps=False,
283
+ )["text"].strip()
284
+ gr.Info("Finished transcription")
285
+ else:
286
+ gr.Info("Using custom reference text...")
287
+
288
+ # Add the functionality to ensure it ends with ". "
289
+ if not ref_text.endswith(". "):
290
+ if ref_text.endswith("."):
291
+ ref_text += " "
292
+ else:
293
+ ref_text += ". "
294
+
295
+ audio, sr = torchaudio.load(ref_audio)
296
+
297
+ # Use the new chunk_text function to split gen_text
298
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
299
+ gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
300
+ print('ref_text', ref_text)
301
+ for i, batch_text in enumerate(gen_text_batches):
302
+ print(f'gen_text {i}', batch_text)
303
+
304
+ gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
305
+ return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence, cross_fade_duration)
306
+
307
+
308
+ @gpu_decorator
309
+ def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
310
+ # Split the script into speaker blocks
311
+ speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
312
+ speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
313
+
314
+ generated_audio_segments = []
315
+
316
+ for i in range(0, len(speaker_blocks), 2):
317
+ speaker = speaker_blocks[i]
318
+ text = speaker_blocks[i+1].strip()
319
+
320
+ # Determine which speaker is talking
321
+ if speaker == speaker1_name:
322
+ ref_audio = ref_audio1
323
+ ref_text = ref_text1
324
+ elif speaker == speaker2_name:
325
+ ref_audio = ref_audio2
326
+ ref_text = ref_text2
327
+ else:
328
+ continue # Skip if the speaker is neither speaker1 nor speaker2
329
+
330
+ # Generate audio for this block
331
+ audio, _ = infer(ref_audio, ref_text, text, exp_name, remove_silence)
332
+
333
+ # Convert the generated audio to a numpy array
334
+ sr, audio_data = audio
335
+
336
+ # Save the audio data as a WAV file
337
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
338
+ sf.write(temp_file.name, audio_data, sr)
339
+ audio_segment = AudioSegment.from_wav(temp_file.name)
340
+
341
+ generated_audio_segments.append(audio_segment)
342
+
343
+ # Add a short pause between speakers
344
+ pause = AudioSegment.silent(duration=500) # 500ms pause
345
+ generated_audio_segments.append(pause)
346
+
347
+ # Concatenate all audio segments
348
+ final_podcast = sum(generated_audio_segments)
349
+
350
+ # Export the final podcast
351
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
352
+ podcast_path = temp_file.name
353
+ final_podcast.export(podcast_path, format="wav")
354
+
355
+ return podcast_path
356
+
357
+ def parse_speechtypes_text(gen_text):
358
+ # Pattern to find (Emotion)
359
+ pattern = r'\((.*?)\)'
360
+
361
+ # Split the text by the pattern
362
+ tokens = re.split(pattern, gen_text)
363
+
364
+ segments = []
365
 
366
+ current_emotion = 'Regular'
367
+
368
+ for i in range(len(tokens)):
369
+ if i % 2 == 0:
370
+ # This is text
371
+ text = tokens[i].strip()
372
+ if text:
373
+ segments.append({'emotion': current_emotion, 'text': text})
374
+ else:
375
+ # This is emotion
376
+ emotion = tokens[i].strip()
377
+ current_emotion = emotion
378
+
379
+ return segments
380
+
381
+ def update_speed(new_speed):
382
+ global speed
383
+ speed = new_speed
384
+ return f"Speed set to: {speed}"
385
 
386
  with gr.Blocks() as app_credits:
387
  gr.Markdown("""
388
  # Credits
389
 
390
  * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
391
+ * [RootingInLoad](https://github.com/RootingInLoad) for the podcast generation
392
+ * [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation
393
  """)
394
  with gr.Blocks() as app_tts:
395
  gr.Markdown("# Batched TTS")
396
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
397
  gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
398
+ model_choice = gr.Radio(
399
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
400
+ )
401
  generate_btn = gr.Button("Synthesize", variant="primary")
402
  with gr.Accordion("Advanced Settings", open=False):
403
  ref_text_input = gr.Textbox(
 
414
  label="Speed",
415
  minimum=0.3,
416
  maximum=2.0,
417
+ value=speed,
418
  step=0.1,
419
  info="Adjust the speed of the audio.",
420
  )
 
426
  step=0.01,
427
  info="Set the duration of the cross-fade between audio clips.",
428
  )
429
+ speed_slider.change(update_speed, inputs=speed_slider)
430
 
431
  audio_output = gr.Audio(label="Synthesized Audio")
432
  spectrogram_output = gr.Image(label="Spectrogram")
433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  generate_btn.click(
435
+ infer,
436
  inputs=[
437
  ref_audio_input,
438
  ref_text_input,
439
  gen_text_input,
440
+ model_choice,
441
  remove_silence,
442
  cross_fade_duration_slider,
 
443
  ],
444
+ outputs=[audio_output, spectrogram_output],
445
+ )
446
+
447
+ with gr.Blocks() as app_podcast:
448
+ gr.Markdown("# Podcast Generation")
449
+ speaker1_name = gr.Textbox(label="Speaker 1 Name")
450
+ ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
451
+ ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
452
+
453
+ speaker2_name = gr.Textbox(label="Speaker 2 Name")
454
+ ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
455
+ ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
456
+
457
+ script_input = gr.Textbox(label="Podcast Script", lines=10,
458
+ placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...")
459
+
460
+ podcast_model_choice = gr.Radio(
461
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
462
+ )
463
+ podcast_remove_silence = gr.Checkbox(
464
+ label="Remove Silences",
465
+ value=True,
466
  )
467
+ generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
468
+ podcast_output = gr.Audio(label="Generated Podcast")
469
 
470
+ def podcast_generation(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence):
471
+ return generate_podcast(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence)
472
 
473
+ generate_podcast_btn.click(
474
+ podcast_generation,
475
+ inputs=[
476
+ script_input,
477
+ speaker1_name,
478
+ ref_audio_input1,
479
+ ref_text_input1,
480
+ speaker2_name,
481
+ ref_audio_input2,
482
+ ref_text_input2,
483
+ podcast_model_choice,
484
+ podcast_remove_silence,
485
+ ],
486
+ outputs=podcast_output,
487
+ )
488
+
489
+ def parse_emotional_text(gen_text):
490
+ # Pattern to find (Emotion)
491
+ pattern = r'\((.*?)\)'
492
 
493
  # Split the text by the pattern
494
  tokens = re.split(pattern, gen_text)
495
 
496
  segments = []
497
 
498
+ current_emotion = 'Regular'
499
 
500
  for i in range(len(tokens)):
501
  if i % 2 == 0:
502
  # This is text
503
  text = tokens[i].strip()
504
  if text:
505
+ segments.append({'emotion': current_emotion, 'text': text})
506
  else:
507
+ # This is emotion
508
+ emotion = tokens[i].strip()
509
+ current_emotion = emotion
510
 
511
  return segments
512
 
513
+ with gr.Blocks() as app_emotional:
514
+ # New section for emotional generation
 
515
  gr.Markdown(
516
  """
517
  # Multiple Speech-Type Generation
518
 
519
+ This section allows you to upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the "Add Speech Type" button. Enter your text in the format shown below, and the system will generate speech using the appropriate emotions. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
 
 
520
 
521
+ **Example Input:**
 
 
 
 
 
 
 
 
 
 
 
522
 
523
+ (Regular) Hello, I'd like to order a sandwich please. (Surprised) What do you mean you're out of bread? (Sad) I really wanted a sandwich though... (Angry) You know what, darn you and your little shop, you suck! (Whisper) I'll just go back home and cry now. (Shouting) Why me?!
524
+ """
 
 
 
 
 
 
 
 
 
 
525
  )
526
 
527
+ gr.Markdown("Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button.")
528
+
529
  # Regular speech type (mandatory)
530
  with gr.Row():
531
+ regular_name = gr.Textbox(value='Regular', label='Speech Type Name', interactive=False)
532
+ regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
533
+ regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
 
 
534
 
535
+ # Additional speech types (up to 99 more)
536
  max_speech_types = 100
537
+ speech_type_names = []
538
+ speech_type_audios = []
539
+ speech_type_ref_texts = []
540
+ speech_type_delete_btns = []
541
+
 
 
 
542
  for i in range(max_speech_types - 1):
543
+ with gr.Row():
544
+ name_input = gr.Textbox(label='Speech Type Name', visible=False)
545
+ audio_input = gr.Audio(label='Reference Audio', type='filepath', visible=False)
546
+ ref_text_input = gr.Textbox(label='Reference Text', lines=2, visible=False)
547
+ delete_btn = gr.Button("Delete", variant="secondary", visible=False)
 
 
 
548
  speech_type_names.append(name_input)
549
  speech_type_audios.append(audio_input)
550
  speech_type_ref_texts.append(ref_text_input)
551
  speech_type_delete_btns.append(delete_btn)
 
552
 
553
  # Button to add speech type
554
  add_speech_type_btn = gr.Button("Add Speech Type")
555
 
556
  # Keep track of current number of speech types
557
+ speech_type_count = gr.State(value=0)
558
 
559
  # Function to add a speech type
560
  def add_speech_type_fn(speech_type_count):
561
+ if speech_type_count < max_speech_types - 1:
562
  speech_type_count += 1
563
+ # Prepare updates for the components
564
+ name_updates = []
565
+ audio_updates = []
566
+ ref_text_updates = []
567
+ delete_btn_updates = []
568
+ for i in range(max_speech_types - 1):
569
  if i < speech_type_count:
570
+ name_updates.append(gr.update(visible=True))
571
+ audio_updates.append(gr.update(visible=True))
572
+ ref_text_updates.append(gr.update(visible=True))
573
+ delete_btn_updates.append(gr.update(visible=True))
574
  else:
575
+ name_updates.append(gr.update())
576
+ audio_updates.append(gr.update())
577
+ ref_text_updates.append(gr.update())
578
+ delete_btn_updates.append(gr.update())
579
  else:
580
  # Optionally, show a warning
581
+ # gr.Warning("Maximum number of speech types reached.")
582
+ name_updates = [gr.update() for _ in range(max_speech_types - 1)]
583
+ audio_updates = [gr.update() for _ in range(max_speech_types - 1)]
584
+ ref_text_updates = [gr.update() for _ in range(max_speech_types - 1)]
585
+ delete_btn_updates = [gr.update() for _ in range(max_speech_types - 1)]
586
+ return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
587
 
588
  add_speech_type_btn.click(
589
+ add_speech_type_fn,
590
+ inputs=speech_type_count,
591
+ outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
592
  )
593
 
594
  # Function to delete a speech type
595
  def make_delete_speech_type_fn(index):
596
  def delete_speech_type_fn(speech_type_count):
597
  # Prepare updates
598
+ name_updates = []
599
+ audio_updates = []
600
+ ref_text_updates = []
601
+ delete_btn_updates = []
602
 
603
+ for i in range(max_speech_types - 1):
604
  if i == index:
605
+ name_updates.append(gr.update(visible=False, value=''))
606
+ audio_updates.append(gr.update(visible=False, value=None))
607
+ ref_text_updates.append(gr.update(visible=False, value=''))
608
+ delete_btn_updates.append(gr.update(visible=False))
609
  else:
610
+ name_updates.append(gr.update())
611
+ audio_updates.append(gr.update())
612
+ ref_text_updates.append(gr.update())
613
+ delete_btn_updates.append(gr.update())
614
 
615
+ speech_type_count = max(0, speech_type_count - 1)
616
 
617
+ return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
618
 
619
  return delete_speech_type_fn
620
 
 
621
  for i, delete_btn in enumerate(speech_type_delete_btns):
622
  delete_fn = make_delete_speech_type_fn(i)
623
+ delete_btn.click(
624
+ delete_fn,
625
+ inputs=speech_type_count,
626
+ outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
627
+ )
628
 
629
  # Text input for the prompt
630
+ gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
 
 
 
 
631
 
632
+ # Model choice
633
+ model_choice_emotional = gr.Radio(
634
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
635
+ )
 
 
 
 
 
 
 
 
 
 
 
 
636
 
637
  with gr.Accordion("Advanced Settings", open=False):
638
+ remove_silence_emotional = gr.Checkbox(
639
  label="Remove Silences",
640
  value=True,
641
  )
642
 
643
  # Generate button
644
+ generate_emotional_btn = gr.Button("Generate Emotional Speech", variant="primary")
645
 
646
  # Output audio
647
+ audio_output_emotional = gr.Audio(label="Synthesized Audio")
 
648
  @gpu_decorator
649
+ def generate_emotional_speech(
650
+ regular_audio,
651
+ regular_ref_text,
652
  gen_text,
653
  *args,
654
  ):
655
+ num_additional_speech_types = max_speech_types - 1
656
+ speech_type_names_list = args[:num_additional_speech_types]
657
+ speech_type_audios_list = args[num_additional_speech_types:2 * num_additional_speech_types]
658
+ speech_type_ref_texts_list = args[2 * num_additional_speech_types:3 * num_additional_speech_types]
659
+ model_choice = args[3 * num_additional_speech_types]
660
+ remove_silence = args[3 * num_additional_speech_types + 1]
661
+
662
  # Collect the speech types and their audios into a dict
663
+ speech_types = {'Regular': {'audio': regular_audio, 'ref_text': regular_ref_text}}
664
 
665
+ for name_input, audio_input, ref_text_input in zip(speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list):
 
 
 
666
  if name_input and audio_input:
667
+ speech_types[name_input] = {'audio': audio_input, 'ref_text': ref_text_input}
 
 
 
668
 
669
  # Parse the gen_text into segments
670
  segments = parse_speechtypes_text(gen_text)
671
 
672
  # For each segment, generate speech
673
  generated_audio_segments = []
674
+ current_emotion = 'Regular'
675
 
676
  for segment in segments:
677
+ emotion = segment['emotion']
678
+ text = segment['text']
679
 
680
+ if emotion in speech_types:
681
+ current_emotion = emotion
682
  else:
683
+ # If emotion not available, default to Regular
684
+ current_emotion = 'Regular'
685
 
686
+ ref_audio = speech_types[current_emotion]['audio']
687
+ ref_text = speech_types[current_emotion].get('ref_text', '')
688
 
689
  # Generate speech for this segment
690
+ audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, 0)
691
+ sr, audio_data = audio
 
 
692
 
693
  generated_audio_segments.append(audio_data)
 
694
 
695
  # Concatenate all audio segments
696
  if generated_audio_segments:
697
  final_audio_data = np.concatenate(generated_audio_segments)
698
+ return (sr, final_audio_data)
 
 
699
  else:
700
  gr.Warning("No audio generated.")
701
+ return None
702
 
703
+ generate_emotional_btn.click(
704
+ generate_emotional_speech,
705
  inputs=[
706
+ regular_audio,
707
+ regular_ref_text,
708
+ gen_text_input_emotional,
709
+ ] + speech_type_names + speech_type_audios + speech_type_ref_texts + [
710
+ model_choice_emotional,
711
+ remove_silence_emotional,
 
712
  ],
713
+ outputs=audio_output_emotional,
714
  )
715
 
716
  # Validation function to disable Generate button if speech types are missing
717
+ def validate_speech_types(
718
+ gen_text,
719
+ regular_name,
720
+ *args
721
+ ):
722
+ num_additional_speech_types = max_speech_types - 1
723
+ speech_type_names_list = args[:num_additional_speech_types]
724
 
725
  # Collect the speech types names
726
  speech_types_available = set()
 
731
  speech_types_available.add(name_input)
732
 
733
  # Parse the gen_text to get the speech types used
734
+ segments = parse_emotional_text(gen_text)
735
+ speech_types_in_text = set(segment['emotion'] for segment in segments)
736
 
737
  # Check if all speech types in text are available
738
  missing_speech_types = speech_types_in_text - speech_types_available
 
744
  # Enable the generate button
745
  return gr.update(interactive=True)
746
 
747
+ gen_text_input_emotional.change(
748
  validate_speech_types,
749
+ inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
750
+ outputs=generate_emotional_btn
751
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
752
  with gr.Blocks() as app:
753
  gr.Markdown(
754
  """
 
759
  * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
760
  * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
761
 
762
+ The checkpoints support English and Chinese.
763
 
764
+ If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
765
 
766
  **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
767
  """
768
  )
769
+ gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
770
 
771
  @click.command()
772
  @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
 
779
  help="Share the app via Gradio share link",
780
  )
781
  @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
782
+ def main(port, host, share, api):
 
 
 
 
 
 
 
783
  global app
784
+ print(f"Starting app...")
785
+ app.queue(api_open=api).launch(
786
+ server_name=host, server_port=port, share=share, show_api=api
787
+ )
788
 
789
 
790
+
791
+ app.queue().launch()
 
 
 
app_local.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print("WARNING: You are running this unofficial E2/F5 TTS demo locally, it may not be as up-to-date as the hosted version (https://huggingface.co/spaces/mrfakename/E2-F5-TTS)")
2
+
3
+ import os
4
+ import re
5
+ import torch
6
+ import torchaudio
7
+ import gradio as gr
8
+ import numpy as np
9
+ import tempfile
10
+ from einops import rearrange
11
+ from ema_pytorch import EMA
12
+ from vocos import Vocos
13
+ from pydub import AudioSegment, silence
14
+ from model import CFM, UNetT, DiT, MMDiT
15
+ from cached_path import cached_path
16
+ from model.utils import (
17
+ get_tokenizer,
18
+ convert_char_to_pinyin,
19
+ save_spectrogram,
20
+ )
21
+ from transformers import pipeline
22
+ import librosa
23
+ import soundfile as sf
24
+ from txtsplit import txtsplit
25
+
26
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
27
+
28
+ pipe = pipeline(
29
+ "automatic-speech-recognition",
30
+ model="openai/whisper-large-v3-turbo",
31
+ torch_dtype=torch.float16,
32
+ device=device,
33
+ )
34
+
35
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
36
+
37
+ # --------------------- Settings -------------------- #
38
+
39
+ target_sample_rate = 24000
40
+ n_mel_channels = 100
41
+ hop_length = 256
42
+ target_rms = 0.1
43
+ nfe_step = 32 # 16, 32
44
+ cfg_strength = 2.0
45
+ ode_method = 'euler'
46
+ sway_sampling_coef = -1.0
47
+ speed = 1.0
48
+ # fix_duration = 27 # None or float (duration in seconds)
49
+ fix_duration = None
50
+
51
+ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
52
+ checkpoint = torch.load(str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
53
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
54
+ model = CFM(
55
+ transformer=model_cls(
56
+ **model_cfg,
57
+ text_num_embeds=vocab_size,
58
+ mel_dim=n_mel_channels
59
+ ),
60
+ mel_spec_kwargs=dict(
61
+ target_sample_rate=target_sample_rate,
62
+ n_mel_channels=n_mel_channels,
63
+ hop_length=hop_length,
64
+ ),
65
+ odeint_kwargs=dict(
66
+ method=ode_method,
67
+ ),
68
+ vocab_char_map=vocab_char_map,
69
+ ).to(device)
70
+
71
+ ema_model = EMA(model, include_online_model=False).to(device)
72
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
73
+ ema_model.copy_params_from_ema_to_model()
74
+
75
+ return model
76
+
77
+ # load models
78
+ F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
79
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
80
+
81
+ F5TTS_ema_model = load_model("F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
82
+ E2TTS_ema_model = load_model("E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
83
+
84
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
85
+ print(gen_text)
86
+ gr.Info("Converting audio...")
87
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
88
+ aseg = AudioSegment.from_file(ref_audio_orig)
89
+ # remove long silence in reference audio
90
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
91
+ non_silent_wave = AudioSegment.silent(duration=0)
92
+ for non_silent_seg in non_silent_segs:
93
+ non_silent_wave += non_silent_seg
94
+ aseg = non_silent_wave
95
+ # Convert to mono
96
+ aseg = aseg.set_channels(1)
97
+ audio_duration = len(aseg)
98
+ if audio_duration > 15000:
99
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
100
+ aseg = aseg[:15000]
101
+ aseg.export(f.name, format="wav")
102
+ ref_audio = f.name
103
+ if exp_name == "F5-TTS":
104
+ ema_model = F5TTS_ema_model
105
+ elif exp_name == "E2-TTS":
106
+ ema_model = E2TTS_ema_model
107
+
108
+ if not ref_text.strip():
109
+ gr.Info("No reference text provided, transcribing reference audio...")
110
+ ref_text = outputs = pipe(
111
+ ref_audio,
112
+ chunk_length_s=30,
113
+ batch_size=128,
114
+ generate_kwargs={"task": "transcribe"},
115
+ return_timestamps=False,
116
+ )['text'].strip()
117
+ gr.Info("Finished transcription")
118
+ else:
119
+ gr.Info("Using custom reference text...")
120
+ audio, sr = torchaudio.load(ref_audio)
121
+ max_chars = int(len(ref_text) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
122
+ # Audio
123
+ if audio.shape[0] > 1:
124
+ audio = torch.mean(audio, dim=0, keepdim=True)
125
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
126
+ if rms < target_rms:
127
+ audio = audio * target_rms / rms
128
+ if sr != target_sample_rate:
129
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
130
+ audio = resampler(audio)
131
+ audio = audio.to(device)
132
+ # Chunk
133
+ chunks = txtsplit(gen_text, 0.7*max_chars, 0.9*max_chars) # 100 chars preferred, 150 max
134
+ results = []
135
+ generated_mel_specs = []
136
+ for chunk in progress.tqdm(chunks):
137
+ # Prepare the text
138
+ text_list = [ref_text + chunk]
139
+ final_text_list = convert_char_to_pinyin(text_list)
140
+
141
+ # Calculate duration
142
+ ref_audio_len = audio.shape[-1] // hop_length
143
+ # if fix_duration is not None:
144
+ # duration = int(fix_duration * target_sample_rate / hop_length)
145
+ # else:
146
+ zh_pause_punc = r"。,、;:?!"
147
+ ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
148
+ chunk = len(chunk.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
149
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * chunk / speed)
150
+
151
+ # inference
152
+ gr.Info(f"Generating audio using {exp_name}")
153
+ with torch.inference_mode():
154
+ generated, _ = ema_model.sample(
155
+ cond=audio,
156
+ text=final_text_list,
157
+ duration=duration,
158
+ steps=nfe_step,
159
+ cfg_strength=cfg_strength,
160
+ sway_sampling_coef=sway_sampling_coef,
161
+ )
162
+
163
+ generated = generated[:, ref_audio_len:, :]
164
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
165
+ gr.Info("Running vocoder")
166
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
167
+ if rms < target_rms:
168
+ generated_wave = generated_wave * rms / target_rms
169
+
170
+ # wav -> numpy
171
+ generated_wave = generated_wave.squeeze().cpu().numpy()
172
+ results.append(generated_wave)
173
+ generated_wave = np.concatenate(results)
174
+ if remove_silence:
175
+ gr.Info("Removing audio silences... This may take a moment")
176
+ # non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
177
+ # non_silent_wave = np.array([])
178
+ # for interval in non_silent_intervals:
179
+ # start, end = interval
180
+ # non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
181
+ # generated_wave = non_silent_wave
182
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
183
+ sf.write(f.name, generated_wave, target_sample_rate)
184
+ aseg = AudioSegment.from_file(f.name)
185
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
186
+ non_silent_wave = AudioSegment.silent(duration=0)
187
+ for non_silent_seg in non_silent_segs:
188
+ non_silent_wave += non_silent_seg
189
+ aseg = non_silent_wave
190
+ aseg.export(f.name, format="wav")
191
+ generated_wave, _ = torchaudio.load(f.name)
192
+ generated_wave = generated_wave.squeeze().cpu().numpy()
193
+
194
+ # spectogram
195
+ # with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
196
+ # spectrogram_path = tmp_spectrogram.name
197
+ # save_spectrogram(generated_mel_spec[0].cpu().numpy(), spectrogram_path)
198
+
199
+ return (target_sample_rate, generated_wave)
200
+
201
+ with gr.Blocks() as app:
202
+ gr.Markdown("""
203
+ # E2/F5 TTS
204
+
205
+ This is an unofficial E2/F5 TTS demo. This demo supports the following TTS models:
206
+
207
+ * [E2-TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
208
+ * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
209
+
210
+ This demo is based on the [F5-TTS](https://github.com/SWivid/F5-TTS) codebase, which is based on an [unofficial E2-TTS implementation](https://github.com/lucidrains/e2-tts-pytorch).
211
+
212
+ The checkpoints support English and Chinese.
213
+
214
+ If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt. If you're still running into issues, please open a [community Discussion](https://huggingface.co/spaces/mrfakename/E2-F5-TTS/discussions).
215
+
216
+ Long-form/batched inference + speech editing is coming soon!
217
+
218
+ **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
219
+ """)
220
+
221
+ ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
222
+ gen_text_input = gr.Textbox(label="Text to Generate (longer text will use chunking)", lines=4)
223
+ model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
224
+ generate_btn = gr.Button("Synthesize", variant="primary")
225
+ with gr.Accordion("Advanced Settings", open=False):
226
+ ref_text_input = gr.Textbox(label="Reference Text", info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.", lines=2)
227
+ remove_silence = gr.Checkbox(label="Remove Silences", info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.", value=True)
228
+
229
+ audio_output = gr.Audio(label="Synthesized Audio")
230
+ # spectrogram_output = gr.Image(label="Spectrogram")
231
+
232
+ generate_btn.click(infer, inputs=[ref_audio_input, ref_text_input, gen_text_input, model_choice, remove_silence], outputs=[audio_output])
233
+ gr.Markdown("Unofficial demo by [mrfakename](https://x.com/realmrfakename)")
234
+
235
+
236
+ app.queue().launch()
cog.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prediction interface for Cog ⚙️
2
+ # https://cog.run/python
3
+
4
+ from cog import BasePredictor, Input, Path
5
+
6
+ import os
7
+ import re
8
+ import torch
9
+ import torchaudio
10
+ import numpy as np
11
+ import tempfile
12
+ from einops import rearrange
13
+ from ema_pytorch import EMA
14
+ from vocos import Vocos
15
+ from pydub import AudioSegment
16
+ from model import CFM, UNetT, DiT, MMDiT
17
+ from cached_path import cached_path
18
+ from model.utils import (
19
+ get_tokenizer,
20
+ convert_char_to_pinyin,
21
+ save_spectrogram,
22
+ )
23
+ from transformers import pipeline
24
+ import librosa
25
+
26
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
27
+
28
+ target_sample_rate = 24000
29
+ n_mel_channels = 100
30
+ hop_length = 256
31
+ target_rms = 0.1
32
+ nfe_step = 32 # 16, 32
33
+ cfg_strength = 2.0
34
+ ode_method = 'euler'
35
+ sway_sampling_coef = -1.0
36
+ speed = 1.0
37
+ # fix_duration = 27 # None or float (duration in seconds)
38
+ fix_duration = None
39
+
40
+
41
+ class Predictor(BasePredictor):
42
+ def load_model(exp_name, model_cls, model_cfg, ckpt_step):
43
+ checkpoint = torch.load(str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
44
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
45
+ model = CFM(
46
+ transformer=model_cls(
47
+ **model_cfg,
48
+ text_num_embeds=vocab_size,
49
+ mel_dim=n_mel_channels
50
+ ),
51
+ mel_spec_kwargs=dict(
52
+ target_sample_rate=target_sample_rate,
53
+ n_mel_channels=n_mel_channels,
54
+ hop_length=hop_length,
55
+ ),
56
+ odeint_kwargs=dict(
57
+ method=ode_method,
58
+ ),
59
+ vocab_char_map=vocab_char_map,
60
+ ).to(device)
61
+
62
+ ema_model = EMA(model, include_online_model=False).to(device)
63
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
64
+ ema_model.copy_params_from_ema_to_model()
65
+
66
+ return ema_model, model
67
+ def setup(self) -> None:
68
+ """Load the model into memory to make running multiple predictions efficient"""
69
+ # self.model = torch.load("./weights.pth")
70
+ print("Loading Whisper model...")
71
+ self.pipe = pipeline(
72
+ "automatic-speech-recognition",
73
+ model="openai/whisper-large-v3-turbo",
74
+ torch_dtype=torch.float16,
75
+ device=device,
76
+ )
77
+ print("Loading F5-TTS model...")
78
+
79
+ F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
80
+ self.F5TTS_ema_model, self.F5TTS_base_model = self.load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
81
+
82
+
83
+ def predict(
84
+ self,
85
+ gen_text: str = Input(description="Text to generate"),
86
+ ref_audio_orig: Path = Input(description="Reference audio"),
87
+ remove_silence: bool = Input(description="Remove silences", default=True),
88
+ ) -> Path:
89
+ """Run a single prediction on the model"""
90
+ model_choice = "F5-TTS"
91
+ print(gen_text)
92
+ if len(gen_text) > 200:
93
+ raise gr.Error("Please keep your text under 200 chars.")
94
+ gr.Info("Converting audio...")
95
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
96
+ aseg = AudioSegment.from_file(ref_audio_orig)
97
+ audio_duration = len(aseg)
98
+ if audio_duration > 15000:
99
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
100
+ aseg = aseg[:15000]
101
+ aseg.export(f.name, format="wav")
102
+ ref_audio = f.name
103
+ ema_model = self.F5TTS_ema_model
104
+ base_model = self.F5TTS_base_model
105
+
106
+ if not ref_text.strip():
107
+ gr.Info("No reference text provided, transcribing reference audio...")
108
+ ref_text = outputs = self.pipe(
109
+ ref_audio,
110
+ chunk_length_s=30,
111
+ batch_size=128,
112
+ generate_kwargs={"task": "transcribe"},
113
+ return_timestamps=False,
114
+ )['text'].strip()
115
+ gr.Info("Finished transcription")
116
+ else:
117
+ gr.Info("Using custom reference text...")
118
+ audio, sr = torchaudio.load(ref_audio)
119
+
120
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
121
+ if rms < target_rms:
122
+ audio = audio * target_rms / rms
123
+ if sr != target_sample_rate:
124
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
125
+ audio = resampler(audio)
126
+ audio = audio.to(device)
127
+
128
+ # Prepare the text
129
+ text_list = [ref_text + gen_text]
130
+ final_text_list = convert_char_to_pinyin(text_list)
131
+
132
+ # Calculate duration
133
+ ref_audio_len = audio.shape[-1] // hop_length
134
+ # if fix_duration is not None:
135
+ # duration = int(fix_duration * target_sample_rate / hop_length)
136
+ # else:
137
+ zh_pause_punc = r"。,、;:?!"
138
+ ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
139
+ gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
140
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
141
+
142
+ # inference
143
+ gr.Info(f"Generating audio using F5-TTS")
144
+ with torch.inference_mode():
145
+ generated, _ = base_model.sample(
146
+ cond=audio,
147
+ text=final_text_list,
148
+ duration=duration,
149
+ steps=nfe_step,
150
+ cfg_strength=cfg_strength,
151
+ sway_sampling_coef=sway_sampling_coef,
152
+ )
153
+
154
+ generated = generated[:, ref_audio_len:, :]
155
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
156
+ gr.Info("Running vocoder")
157
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
158
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
159
+ if rms < target_rms:
160
+ generated_wave = generated_wave * rms / target_rms
161
+
162
+ # wav -> numpy
163
+ generated_wave = generated_wave.squeeze().cpu().numpy()
164
+
165
+ if remove_silence:
166
+ gr.Info("Removing audio silences... This may take a moment")
167
+ non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
168
+ non_silent_wave = np.array([])
169
+ for interval in non_silent_intervals:
170
+ start, end = interval
171
+ non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
172
+ generated_wave = non_silent_wave
173
+
174
+
175
+ # spectogram
176
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
177
+ wav_path = tmp_wav.name
178
+ torchaudio.save(wav_path, torch.tensor(generated_wave), target_sample_rate)
179
+
180
+ return wav_path
data/.DS_Store DELETED
Binary file (6.15 kB)
 
data/Emilia_ZH_EN_pinyin/vocab.txt CHANGED
@@ -1,2545 +1,2545 @@
1
-
2
- !
3
- "
4
- #
5
- $
6
- %
7
- &
8
- '
9
- (
10
- )
11
- *
12
- +
13
- ,
14
- -
15
- .
16
- /
17
- 0
18
- 1
19
- 2
20
- 3
21
- 4
22
- 5
23
- 6
24
- 7
25
- 8
26
- 9
27
- :
28
- ;
29
- =
30
- >
31
- ?
32
- @
33
- A
34
- B
35
- C
36
- D
37
- E
38
- F
39
- G
40
- H
41
- I
42
- J
43
- K
44
- L
45
- M
46
- N
47
- O
48
- P
49
- Q
50
- R
51
- S
52
- T
53
- U
54
- V
55
- W
56
- X
57
- Y
58
- Z
59
- [
60
- \
61
- ]
62
- _
63
- a
64
- a1
65
- ai1
66
- ai2
67
- ai3
68
- ai4
69
- an1
70
- an3
71
- an4
72
- ang1
73
- ang2
74
- ang4
75
- ao1
76
- ao2
77
- ao3
78
- ao4
79
- b
80
- ba
81
- ba1
82
- ba2
83
- ba3
84
- ba4
85
- bai1
86
- bai2
87
- bai3
88
- bai4
89
- ban1
90
- ban2
91
- ban3
92
- ban4
93
- bang1
94
- bang2
95
- bang3
96
- bang4
97
- bao1
98
- bao2
99
- bao3
100
- bao4
101
- bei
102
- bei1
103
- bei2
104
- bei3
105
- bei4
106
- ben1
107
- ben2
108
- ben3
109
- ben4
110
- beng
111
- beng1
112
- beng2
113
- beng3
114
- beng4
115
- bi1
116
- bi2
117
- bi3
118
- bi4
119
- bian1
120
- bian2
121
- bian3
122
- bian4
123
- biao1
124
- biao2
125
- biao3
126
- bie1
127
- bie2
128
- bie3
129
- bie4
130
- bin1
131
- bin4
132
- bing1
133
- bing2
134
- bing3
135
- bing4
136
- bo
137
- bo1
138
- bo2
139
- bo3
140
- bo4
141
- bu2
142
- bu3
143
- bu4
144
- c
145
- ca1
146
- cai1
147
- cai2
148
- cai3
149
- cai4
150
- can1
151
- can2
152
- can3
153
- can4
154
- cang1
155
- cang2
156
- cao1
157
- cao2
158
- cao3
159
- ce4
160
- cen1
161
- cen2
162
- ceng1
163
- ceng2
164
- ceng4
165
- cha1
166
- cha2
167
- cha3
168
- cha4
169
- chai1
170
- chai2
171
- chan1
172
- chan2
173
- chan3
174
- chan4
175
- chang1
176
- chang2
177
- chang3
178
- chang4
179
- chao1
180
- chao2
181
- chao3
182
- che1
183
- che2
184
- che3
185
- che4
186
- chen1
187
- chen2
188
- chen3
189
- chen4
190
- cheng1
191
- cheng2
192
- cheng3
193
- cheng4
194
- chi1
195
- chi2
196
- chi3
197
- chi4
198
- chong1
199
- chong2
200
- chong3
201
- chong4
202
- chou1
203
- chou2
204
- chou3
205
- chou4
206
- chu1
207
- chu2
208
- chu3
209
- chu4
210
- chua1
211
- chuai1
212
- chuai2
213
- chuai3
214
- chuai4
215
- chuan1
216
- chuan2
217
- chuan3
218
- chuan4
219
- chuang1
220
- chuang2
221
- chuang3
222
- chuang4
223
- chui1
224
- chui2
225
- chun1
226
- chun2
227
- chun3
228
- chuo1
229
- chuo4
230
- ci1
231
- ci2
232
- ci3
233
- ci4
234
- cong1
235
- cong2
236
- cou4
237
- cu1
238
- cu4
239
- cuan1
240
- cuan2
241
- cuan4
242
- cui1
243
- cui3
244
- cui4
245
- cun1
246
- cun2
247
- cun4
248
- cuo1
249
- cuo2
250
- cuo4
251
- d
252
- da
253
- da1
254
- da2
255
- da3
256
- da4
257
- dai1
258
- dai2
259
- dai3
260
- dai4
261
- dan1
262
- dan2
263
- dan3
264
- dan4
265
- dang1
266
- dang2
267
- dang3
268
- dang4
269
- dao1
270
- dao2
271
- dao3
272
- dao4
273
- de
274
- de1
275
- de2
276
- dei3
277
- den4
278
- deng1
279
- deng2
280
- deng3
281
- deng4
282
- di1
283
- di2
284
- di3
285
- di4
286
- dia3
287
- dian1
288
- dian2
289
- dian3
290
- dian4
291
- diao1
292
- diao3
293
- diao4
294
- die1
295
- die2
296
- die4
297
- ding1
298
- ding2
299
- ding3
300
- ding4
301
- diu1
302
- dong1
303
- dong3
304
- dong4
305
- dou1
306
- dou2
307
- dou3
308
- dou4
309
- du1
310
- du2
311
- du3
312
- du4
313
- duan1
314
- duan2
315
- duan3
316
- duan4
317
- dui1
318
- dui4
319
- dun1
320
- dun3
321
- dun4
322
- duo1
323
- duo2
324
- duo3
325
- duo4
326
- e
327
- e1
328
- e2
329
- e3
330
- e4
331
- ei2
332
- en1
333
- en4
334
- er
335
- er2
336
- er3
337
- er4
338
- f
339
- fa1
340
- fa2
341
- fa3
342
- fa4
343
- fan1
344
- fan2
345
- fan3
346
- fan4
347
- fang1
348
- fang2
349
- fang3
350
- fang4
351
- fei1
352
- fei2
353
- fei3
354
- fei4
355
- fen1
356
- fen2
357
- fen3
358
- fen4
359
- feng1
360
- feng2
361
- feng3
362
- feng4
363
- fo2
364
- fou2
365
- fou3
366
- fu1
367
- fu2
368
- fu3
369
- fu4
370
- g
371
- ga1
372
- ga2
373
- ga3
374
- ga4
375
- gai1
376
- gai2
377
- gai3
378
- gai4
379
- gan1
380
- gan2
381
- gan3
382
- gan4
383
- gang1
384
- gang2
385
- gang3
386
- gang4
387
- gao1
388
- gao2
389
- gao3
390
- gao4
391
- ge1
392
- ge2
393
- ge3
394
- ge4
395
- gei2
396
- gei3
397
- gen1
398
- gen2
399
- gen3
400
- gen4
401
- geng1
402
- geng3
403
- geng4
404
- gong1
405
- gong3
406
- gong4
407
- gou1
408
- gou2
409
- gou3
410
- gou4
411
- gu
412
- gu1
413
- gu2
414
- gu3
415
- gu4
416
- gua1
417
- gua2
418
- gua3
419
- gua4
420
- guai1
421
- guai2
422
- guai3
423
- guai4
424
- guan1
425
- guan2
426
- guan3
427
- guan4
428
- guang1
429
- guang2
430
- guang3
431
- guang4
432
- gui1
433
- gui2
434
- gui3
435
- gui4
436
- gun3
437
- gun4
438
- guo1
439
- guo2
440
- guo3
441
- guo4
442
- h
443
- ha1
444
- ha2
445
- ha3
446
- hai1
447
- hai2
448
- hai3
449
- hai4
450
- han1
451
- han2
452
- han3
453
- han4
454
- hang1
455
- hang2
456
- hang4
457
- hao1
458
- hao2
459
- hao3
460
- hao4
461
- he1
462
- he2
463
- he4
464
- hei1
465
- hen2
466
- hen3
467
- hen4
468
- heng1
469
- heng2
470
- heng4
471
- hong1
472
- hong2
473
- hong3
474
- hong4
475
- hou1
476
- hou2
477
- hou3
478
- hou4
479
- hu1
480
- hu2
481
- hu3
482
- hu4
483
- hua1
484
- hua2
485
- hua4
486
- huai2
487
- huai4
488
- huan1
489
- huan2
490
- huan3
491
- huan4
492
- huang1
493
- huang2
494
- huang3
495
- huang4
496
- hui1
497
- hui2
498
- hui3
499
- hui4
500
- hun1
501
- hun2
502
- hun4
503
- huo
504
- huo1
505
- huo2
506
- huo3
507
- huo4
508
- i
509
- j
510
- ji1
511
- ji2
512
- ji3
513
- ji4
514
- jia
515
- jia1
516
- jia2
517
- jia3
518
- jia4
519
- jian1
520
- jian2
521
- jian3
522
- jian4
523
- jiang1
524
- jiang2
525
- jiang3
526
- jiang4
527
- jiao1
528
- jiao2
529
- jiao3
530
- jiao4
531
- jie1
532
- jie2
533
- jie3
534
- jie4
535
- jin1
536
- jin2
537
- jin3
538
- jin4
539
- jing1
540
- jing2
541
- jing3
542
- jing4
543
- jiong3
544
- jiu1
545
- jiu2
546
- jiu3
547
- jiu4
548
- ju1
549
- ju2
550
- ju3
551
- ju4
552
- juan1
553
- juan2
554
- juan3
555
- juan4
556
- jue1
557
- jue2
558
- jue4
559
- jun1
560
- jun4
561
- k
562
- ka1
563
- ka2
564
- ka3
565
- kai1
566
- kai2
567
- kai3
568
- kai4
569
- kan1
570
- kan2
571
- kan3
572
- kan4
573
- kang1
574
- kang2
575
- kang4
576
- kao1
577
- kao2
578
- kao3
579
- kao4
580
- ke1
581
- ke2
582
- ke3
583
- ke4
584
- ken3
585
- keng1
586
- kong1
587
- kong3
588
- kong4
589
- kou1
590
- kou2
591
- kou3
592
- kou4
593
- ku1
594
- ku2
595
- ku3
596
- ku4
597
- kua1
598
- kua3
599
- kua4
600
- kuai3
601
- kuai4
602
- kuan1
603
- kuan2
604
- kuan3
605
- kuang1
606
- kuang2
607
- kuang4
608
- kui1
609
- kui2
610
- kui3
611
- kui4
612
- kun1
613
- kun3
614
- kun4
615
- kuo4
616
- l
617
- la
618
- la1
619
- la2
620
- la3
621
- la4
622
- lai2
623
- lai4
624
- lan2
625
- lan3
626
- lan4
627
- lang1
628
- lang2
629
- lang3
630
- lang4
631
- lao1
632
- lao2
633
- lao3
634
- lao4
635
- le
636
- le1
637
- le4
638
- lei
639
- lei1
640
- lei2
641
- lei3
642
- lei4
643
- leng1
644
- leng2
645
- leng3
646
- leng4
647
- li
648
- li1
649
- li2
650
- li3
651
- li4
652
- lia3
653
- lian2
654
- lian3
655
- lian4
656
- liang2
657
- liang3
658
- liang4
659
- liao1
660
- liao2
661
- liao3
662
- liao4
663
- lie1
664
- lie2
665
- lie3
666
- lie4
667
- lin1
668
- lin2
669
- lin3
670
- lin4
671
- ling2
672
- ling3
673
- ling4
674
- liu1
675
- liu2
676
- liu3
677
- liu4
678
- long1
679
- long2
680
- long3
681
- long4
682
- lou1
683
- lou2
684
- lou3
685
- lou4
686
- lu1
687
- lu2
688
- lu3
689
- lu4
690
- luan2
691
- luan3
692
- luan4
693
- lun1
694
- lun2
695
- lun4
696
- luo1
697
- luo2
698
- luo3
699
- luo4
700
- lv2
701
- lv3
702
- lv4
703
- lve3
704
- lve4
705
- m
706
- ma
707
- ma1
708
- ma2
709
- ma3
710
- ma4
711
- mai2
712
- mai3
713
- mai4
714
- man1
715
- man2
716
- man3
717
- man4
718
- mang2
719
- mang3
720
- mao1
721
- mao2
722
- mao3
723
- mao4
724
- me
725
- mei2
726
- mei3
727
- mei4
728
- men
729
- men1
730
- men2
731
- men4
732
- meng
733
- meng1
734
- meng2
735
- meng3
736
- meng4
737
- mi1
738
- mi2
739
- mi3
740
- mi4
741
- mian2
742
- mian3
743
- mian4
744
- miao1
745
- miao2
746
- miao3
747
- miao4
748
- mie1
749
- mie4
750
- min2
751
- min3
752
- ming2
753
- ming3
754
- ming4
755
- miu4
756
- mo1
757
- mo2
758
- mo3
759
- mo4
760
- mou1
761
- mou2
762
- mou3
763
- mu2
764
- mu3
765
- mu4
766
- n
767
- n2
768
- na1
769
- na2
770
- na3
771
- na4
772
- nai2
773
- nai3
774
- nai4
775
- nan1
776
- nan2
777
- nan3
778
- nan4
779
- nang1
780
- nang2
781
- nang3
782
- nao1
783
- nao2
784
- nao3
785
- nao4
786
- ne
787
- ne2
788
- ne4
789
- nei3
790
- nei4
791
- nen4
792
- neng2
793
- ni1
794
- ni2
795
- ni3
796
- ni4
797
- nian1
798
- nian2
799
- nian3
800
- nian4
801
- niang2
802
- niang4
803
- niao2
804
- niao3
805
- niao4
806
- nie1
807
- nie4
808
- nin2
809
- ning2
810
- ning3
811
- ning4
812
- niu1
813
- niu2
814
- niu3
815
- niu4
816
- nong2
817
- nong4
818
- nou4
819
- nu2
820
- nu3
821
- nu4
822
- nuan3
823
- nuo2
824
- nuo4
825
- nv2
826
- nv3
827
- nve4
828
- o
829
- o1
830
- o2
831
- ou1
832
- ou2
833
- ou3
834
- ou4
835
- p
836
- pa1
837
- pa2
838
- pa4
839
- pai1
840
- pai2
841
- pai3
842
- pai4
843
- pan1
844
- pan2
845
- pan4
846
- pang1
847
- pang2
848
- pang4
849
- pao1
850
- pao2
851
- pao3
852
- pao4
853
- pei1
854
- pei2
855
- pei4
856
- pen1
857
- pen2
858
- pen4
859
- peng1
860
- peng2
861
- peng3
862
- peng4
863
- pi1
864
- pi2
865
- pi3
866
- pi4
867
- pian1
868
- pian2
869
- pian4
870
- piao1
871
- piao2
872
- piao3
873
- piao4
874
- pie1
875
- pie2
876
- pie3
877
- pin1
878
- pin2
879
- pin3
880
- pin4
881
- ping1
882
- ping2
883
- po1
884
- po2
885
- po3
886
- po4
887
- pou1
888
- pu1
889
- pu2
890
- pu3
891
- pu4
892
- q
893
- qi1
894
- qi2
895
- qi3
896
- qi4
897
- qia1
898
- qia3
899
- qia4
900
- qian1
901
- qian2
902
- qian3
903
- qian4
904
- qiang1
905
- qiang2
906
- qiang3
907
- qiang4
908
- qiao1
909
- qiao2
910
- qiao3
911
- qiao4
912
- qie1
913
- qie2
914
- qie3
915
- qie4
916
- qin1
917
- qin2
918
- qin3
919
- qin4
920
- qing1
921
- qing2
922
- qing3
923
- qing4
924
- qiong1
925
- qiong2
926
- qiu1
927
- qiu2
928
- qiu3
929
- qu1
930
- qu2
931
- qu3
932
- qu4
933
- quan1
934
- quan2
935
- quan3
936
- quan4
937
- que1
938
- que2
939
- que4
940
- qun2
941
- r
942
- ran2
943
- ran3
944
- rang1
945
- rang2
946
- rang3
947
- rang4
948
- rao2
949
- rao3
950
- rao4
951
- re2
952
- re3
953
- re4
954
- ren2
955
- ren3
956
- ren4
957
- reng1
958
- reng2
959
- ri4
960
- rong1
961
- rong2
962
- rong3
963
- rou2
964
- rou4
965
- ru2
966
- ru3
967
- ru4
968
- ruan2
969
- ruan3
970
- rui3
971
- rui4
972
- run4
973
- ruo4
974
- s
975
- sa1
976
- sa2
977
- sa3
978
- sa4
979
- sai1
980
- sai4
981
- san1
982
- san2
983
- san3
984
- san4
985
- sang1
986
- sang3
987
- sang4
988
- sao1
989
- sao2
990
- sao3
991
- sao4
992
- se4
993
- sen1
994
- seng1
995
- sha1
996
- sha2
997
- sha3
998
- sha4
999
- shai1
1000
- shai2
1001
- shai3
1002
- shai4
1003
- shan1
1004
- shan3
1005
- shan4
1006
- shang
1007
- shang1
1008
- shang3
1009
- shang4
1010
- shao1
1011
- shao2
1012
- shao3
1013
- shao4
1014
- she1
1015
- she2
1016
- she3
1017
- she4
1018
- shei2
1019
- shen1
1020
- shen2
1021
- shen3
1022
- shen4
1023
- sheng1
1024
- sheng2
1025
- sheng3
1026
- sheng4
1027
- shi
1028
- shi1
1029
- shi2
1030
- shi3
1031
- shi4
1032
- shou1
1033
- shou2
1034
- shou3
1035
- shou4
1036
- shu1
1037
- shu2
1038
- shu3
1039
- shu4
1040
- shua1
1041
- shua2
1042
- shua3
1043
- shua4
1044
- shuai1
1045
- shuai3
1046
- shuai4
1047
- shuan1
1048
- shuan4
1049
- shuang1
1050
- shuang3
1051
- shui2
1052
- shui3
1053
- shui4
1054
- shun3
1055
- shun4
1056
- shuo1
1057
- shuo4
1058
- si1
1059
- si2
1060
- si3
1061
- si4
1062
- song1
1063
- song3
1064
- song4
1065
- sou1
1066
- sou3
1067
- sou4
1068
- su1
1069
- su2
1070
- su4
1071
- suan1
1072
- suan4
1073
- sui1
1074
- sui2
1075
- sui3
1076
- sui4
1077
- sun1
1078
- sun3
1079
- suo
1080
- suo1
1081
- suo2
1082
- suo3
1083
- t
1084
- ta1
1085
- ta2
1086
- ta3
1087
- ta4
1088
- tai1
1089
- tai2
1090
- tai4
1091
- tan1
1092
- tan2
1093
- tan3
1094
- tan4
1095
- tang1
1096
- tang2
1097
- tang3
1098
- tang4
1099
- tao1
1100
- tao2
1101
- tao3
1102
- tao4
1103
- te4
1104
- teng2
1105
- ti1
1106
- ti2
1107
- ti3
1108
- ti4
1109
- tian1
1110
- tian2
1111
- tian3
1112
- tiao1
1113
- tiao2
1114
- tiao3
1115
- tiao4
1116
- tie1
1117
- tie2
1118
- tie3
1119
- tie4
1120
- ting1
1121
- ting2
1122
- ting3
1123
- tong1
1124
- tong2
1125
- tong3
1126
- tong4
1127
- tou
1128
- tou1
1129
- tou2
1130
- tou4
1131
- tu1
1132
- tu2
1133
- tu3
1134
- tu4
1135
- tuan1
1136
- tuan2
1137
- tui1
1138
- tui2
1139
- tui3
1140
- tui4
1141
- tun1
1142
- tun2
1143
- tun4
1144
- tuo1
1145
- tuo2
1146
- tuo3
1147
- tuo4
1148
- u
1149
- v
1150
- w
1151
- wa
1152
- wa1
1153
- wa2
1154
- wa3
1155
- wa4
1156
- wai1
1157
- wai3
1158
- wai4
1159
- wan1
1160
- wan2
1161
- wan3
1162
- wan4
1163
- wang1
1164
- wang2
1165
- wang3
1166
- wang4
1167
- wei1
1168
- wei2
1169
- wei3
1170
- wei4
1171
- wen1
1172
- wen2
1173
- wen3
1174
- wen4
1175
- weng1
1176
- weng4
1177
- wo1
1178
- wo2
1179
- wo3
1180
- wo4
1181
- wu1
1182
- wu2
1183
- wu3
1184
- wu4
1185
- x
1186
- xi1
1187
- xi2
1188
- xi3
1189
- xi4
1190
- xia1
1191
- xia2
1192
- xia4
1193
- xian1
1194
- xian2
1195
- xian3
1196
- xian4
1197
- xiang1
1198
- xiang2
1199
- xiang3
1200
- xiang4
1201
- xiao1
1202
- xiao2
1203
- xiao3
1204
- xiao4
1205
- xie1
1206
- xie2
1207
- xie3
1208
- xie4
1209
- xin1
1210
- xin2
1211
- xin4
1212
- xing1
1213
- xing2
1214
- xing3
1215
- xing4
1216
- xiong1
1217
- xiong2
1218
- xiu1
1219
- xiu3
1220
- xiu4
1221
- xu
1222
- xu1
1223
- xu2
1224
- xu3
1225
- xu4
1226
- xuan1
1227
- xuan2
1228
- xuan3
1229
- xuan4
1230
- xue1
1231
- xue2
1232
- xue3
1233
- xue4
1234
- xun1
1235
- xun2
1236
- xun4
1237
- y
1238
- ya
1239
- ya1
1240
- ya2
1241
- ya3
1242
- ya4
1243
- yan1
1244
- yan2
1245
- yan3
1246
- yan4
1247
- yang1
1248
- yang2
1249
- yang3
1250
- yang4
1251
- yao1
1252
- yao2
1253
- yao3
1254
- yao4
1255
- ye1
1256
- ye2
1257
- ye3
1258
- ye4
1259
- yi
1260
- yi1
1261
- yi2
1262
- yi3
1263
- yi4
1264
- yin1
1265
- yin2
1266
- yin3
1267
- yin4
1268
- ying1
1269
- ying2
1270
- ying3
1271
- ying4
1272
- yo1
1273
- yong1
1274
- yong2
1275
- yong3
1276
- yong4
1277
- you1
1278
- you2
1279
- you3
1280
- you4
1281
- yu1
1282
- yu2
1283
- yu3
1284
- yu4
1285
- yuan1
1286
- yuan2
1287
- yuan3
1288
- yuan4
1289
- yue1
1290
- yue4
1291
- yun1
1292
- yun2
1293
- yun3
1294
- yun4
1295
- z
1296
- za1
1297
- za2
1298
- za3
1299
- zai1
1300
- zai3
1301
- zai4
1302
- zan1
1303
- zan2
1304
- zan3
1305
- zan4
1306
- zang1
1307
- zang4
1308
- zao1
1309
- zao2
1310
- zao3
1311
- zao4
1312
- ze2
1313
- ze4
1314
- zei2
1315
- zen3
1316
- zeng1
1317
- zeng4
1318
- zha1
1319
- zha2
1320
- zha3
1321
- zha4
1322
- zhai1
1323
- zhai2
1324
- zhai3
1325
- zhai4
1326
- zhan1
1327
- zhan2
1328
- zhan3
1329
- zhan4
1330
- zhang1
1331
- zhang2
1332
- zhang3
1333
- zhang4
1334
- zhao1
1335
- zhao2
1336
- zhao3
1337
- zhao4
1338
- zhe
1339
- zhe1
1340
- zhe2
1341
- zhe3
1342
- zhe4
1343
- zhen1
1344
- zhen2
1345
- zhen3
1346
- zhen4
1347
- zheng1
1348
- zheng2
1349
- zheng3
1350
- zheng4
1351
- zhi1
1352
- zhi2
1353
- zhi3
1354
- zhi4
1355
- zhong1
1356
- zhong2
1357
- zhong3
1358
- zhong4
1359
- zhou1
1360
- zhou2
1361
- zhou3
1362
- zhou4
1363
- zhu1
1364
- zhu2
1365
- zhu3
1366
- zhu4
1367
- zhua1
1368
- zhua2
1369
- zhua3
1370
- zhuai1
1371
- zhuai3
1372
- zhuai4
1373
- zhuan1
1374
- zhuan2
1375
- zhuan3
1376
- zhuan4
1377
- zhuang1
1378
- zhuang4
1379
- zhui1
1380
- zhui4
1381
- zhun1
1382
- zhun2
1383
- zhun3
1384
- zhuo1
1385
- zhuo2
1386
- zi
1387
- zi1
1388
- zi2
1389
- zi3
1390
- zi4
1391
- zong1
1392
- zong2
1393
- zong3
1394
- zong4
1395
- zou1
1396
- zou2
1397
- zou3
1398
- zou4
1399
- zu1
1400
- zu2
1401
- zu3
1402
- zuan1
1403
- zuan3
1404
- zuan4
1405
- zui2
1406
- zui3
1407
- zui4
1408
- zun1
1409
- zuo
1410
- zuo1
1411
- zuo2
1412
- zuo3
1413
- zuo4
1414
- {
1415
- ~
1416
- ¡
1417
- ¢
1418
- £
1419
- ¥
1420
- §
1421
- ¨
1422
- ©
1423
- «
1424
- ®
1425
- ¯
1426
- °
1427
- ±
1428
- ²
1429
- ³
1430
- ´
1431
- µ
1432
- ·
1433
- ¹
1434
- º
1435
- »
1436
- ¼
1437
- ½
1438
- ¾
1439
- ¿
1440
- À
1441
- Á
1442
- Â
1443
- Ã
1444
- Ä
1445
- Å
1446
- Æ
1447
- Ç
1448
- È
1449
- É
1450
- Ê
1451
- Í
1452
- Î
1453
- Ñ
1454
- Ó
1455
- Ö
1456
- ×
1457
- Ø
1458
- Ú
1459
- Ü
1460
- Ý
1461
- Þ
1462
- ß
1463
- à
1464
- á
1465
- â
1466
- ã
1467
- ä
1468
- å
1469
- æ
1470
- ç
1471
- è
1472
- é
1473
- ê
1474
- ë
1475
- ì
1476
- í
1477
- î
1478
- ï
1479
- ð
1480
- ñ
1481
- ò
1482
- ó
1483
- ô
1484
- õ
1485
- ö
1486
- ø
1487
- ù
1488
- ú
1489
- û
1490
- ü
1491
- ý
1492
- Ā
1493
- ā
1494
- ă
1495
- ą
1496
- ć
1497
- Č
1498
- č
1499
- Đ
1500
- đ
1501
- ē
1502
- ė
1503
- ę
1504
- ě
1505
- ĝ
1506
- ğ
1507
- ħ
1508
- ī
1509
- į
1510
- İ
1511
- ı
1512
- Ł
1513
- ł
1514
- ń
1515
- ņ
1516
- ň
1517
- ŋ
1518
- Ō
1519
- ō
1520
- ő
1521
- œ
1522
- ř
1523
- Ś
1524
- ś
1525
- Ş
1526
- ş
1527
- Š
1528
- š
1529
- Ť
1530
- ť
1531
- ũ
1532
- ū
1533
- ź
1534
- Ż
1535
- ż
1536
- Ž
1537
- ž
1538
- ơ
1539
- ư
1540
- ǎ
1541
- ǐ
1542
- ǒ
1543
- ǔ
1544
- ǚ
1545
- ș
1546
- ț
1547
- ɑ
1548
- ɔ
1549
- ɕ
1550
- ə
1551
- ɛ
1552
- ɜ
1553
- ɡ
1554
- ɣ
1555
- ɪ
1556
- ɫ
1557
- ɴ
1558
- ɹ
1559
- ɾ
1560
- ʃ
1561
- ʊ
1562
- ʌ
1563
- ʒ
1564
- ʔ
1565
- ʰ
1566
- ʷ
1567
- ʻ
1568
- ʾ
1569
- ʿ
1570
- ˈ
1571
- ː
1572
- ˙
1573
- ˜
1574
- ˢ
1575
- ́
1576
- ̅
1577
- Α
1578
- Β
1579
- Δ
1580
- Ε
1581
- Θ
1582
- Κ
1583
- Λ
1584
- Μ
1585
- Ξ
1586
- Π
1587
- Σ
1588
- Τ
1589
- Φ
1590
- Χ
1591
- Ψ
1592
- Ω
1593
- ά
1594
- έ
1595
- ή
1596
- ί
1597
- α
1598
- β
1599
- γ
1600
- δ
1601
- ε
1602
- ζ
1603
- η
1604
- θ
1605
- ι
1606
- κ
1607
- λ
1608
- μ
1609
- ν
1610
- ξ
1611
- ο
1612
- π
1613
- ρ
1614
- ς
1615
- σ
1616
- τ
1617
- υ
1618
- φ
1619
- χ
1620
- ψ
1621
- ω
1622
- ϊ
1623
- ό
1624
- ύ
1625
- ώ
1626
- ϕ
1627
- ϵ
1628
- Ё
1629
- А
1630
- Б
1631
- В
1632
- Г
1633
- Д
1634
- Е
1635
- Ж
1636
- З
1637
- И
1638
- Й
1639
- К
1640
- Л
1641
- М
1642
- Н
1643
- О
1644
- П
1645
- Р
1646
- С
1647
- Т
1648
- У
1649
- Ф
1650
- Х
1651
- Ц
1652
- Ч
1653
- Ш
1654
- Щ
1655
- Ы
1656
- Ь
1657
- Э
1658
- Ю
1659
- Я
1660
- а
1661
- б
1662
- в
1663
- г
1664
- д
1665
- е
1666
- ж
1667
- з
1668
- и
1669
- й
1670
- к
1671
- л
1672
- м
1673
- н
1674
- о
1675
- п
1676
- р
1677
- с
1678
- т
1679
- у
1680
- ф
1681
- х
1682
- ц
1683
- ч
1684
- ш
1685
- щ
1686
- ъ
1687
- ы
1688
- ь
1689
- э
1690
- ю
1691
- я
1692
- ё
1693
- і
1694
- ְ
1695
- ִ
1696
- ֵ
1697
- ֶ
1698
- ַ
1699
- ָ
1700
- ֹ
1701
- ּ
1702
- ־
1703
- ׁ
1704
- א
1705
- ב
1706
- ג
1707
- ד
1708
- ה
1709
- ו
1710
- ז
1711
- ח
1712
- ט
1713
- י
1714
- כ
1715
- ל
1716
- ם
1717
- מ
1718
- ן
1719
- נ
1720
- ס
1721
- ע
1722
- פ
1723
- ק
1724
- ר
1725
- ש
1726
- ת
1727
- أ
1728
- ب
1729
- ة
1730
- ت
1731
- ج
1732
- ح
1733
- د
1734
- ر
1735
- ز
1736
- س
1737
- ص
1738
- ط
1739
- ع
1740
- ق
1741
- ك
1742
- ل
1743
- م
1744
- ن
1745
- ه
1746
- و
1747
- ي
1748
- َ
1749
- ُ
1750
- ِ
1751
- ْ
1752
-
1753
-
1754
-
1755
-
1756
-
1757
-
1758
-
1759
-
1760
-
1761
-
1762
-
1763
-
1764
-
1765
-
1766
-
1767
-
1768
-
1769
-
1770
-
1771
-
1772
-
1773
-
1774
-
1775
-
1776
-
1777
-
1778
-
1779
-
1780
-
1781
-
1782
-
1783
-
1784
-
1785
-
1786
-
1787
-
1788
-
1789
-
1790
-
1791
-
1792
-
1793
-
1794
-
1795
-
1796
-
1797
-
1798
-
1799
-
1800
- ế
1801
-
1802
-
1803
-
1804
-
1805
-
1806
-
1807
-
1808
-
1809
-
1810
-
1811
-
1812
-
1813
-
1814
-
1815
-
1816
-
1817
-
1818
-
1819
-
1820
-
1821
-
1822
-
1823
-
1824
-
1825
-
1826
-
1827
-
1828
-
1829
-
1830
-
1831
-
1832
-
1833
-
1834
-
1835
-
1836
-
1837
-
1838
-
1839
-
1840
-
1841
-
1842
-
1843
-
1844
-
1845
-
1846
-
1847
-
1848
-
1849
-
1850
-
1851
-
1852
-
1853
-
1854
-
1855
-
1856
-
1857
-
1858
-
1859
-
1860
-
1861
-
1862
-
1863
-
1864
-
1865
-
1866
-
1867
-
1868
-
1869
-
1870
-
1871
-
1872
-
1873
-
1874
-
1875
-
1876
-
1877
-
1878
-
1879
-
1880
-
1881
-
1882
-
1883
-
1884
-
1885
-
1886
-
1887
-
1888
-
1889
-
1890
-
1891
-
1892
-
1893
-
1894
-
1895
-
1896
-
1897
-
1898
-
1899
-
1900
-
1901
-
1902
-
1903
-
1904
-
1905
-
1906
-
1907
-
1908
-
1909
-
1910
-
1911
-
1912
-
1913
-
1914
-
1915
-
1916
-
1917
-
1918
-
1919
-
1920
-
1921
-
1922
-
1923
-
1924
-
1925
-
1926
-
1927
-
1928
-
1929
-
1930
-
1931
-
1932
-
1933
-
1934
-
1935
-
1936
-
1937
-
1938
-
1939
-
1940
-
1941
-
1942
-
1943
-
1944
-
1945
-
1946
-
1947
-
1948
-
1949
-
1950
-
1951
-
1952
-
1953
-
1954
-
1955
-
1956
-
1957
-
1958
-
1959
-
1960
-
1961
-
1962
-
1963
-
1964
-
1965
-
1966
-
1967
-
1968
-
1969
-
1970
-
1971
-
1972
-
1973
-
1974
-
1975
-
1976
-
1977
-
1978
-
1979
-
1980
-
1981
-
1982
-
1983
-
1984
-
1985
-
1986
-
1987
-
1988
-
1989
-
1990
-
1991
-
1992
-
1993
-
1994
-
1995
-
1996
-
1997
-
1998
-
1999
-
2000
-
2001
-
2002
-
2003
-
2004
-
2005
-
2006
-
2007
-
2008
-
2009
-
2010
-
2011
-
2012
-
2013
-
2014
-
2015
-
2016
-
2017
-
2018
-
2019
-
2020
-
2021
-
2022
-
2023
-
2024
-
2025
-
2026
-
2027
-
2028
-
2029
-
2030
-
2031
-
2032
-
2033
-
2034
-
2035
-
2036
-
2037
-
2038
-
2039
-
2040
-
2041
-
2042
-
2043
-
2044
-
2045
-
2046
-
2047
-
2048
-
2049
-
2050
-
2051
-
2052
-
2053
-
2054
-
2055
-
2056
-
2057
-
2058
-
2059
-
2060
-
2061
-
2062
-
2063
-
2064
-
2065
-
2066
-
2067
-
2068
-
2069
-
2070
-
2071
-
2072
-
2073
-
2074
-
2075
-
2076
-
2077
-
2078
-
2079
-
2080
-
2081
-
2082
-
2083
-
2084
-
2085
-
2086
-
2087
-
2088
-
2089
-
2090
-
2091
-
2092
-
2093
-
2094
-
2095
-
2096
-
2097
-
2098
-
2099
-
2100
-
2101
-
2102
-
2103
-
2104
-
2105
-
2106
-
2107
-
2108
-
2109
-
2110
-
2111
-
2112
-
2113
-
2114
-
2115
-
2116
-
2117
-
2118
-
2119
-
2120
-
2121
-
2122
-
2123
-
2124
-
2125
-
2126
-
2127
-
2128
-
2129
-
2130
-
2131
-
2132
-
2133
-
2134
-
2135
-
2136
-
2137
-
2138
-
2139
-
2140
-
2141
-
2142
-
2143
-
2144
-
2145
-
2146
-
2147
-
2148
-
2149
-
2150
-
2151
-
2152
-
2153
-
2154
-
2155
-
2156
-
2157
-
2158
-
2159
-
2160
-
2161
-
2162
-
2163
-
2164
-
2165
-
2166
-
2167
-
2168
-
2169
-
2170
-
2171
-
2172
-
2173
-
2174
-
2175
-
2176
-
2177
-
2178
-
2179
-
2180
-
2181
-
2182
-
2183
-
2184
-
2185
-
2186
-
2187
-
2188
-
2189
-
2190
-
2191
-
2192
-
2193
-
2194
-
2195
-
2196
-
2197
-
2198
-
2199
-
2200
-
2201
-
2202
-
2203
-
2204
-
2205
-
2206
-
2207
-
2208
-
2209
-
2210
-
2211
-
2212
-
2213
-
2214
-
2215
-
2216
-
2217
-
2218
-
2219
-
2220
-
2221
-
2222
-
2223
-
2224
-
2225
-
2226
-
2227
-
2228
-
2229
-
2230
-
2231
-
2232
-
2233
-
2234
-
2235
-
2236
-
2237
-
2238
-
2239
-
2240
-
2241
-
2242
-
2243
-
2244
-
2245
-
2246
-
2247
-
2248
-
2249
-
2250
-
2251
-
2252
-
2253
-
2254
-
2255
-
2256
-
2257
-
2258
-
2259
-
2260
-
2261
-
2262
-
2263
-
2264
-
2265
-
2266
-
2267
-
2268
-
2269
-
2270
-
2271
-
2272
-
2273
-
2274
-
2275
-
2276
-
2277
-
2278
-
2279
-
2280
-
2281
-
2282
-
2283
-
2284
-
2285
-
2286
-
2287
-
2288
-
2289
-
2290
-
2291
-
2292
-
2293
-
2294
-
2295
-
2296
-
2297
-
2298
-
2299
-
2300
-
2301
-
2302
-
2303
-
2304
-
2305
-
2306
-
2307
-
2308
-
2309
-
2310
-
2311
-
2312
-
2313
-
2314
-
2315
-
2316
-
2317
-
2318
-
2319
-
2320
-
2321
-
2322
-
2323
-
2324
-
2325
-
2326
-
2327
-
2328
-
2329
-
2330
-
2331
-
2332
-
2333
-
2334
-
2335
-
2336
-
2337
-
2338
-
2339
-
2340
-
2341
-
2342
-
2343
-
2344
-
2345
-
2346
-
2347
-
2348
-
2349
-
2350
-
2351
-
2352
-
2353
-
2354
-
2355
-
2356
-
2357
-
2358
-
2359
-
2360
-
2361
-
2362
-
2363
-
2364
-
2365
-
2366
-
2367
-
2368
-
2369
-
2370
-
2371
-
2372
-
2373
-
2374
-
2375
-
2376
-
2377
-
2378
- ���
2379
-
2380
-
2381
-
2382
-
2383
-
2384
-
2385
-
2386
-
2387
-
2388
-
2389
-
2390
-
2391
-
2392
-
2393
-
2394
-
2395
-
2396
-
2397
-
2398
-
2399
-
2400
-
2401
-
2402
-
2403
-
2404
-
2405
-
2406
-
2407
-
2408
-
2409
-
2410
-
2411
-
2412
-
2413
-
2414
-
2415
-
2416
-
2417
-
2418
-
2419
-
2420
-
2421
-
2422
-
2423
-
2424
-
2425
-
2426
-
2427
-
2428
-
2429
-
2430
-
2431
-
2432
-
2433
-
2434
-
2435
-
2436
-
2437
-
2438
-
2439
-
2440
-
2441
-
2442
-
2443
-
2444
-
2445
-
2446
-
2447
-
2448
-
2449
-
2450
-
2451
-
2452
-
2453
-
2454
-
2455
-
2456
-
2457
-
2458
-
2459
-
2460
-
2461
-
2462
-
2463
-
2464
-
2465
-
2466
-
2467
-
2468
-
2469
-
2470
-
2471
-
2472
-
2473
-
2474
-
2475
-
2476
-
2477
-
2478
-
2479
-
2480
-
2481
-
2482
-
2483
-
2484
-
2485
-
2486
-
2487
-
2488
-
2489
-
2490
-
2491
-
2492
-
2493
-
2494
-
2495
-
2496
-
2497
-
2498
-
2499
-
2500
-
2501
-
2502
-
2503
-
2504
-
2505
-
2506
-
2507
-
2508
-
2509
-
2510
-
2511
-
2512
-
2513
-
2514
-
2515
-
2516
-
2517
-
2518
-
2519
-
2520
-
2521
-
2522
-
2523
-
2524
-
2525
-
2526
-
2527
-
2528
-
2529
-
2530
-
2531
-
2532
-
2533
-
2534
-
2535
-
2536
-
2537
-
2538
-
2539
-
2540
-
2541
-
2542
-
2543
-
2544
-
2545
- 𠮶
 
1
+
2
+ !
3
+ "
4
+ #
5
+ $
6
+ %
7
+ &
8
+ '
9
+ (
10
+ )
11
+ *
12
+ +
13
+ ,
14
+ -
15
+ .
16
+ /
17
+ 0
18
+ 1
19
+ 2
20
+ 3
21
+ 4
22
+ 5
23
+ 6
24
+ 7
25
+ 8
26
+ 9
27
+ :
28
+ ;
29
+ =
30
+ >
31
+ ?
32
+ @
33
+ A
34
+ B
35
+ C
36
+ D
37
+ E
38
+ F
39
+ G
40
+ H
41
+ I
42
+ J
43
+ K
44
+ L
45
+ M
46
+ N
47
+ O
48
+ P
49
+ Q
50
+ R
51
+ S
52
+ T
53
+ U
54
+ V
55
+ W
56
+ X
57
+ Y
58
+ Z
59
+ [
60
+ \
61
+ ]
62
+ _
63
+ a
64
+ a1
65
+ ai1
66
+ ai2
67
+ ai3
68
+ ai4
69
+ an1
70
+ an3
71
+ an4
72
+ ang1
73
+ ang2
74
+ ang4
75
+ ao1
76
+ ao2
77
+ ao3
78
+ ao4
79
+ b
80
+ ba
81
+ ba1
82
+ ba2
83
+ ba3
84
+ ba4
85
+ bai1
86
+ bai2
87
+ bai3
88
+ bai4
89
+ ban1
90
+ ban2
91
+ ban3
92
+ ban4
93
+ bang1
94
+ bang2
95
+ bang3
96
+ bang4
97
+ bao1
98
+ bao2
99
+ bao3
100
+ bao4
101
+ bei
102
+ bei1
103
+ bei2
104
+ bei3
105
+ bei4
106
+ ben1
107
+ ben2
108
+ ben3
109
+ ben4
110
+ beng
111
+ beng1
112
+ beng2
113
+ beng3
114
+ beng4
115
+ bi1
116
+ bi2
117
+ bi3
118
+ bi4
119
+ bian1
120
+ bian2
121
+ bian3
122
+ bian4
123
+ biao1
124
+ biao2
125
+ biao3
126
+ bie1
127
+ bie2
128
+ bie3
129
+ bie4
130
+ bin1
131
+ bin4
132
+ bing1
133
+ bing2
134
+ bing3
135
+ bing4
136
+ bo
137
+ bo1
138
+ bo2
139
+ bo3
140
+ bo4
141
+ bu2
142
+ bu3
143
+ bu4
144
+ c
145
+ ca1
146
+ cai1
147
+ cai2
148
+ cai3
149
+ cai4
150
+ can1
151
+ can2
152
+ can3
153
+ can4
154
+ cang1
155
+ cang2
156
+ cao1
157
+ cao2
158
+ cao3
159
+ ce4
160
+ cen1
161
+ cen2
162
+ ceng1
163
+ ceng2
164
+ ceng4
165
+ cha1
166
+ cha2
167
+ cha3
168
+ cha4
169
+ chai1
170
+ chai2
171
+ chan1
172
+ chan2
173
+ chan3
174
+ chan4
175
+ chang1
176
+ chang2
177
+ chang3
178
+ chang4
179
+ chao1
180
+ chao2
181
+ chao3
182
+ che1
183
+ che2
184
+ che3
185
+ che4
186
+ chen1
187
+ chen2
188
+ chen3
189
+ chen4
190
+ cheng1
191
+ cheng2
192
+ cheng3
193
+ cheng4
194
+ chi1
195
+ chi2
196
+ chi3
197
+ chi4
198
+ chong1
199
+ chong2
200
+ chong3
201
+ chong4
202
+ chou1
203
+ chou2
204
+ chou3
205
+ chou4
206
+ chu1
207
+ chu2
208
+ chu3
209
+ chu4
210
+ chua1
211
+ chuai1
212
+ chuai2
213
+ chuai3
214
+ chuai4
215
+ chuan1
216
+ chuan2
217
+ chuan3
218
+ chuan4
219
+ chuang1
220
+ chuang2
221
+ chuang3
222
+ chuang4
223
+ chui1
224
+ chui2
225
+ chun1
226
+ chun2
227
+ chun3
228
+ chuo1
229
+ chuo4
230
+ ci1
231
+ ci2
232
+ ci3
233
+ ci4
234
+ cong1
235
+ cong2
236
+ cou4
237
+ cu1
238
+ cu4
239
+ cuan1
240
+ cuan2
241
+ cuan4
242
+ cui1
243
+ cui3
244
+ cui4
245
+ cun1
246
+ cun2
247
+ cun4
248
+ cuo1
249
+ cuo2
250
+ cuo4
251
+ d
252
+ da
253
+ da1
254
+ da2
255
+ da3
256
+ da4
257
+ dai1
258
+ dai2
259
+ dai3
260
+ dai4
261
+ dan1
262
+ dan2
263
+ dan3
264
+ dan4
265
+ dang1
266
+ dang2
267
+ dang3
268
+ dang4
269
+ dao1
270
+ dao2
271
+ dao3
272
+ dao4
273
+ de
274
+ de1
275
+ de2
276
+ dei3
277
+ den4
278
+ deng1
279
+ deng2
280
+ deng3
281
+ deng4
282
+ di1
283
+ di2
284
+ di3
285
+ di4
286
+ dia3
287
+ dian1
288
+ dian2
289
+ dian3
290
+ dian4
291
+ diao1
292
+ diao3
293
+ diao4
294
+ die1
295
+ die2
296
+ die4
297
+ ding1
298
+ ding2
299
+ ding3
300
+ ding4
301
+ diu1
302
+ dong1
303
+ dong3
304
+ dong4
305
+ dou1
306
+ dou2
307
+ dou3
308
+ dou4
309
+ du1
310
+ du2
311
+ du3
312
+ du4
313
+ duan1
314
+ duan2
315
+ duan3
316
+ duan4
317
+ dui1
318
+ dui4
319
+ dun1
320
+ dun3
321
+ dun4
322
+ duo1
323
+ duo2
324
+ duo3
325
+ duo4
326
+ e
327
+ e1
328
+ e2
329
+ e3
330
+ e4
331
+ ei2
332
+ en1
333
+ en4
334
+ er
335
+ er2
336
+ er3
337
+ er4
338
+ f
339
+ fa1
340
+ fa2
341
+ fa3
342
+ fa4
343
+ fan1
344
+ fan2
345
+ fan3
346
+ fan4
347
+ fang1
348
+ fang2
349
+ fang3
350
+ fang4
351
+ fei1
352
+ fei2
353
+ fei3
354
+ fei4
355
+ fen1
356
+ fen2
357
+ fen3
358
+ fen4
359
+ feng1
360
+ feng2
361
+ feng3
362
+ feng4
363
+ fo2
364
+ fou2
365
+ fou3
366
+ fu1
367
+ fu2
368
+ fu3
369
+ fu4
370
+ g
371
+ ga1
372
+ ga2
373
+ ga3
374
+ ga4
375
+ gai1
376
+ gai2
377
+ gai3
378
+ gai4
379
+ gan1
380
+ gan2
381
+ gan3
382
+ gan4
383
+ gang1
384
+ gang2
385
+ gang3
386
+ gang4
387
+ gao1
388
+ gao2
389
+ gao3
390
+ gao4
391
+ ge1
392
+ ge2
393
+ ge3
394
+ ge4
395
+ gei2
396
+ gei3
397
+ gen1
398
+ gen2
399
+ gen3
400
+ gen4
401
+ geng1
402
+ geng3
403
+ geng4
404
+ gong1
405
+ gong3
406
+ gong4
407
+ gou1
408
+ gou2
409
+ gou3
410
+ gou4
411
+ gu
412
+ gu1
413
+ gu2
414
+ gu3
415
+ gu4
416
+ gua1
417
+ gua2
418
+ gua3
419
+ gua4
420
+ guai1
421
+ guai2
422
+ guai3
423
+ guai4
424
+ guan1
425
+ guan2
426
+ guan3
427
+ guan4
428
+ guang1
429
+ guang2
430
+ guang3
431
+ guang4
432
+ gui1
433
+ gui2
434
+ gui3
435
+ gui4
436
+ gun3
437
+ gun4
438
+ guo1
439
+ guo2
440
+ guo3
441
+ guo4
442
+ h
443
+ ha1
444
+ ha2
445
+ ha3
446
+ hai1
447
+ hai2
448
+ hai3
449
+ hai4
450
+ han1
451
+ han2
452
+ han3
453
+ han4
454
+ hang1
455
+ hang2
456
+ hang4
457
+ hao1
458
+ hao2
459
+ hao3
460
+ hao4
461
+ he1
462
+ he2
463
+ he4
464
+ hei1
465
+ hen2
466
+ hen3
467
+ hen4
468
+ heng1
469
+ heng2
470
+ heng4
471
+ hong1
472
+ hong2
473
+ hong3
474
+ hong4
475
+ hou1
476
+ hou2
477
+ hou3
478
+ hou4
479
+ hu1
480
+ hu2
481
+ hu3
482
+ hu4
483
+ hua1
484
+ hua2
485
+ hua4
486
+ huai2
487
+ huai4
488
+ huan1
489
+ huan2
490
+ huan3
491
+ huan4
492
+ huang1
493
+ huang2
494
+ huang3
495
+ huang4
496
+ hui1
497
+ hui2
498
+ hui3
499
+ hui4
500
+ hun1
501
+ hun2
502
+ hun4
503
+ huo
504
+ huo1
505
+ huo2
506
+ huo3
507
+ huo4
508
+ i
509
+ j
510
+ ji1
511
+ ji2
512
+ ji3
513
+ ji4
514
+ jia
515
+ jia1
516
+ jia2
517
+ jia3
518
+ jia4
519
+ jian1
520
+ jian2
521
+ jian3
522
+ jian4
523
+ jiang1
524
+ jiang2
525
+ jiang3
526
+ jiang4
527
+ jiao1
528
+ jiao2
529
+ jiao3
530
+ jiao4
531
+ jie1
532
+ jie2
533
+ jie3
534
+ jie4
535
+ jin1
536
+ jin2
537
+ jin3
538
+ jin4
539
+ jing1
540
+ jing2
541
+ jing3
542
+ jing4
543
+ jiong3
544
+ jiu1
545
+ jiu2
546
+ jiu3
547
+ jiu4
548
+ ju1
549
+ ju2
550
+ ju3
551
+ ju4
552
+ juan1
553
+ juan2
554
+ juan3
555
+ juan4
556
+ jue1
557
+ jue2
558
+ jue4
559
+ jun1
560
+ jun4
561
+ k
562
+ ka1
563
+ ka2
564
+ ka3
565
+ kai1
566
+ kai2
567
+ kai3
568
+ kai4
569
+ kan1
570
+ kan2
571
+ kan3
572
+ kan4
573
+ kang1
574
+ kang2
575
+ kang4
576
+ kao1
577
+ kao2
578
+ kao3
579
+ kao4
580
+ ke1
581
+ ke2
582
+ ke3
583
+ ke4
584
+ ken3
585
+ keng1
586
+ kong1
587
+ kong3
588
+ kong4
589
+ kou1
590
+ kou2
591
+ kou3
592
+ kou4
593
+ ku1
594
+ ku2
595
+ ku3
596
+ ku4
597
+ kua1
598
+ kua3
599
+ kua4
600
+ kuai3
601
+ kuai4
602
+ kuan1
603
+ kuan2
604
+ kuan3
605
+ kuang1
606
+ kuang2
607
+ kuang4
608
+ kui1
609
+ kui2
610
+ kui3
611
+ kui4
612
+ kun1
613
+ kun3
614
+ kun4
615
+ kuo4
616
+ l
617
+ la
618
+ la1
619
+ la2
620
+ la3
621
+ la4
622
+ lai2
623
+ lai4
624
+ lan2
625
+ lan3
626
+ lan4
627
+ lang1
628
+ lang2
629
+ lang3
630
+ lang4
631
+ lao1
632
+ lao2
633
+ lao3
634
+ lao4
635
+ le
636
+ le1
637
+ le4
638
+ lei
639
+ lei1
640
+ lei2
641
+ lei3
642
+ lei4
643
+ leng1
644
+ leng2
645
+ leng3
646
+ leng4
647
+ li
648
+ li1
649
+ li2
650
+ li3
651
+ li4
652
+ lia3
653
+ lian2
654
+ lian3
655
+ lian4
656
+ liang2
657
+ liang3
658
+ liang4
659
+ liao1
660
+ liao2
661
+ liao3
662
+ liao4
663
+ lie1
664
+ lie2
665
+ lie3
666
+ lie4
667
+ lin1
668
+ lin2
669
+ lin3
670
+ lin4
671
+ ling2
672
+ ling3
673
+ ling4
674
+ liu1
675
+ liu2
676
+ liu3
677
+ liu4
678
+ long1
679
+ long2
680
+ long3
681
+ long4
682
+ lou1
683
+ lou2
684
+ lou3
685
+ lou4
686
+ lu1
687
+ lu2
688
+ lu3
689
+ lu4
690
+ luan2
691
+ luan3
692
+ luan4
693
+ lun1
694
+ lun2
695
+ lun4
696
+ luo1
697
+ luo2
698
+ luo3
699
+ luo4
700
+ lv2
701
+ lv3
702
+ lv4
703
+ lve3
704
+ lve4
705
+ m
706
+ ma
707
+ ma1
708
+ ma2
709
+ ma3
710
+ ma4
711
+ mai2
712
+ mai3
713
+ mai4
714
+ man1
715
+ man2
716
+ man3
717
+ man4
718
+ mang2
719
+ mang3
720
+ mao1
721
+ mao2
722
+ mao3
723
+ mao4
724
+ me
725
+ mei2
726
+ mei3
727
+ mei4
728
+ men
729
+ men1
730
+ men2
731
+ men4
732
+ meng
733
+ meng1
734
+ meng2
735
+ meng3
736
+ meng4
737
+ mi1
738
+ mi2
739
+ mi3
740
+ mi4
741
+ mian2
742
+ mian3
743
+ mian4
744
+ miao1
745
+ miao2
746
+ miao3
747
+ miao4
748
+ mie1
749
+ mie4
750
+ min2
751
+ min3
752
+ ming2
753
+ ming3
754
+ ming4
755
+ miu4
756
+ mo1
757
+ mo2
758
+ mo3
759
+ mo4
760
+ mou1
761
+ mou2
762
+ mou3
763
+ mu2
764
+ mu3
765
+ mu4
766
+ n
767
+ n2
768
+ na1
769
+ na2
770
+ na3
771
+ na4
772
+ nai2
773
+ nai3
774
+ nai4
775
+ nan1
776
+ nan2
777
+ nan3
778
+ nan4
779
+ nang1
780
+ nang2
781
+ nang3
782
+ nao1
783
+ nao2
784
+ nao3
785
+ nao4
786
+ ne
787
+ ne2
788
+ ne4
789
+ nei3
790
+ nei4
791
+ nen4
792
+ neng2
793
+ ni1
794
+ ni2
795
+ ni3
796
+ ni4
797
+ nian1
798
+ nian2
799
+ nian3
800
+ nian4
801
+ niang2
802
+ niang4
803
+ niao2
804
+ niao3
805
+ niao4
806
+ nie1
807
+ nie4
808
+ nin2
809
+ ning2
810
+ ning3
811
+ ning4
812
+ niu1
813
+ niu2
814
+ niu3
815
+ niu4
816
+ nong2
817
+ nong4
818
+ nou4
819
+ nu2
820
+ nu3
821
+ nu4
822
+ nuan3
823
+ nuo2
824
+ nuo4
825
+ nv2
826
+ nv3
827
+ nve4
828
+ o
829
+ o1
830
+ o2
831
+ ou1
832
+ ou2
833
+ ou3
834
+ ou4
835
+ p
836
+ pa1
837
+ pa2
838
+ pa4
839
+ pai1
840
+ pai2
841
+ pai3
842
+ pai4
843
+ pan1
844
+ pan2
845
+ pan4
846
+ pang1
847
+ pang2
848
+ pang4
849
+ pao1
850
+ pao2
851
+ pao3
852
+ pao4
853
+ pei1
854
+ pei2
855
+ pei4
856
+ pen1
857
+ pen2
858
+ pen4
859
+ peng1
860
+ peng2
861
+ peng3
862
+ peng4
863
+ pi1
864
+ pi2
865
+ pi3
866
+ pi4
867
+ pian1
868
+ pian2
869
+ pian4
870
+ piao1
871
+ piao2
872
+ piao3
873
+ piao4
874
+ pie1
875
+ pie2
876
+ pie3
877
+ pin1
878
+ pin2
879
+ pin3
880
+ pin4
881
+ ping1
882
+ ping2
883
+ po1
884
+ po2
885
+ po3
886
+ po4
887
+ pou1
888
+ pu1
889
+ pu2
890
+ pu3
891
+ pu4
892
+ q
893
+ qi1
894
+ qi2
895
+ qi3
896
+ qi4
897
+ qia1
898
+ qia3
899
+ qia4
900
+ qian1
901
+ qian2
902
+ qian3
903
+ qian4
904
+ qiang1
905
+ qiang2
906
+ qiang3
907
+ qiang4
908
+ qiao1
909
+ qiao2
910
+ qiao3
911
+ qiao4
912
+ qie1
913
+ qie2
914
+ qie3
915
+ qie4
916
+ qin1
917
+ qin2
918
+ qin3
919
+ qin4
920
+ qing1
921
+ qing2
922
+ qing3
923
+ qing4
924
+ qiong1
925
+ qiong2
926
+ qiu1
927
+ qiu2
928
+ qiu3
929
+ qu1
930
+ qu2
931
+ qu3
932
+ qu4
933
+ quan1
934
+ quan2
935
+ quan3
936
+ quan4
937
+ que1
938
+ que2
939
+ que4
940
+ qun2
941
+ r
942
+ ran2
943
+ ran3
944
+ rang1
945
+ rang2
946
+ rang3
947
+ rang4
948
+ rao2
949
+ rao3
950
+ rao4
951
+ re2
952
+ re3
953
+ re4
954
+ ren2
955
+ ren3
956
+ ren4
957
+ reng1
958
+ reng2
959
+ ri4
960
+ rong1
961
+ rong2
962
+ rong3
963
+ rou2
964
+ rou4
965
+ ru2
966
+ ru3
967
+ ru4
968
+ ruan2
969
+ ruan3
970
+ rui3
971
+ rui4
972
+ run4
973
+ ruo4
974
+ s
975
+ sa1
976
+ sa2
977
+ sa3
978
+ sa4
979
+ sai1
980
+ sai4
981
+ san1
982
+ san2
983
+ san3
984
+ san4
985
+ sang1
986
+ sang3
987
+ sang4
988
+ sao1
989
+ sao2
990
+ sao3
991
+ sao4
992
+ se4
993
+ sen1
994
+ seng1
995
+ sha1
996
+ sha2
997
+ sha3
998
+ sha4
999
+ shai1
1000
+ shai2
1001
+ shai3
1002
+ shai4
1003
+ shan1
1004
+ shan3
1005
+ shan4
1006
+ shang
1007
+ shang1
1008
+ shang3
1009
+ shang4
1010
+ shao1
1011
+ shao2
1012
+ shao3
1013
+ shao4
1014
+ she1
1015
+ she2
1016
+ she3
1017
+ she4
1018
+ shei2
1019
+ shen1
1020
+ shen2
1021
+ shen3
1022
+ shen4
1023
+ sheng1
1024
+ sheng2
1025
+ sheng3
1026
+ sheng4
1027
+ shi
1028
+ shi1
1029
+ shi2
1030
+ shi3
1031
+ shi4
1032
+ shou1
1033
+ shou2
1034
+ shou3
1035
+ shou4
1036
+ shu1
1037
+ shu2
1038
+ shu3
1039
+ shu4
1040
+ shua1
1041
+ shua2
1042
+ shua3
1043
+ shua4
1044
+ shuai1
1045
+ shuai3
1046
+ shuai4
1047
+ shuan1
1048
+ shuan4
1049
+ shuang1
1050
+ shuang3
1051
+ shui2
1052
+ shui3
1053
+ shui4
1054
+ shun3
1055
+ shun4
1056
+ shuo1
1057
+ shuo4
1058
+ si1
1059
+ si2
1060
+ si3
1061
+ si4
1062
+ song1
1063
+ song3
1064
+ song4
1065
+ sou1
1066
+ sou3
1067
+ sou4
1068
+ su1
1069
+ su2
1070
+ su4
1071
+ suan1
1072
+ suan4
1073
+ sui1
1074
+ sui2
1075
+ sui3
1076
+ sui4
1077
+ sun1
1078
+ sun3
1079
+ suo
1080
+ suo1
1081
+ suo2
1082
+ suo3
1083
+ t
1084
+ ta1
1085
+ ta2
1086
+ ta3
1087
+ ta4
1088
+ tai1
1089
+ tai2
1090
+ tai4
1091
+ tan1
1092
+ tan2
1093
+ tan3
1094
+ tan4
1095
+ tang1
1096
+ tang2
1097
+ tang3
1098
+ tang4
1099
+ tao1
1100
+ tao2
1101
+ tao3
1102
+ tao4
1103
+ te4
1104
+ teng2
1105
+ ti1
1106
+ ti2
1107
+ ti3
1108
+ ti4
1109
+ tian1
1110
+ tian2
1111
+ tian3
1112
+ tiao1
1113
+ tiao2
1114
+ tiao3
1115
+ tiao4
1116
+ tie1
1117
+ tie2
1118
+ tie3
1119
+ tie4
1120
+ ting1
1121
+ ting2
1122
+ ting3
1123
+ tong1
1124
+ tong2
1125
+ tong3
1126
+ tong4
1127
+ tou
1128
+ tou1
1129
+ tou2
1130
+ tou4
1131
+ tu1
1132
+ tu2
1133
+ tu3
1134
+ tu4
1135
+ tuan1
1136
+ tuan2
1137
+ tui1
1138
+ tui2
1139
+ tui3
1140
+ tui4
1141
+ tun1
1142
+ tun2
1143
+ tun4
1144
+ tuo1
1145
+ tuo2
1146
+ tuo3
1147
+ tuo4
1148
+ u
1149
+ v
1150
+ w
1151
+ wa
1152
+ wa1
1153
+ wa2
1154
+ wa3
1155
+ wa4
1156
+ wai1
1157
+ wai3
1158
+ wai4
1159
+ wan1
1160
+ wan2
1161
+ wan3
1162
+ wan4
1163
+ wang1
1164
+ wang2
1165
+ wang3
1166
+ wang4
1167
+ wei1
1168
+ wei2
1169
+ wei3
1170
+ wei4
1171
+ wen1
1172
+ wen2
1173
+ wen3
1174
+ wen4
1175
+ weng1
1176
+ weng4
1177
+ wo1
1178
+ wo2
1179
+ wo3
1180
+ wo4
1181
+ wu1
1182
+ wu2
1183
+ wu3
1184
+ wu4
1185
+ x
1186
+ xi1
1187
+ xi2
1188
+ xi3
1189
+ xi4
1190
+ xia1
1191
+ xia2
1192
+ xia4
1193
+ xian1
1194
+ xian2
1195
+ xian3
1196
+ xian4
1197
+ xiang1
1198
+ xiang2
1199
+ xiang3
1200
+ xiang4
1201
+ xiao1
1202
+ xiao2
1203
+ xiao3
1204
+ xiao4
1205
+ xie1
1206
+ xie2
1207
+ xie3
1208
+ xie4
1209
+ xin1
1210
+ xin2
1211
+ xin4
1212
+ xing1
1213
+ xing2
1214
+ xing3
1215
+ xing4
1216
+ xiong1
1217
+ xiong2
1218
+ xiu1
1219
+ xiu3
1220
+ xiu4
1221
+ xu
1222
+ xu1
1223
+ xu2
1224
+ xu3
1225
+ xu4
1226
+ xuan1
1227
+ xuan2
1228
+ xuan3
1229
+ xuan4
1230
+ xue1
1231
+ xue2
1232
+ xue3
1233
+ xue4
1234
+ xun1
1235
+ xun2
1236
+ xun4
1237
+ y
1238
+ ya
1239
+ ya1
1240
+ ya2
1241
+ ya3
1242
+ ya4
1243
+ yan1
1244
+ yan2
1245
+ yan3
1246
+ yan4
1247
+ yang1
1248
+ yang2
1249
+ yang3
1250
+ yang4
1251
+ yao1
1252
+ yao2
1253
+ yao3
1254
+ yao4
1255
+ ye1
1256
+ ye2
1257
+ ye3
1258
+ ye4
1259
+ yi
1260
+ yi1
1261
+ yi2
1262
+ yi3
1263
+ yi4
1264
+ yin1
1265
+ yin2
1266
+ yin3
1267
+ yin4
1268
+ ying1
1269
+ ying2
1270
+ ying3
1271
+ ying4
1272
+ yo1
1273
+ yong1
1274
+ yong2
1275
+ yong3
1276
+ yong4
1277
+ you1
1278
+ you2
1279
+ you3
1280
+ you4
1281
+ yu1
1282
+ yu2
1283
+ yu3
1284
+ yu4
1285
+ yuan1
1286
+ yuan2
1287
+ yuan3
1288
+ yuan4
1289
+ yue1
1290
+ yue4
1291
+ yun1
1292
+ yun2
1293
+ yun3
1294
+ yun4
1295
+ z
1296
+ za1
1297
+ za2
1298
+ za3
1299
+ zai1
1300
+ zai3
1301
+ zai4
1302
+ zan1
1303
+ zan2
1304
+ zan3
1305
+ zan4
1306
+ zang1
1307
+ zang4
1308
+ zao1
1309
+ zao2
1310
+ zao3
1311
+ zao4
1312
+ ze2
1313
+ ze4
1314
+ zei2
1315
+ zen3
1316
+ zeng1
1317
+ zeng4
1318
+ zha1
1319
+ zha2
1320
+ zha3
1321
+ zha4
1322
+ zhai1
1323
+ zhai2
1324
+ zhai3
1325
+ zhai4
1326
+ zhan1
1327
+ zhan2
1328
+ zhan3
1329
+ zhan4
1330
+ zhang1
1331
+ zhang2
1332
+ zhang3
1333
+ zhang4
1334
+ zhao1
1335
+ zhao2
1336
+ zhao3
1337
+ zhao4
1338
+ zhe
1339
+ zhe1
1340
+ zhe2
1341
+ zhe3
1342
+ zhe4
1343
+ zhen1
1344
+ zhen2
1345
+ zhen3
1346
+ zhen4
1347
+ zheng1
1348
+ zheng2
1349
+ zheng3
1350
+ zheng4
1351
+ zhi1
1352
+ zhi2
1353
+ zhi3
1354
+ zhi4
1355
+ zhong1
1356
+ zhong2
1357
+ zhong3
1358
+ zhong4
1359
+ zhou1
1360
+ zhou2
1361
+ zhou3
1362
+ zhou4
1363
+ zhu1
1364
+ zhu2
1365
+ zhu3
1366
+ zhu4
1367
+ zhua1
1368
+ zhua2
1369
+ zhua3
1370
+ zhuai1
1371
+ zhuai3
1372
+ zhuai4
1373
+ zhuan1
1374
+ zhuan2
1375
+ zhuan3
1376
+ zhuan4
1377
+ zhuang1
1378
+ zhuang4
1379
+ zhui1
1380
+ zhui4
1381
+ zhun1
1382
+ zhun2
1383
+ zhun3
1384
+ zhuo1
1385
+ zhuo2
1386
+ zi
1387
+ zi1
1388
+ zi2
1389
+ zi3
1390
+ zi4
1391
+ zong1
1392
+ zong2
1393
+ zong3
1394
+ zong4
1395
+ zou1
1396
+ zou2
1397
+ zou3
1398
+ zou4
1399
+ zu1
1400
+ zu2
1401
+ zu3
1402
+ zuan1
1403
+ zuan3
1404
+ zuan4
1405
+ zui2
1406
+ zui3
1407
+ zui4
1408
+ zun1
1409
+ zuo
1410
+ zuo1
1411
+ zuo2
1412
+ zuo3
1413
+ zuo4
1414
+ {
1415
+ ~
1416
+ ¡
1417
+ ¢
1418
+ £
1419
+ ¥
1420
+ §
1421
+ ¨
1422
+ ©
1423
+ «
1424
+ ®
1425
+ ¯
1426
+ °
1427
+ ±
1428
+ ²
1429
+ ³
1430
+ ´
1431
+ µ
1432
+ ·
1433
+ ¹
1434
+ º
1435
+ »
1436
+ ¼
1437
+ ½
1438
+ ¾
1439
+ ¿
1440
+ À
1441
+ Á
1442
+ Â
1443
+ Ã
1444
+ Ä
1445
+ Å
1446
+ Æ
1447
+ Ç
1448
+ È
1449
+ É
1450
+ Ê
1451
+ Í
1452
+ Î
1453
+ Ñ
1454
+ Ó
1455
+ Ö
1456
+ ×
1457
+ Ø
1458
+ Ú
1459
+ Ü
1460
+ Ý
1461
+ Þ
1462
+ ß
1463
+ à
1464
+ á
1465
+ â
1466
+ ã
1467
+ ä
1468
+ å
1469
+ æ
1470
+ ç
1471
+ è
1472
+ é
1473
+ ê
1474
+ ë
1475
+ ì
1476
+ í
1477
+ î
1478
+ ï
1479
+ ð
1480
+ ñ
1481
+ ò
1482
+ ó
1483
+ ô
1484
+ õ
1485
+ ö
1486
+ ø
1487
+ ù
1488
+ ú
1489
+ û
1490
+ ü
1491
+ ý
1492
+ Ā
1493
+ ā
1494
+ ă
1495
+ ą
1496
+ ć
1497
+ Č
1498
+ č
1499
+ Đ
1500
+ đ
1501
+ ē
1502
+ ė
1503
+ ę
1504
+ ě
1505
+ ĝ
1506
+ ğ
1507
+ ħ
1508
+ ī
1509
+ į
1510
+ İ
1511
+ ı
1512
+ Ł
1513
+ ł
1514
+ ń
1515
+ ņ
1516
+ ň
1517
+ ŋ
1518
+ Ō
1519
+ ō
1520
+ ő
1521
+ œ
1522
+ ř
1523
+ Ś
1524
+ ś
1525
+ Ş
1526
+ ş
1527
+ Š
1528
+ š
1529
+ Ť
1530
+ ť
1531
+ ũ
1532
+ ū
1533
+ ź
1534
+ Ż
1535
+ ż
1536
+ Ž
1537
+ ž
1538
+ ơ
1539
+ ư
1540
+ ǎ
1541
+ ǐ
1542
+ ǒ
1543
+ ǔ
1544
+ ǚ
1545
+ ș
1546
+ ț
1547
+ ɑ
1548
+ ɔ
1549
+ ɕ
1550
+ ə
1551
+ ɛ
1552
+ ɜ
1553
+ ɡ
1554
+ ɣ
1555
+ ɪ
1556
+ ɫ
1557
+ ɴ
1558
+ ɹ
1559
+ ɾ
1560
+ ʃ
1561
+ ʊ
1562
+ ʌ
1563
+ ʒ
1564
+ ʔ
1565
+ ʰ
1566
+ ʷ
1567
+ ʻ
1568
+ ʾ
1569
+ ʿ
1570
+ ˈ
1571
+ ː
1572
+ ˙
1573
+ ˜
1574
+ ˢ
1575
+ ́
1576
+ ̅
1577
+ Α
1578
+ Β
1579
+ Δ
1580
+ Ε
1581
+ Θ
1582
+ Κ
1583
+ Λ
1584
+ Μ
1585
+ Ξ
1586
+ Π
1587
+ Σ
1588
+ Τ
1589
+ Φ
1590
+ Χ
1591
+ Ψ
1592
+ Ω
1593
+ ά
1594
+ έ
1595
+ ή
1596
+ ί
1597
+ α
1598
+ β
1599
+ γ
1600
+ δ
1601
+ ε
1602
+ ζ
1603
+ η
1604
+ θ
1605
+ ι
1606
+ κ
1607
+ λ
1608
+ μ
1609
+ ν
1610
+ ξ
1611
+ ο
1612
+ π
1613
+ ρ
1614
+ ς
1615
+ σ
1616
+ τ
1617
+ υ
1618
+ φ
1619
+ χ
1620
+ ψ
1621
+ ω
1622
+ ϊ
1623
+ ό
1624
+ ύ
1625
+ ώ
1626
+ ϕ
1627
+ ϵ
1628
+ Ё
1629
+ А
1630
+ Б
1631
+ В
1632
+ Г
1633
+ Д
1634
+ Е
1635
+ Ж
1636
+ З
1637
+ И
1638
+ Й
1639
+ К
1640
+ Л
1641
+ М
1642
+ Н
1643
+ О
1644
+ П
1645
+ Р
1646
+ С
1647
+ Т
1648
+ У
1649
+ Ф
1650
+ Х
1651
+ Ц
1652
+ Ч
1653
+ Ш
1654
+ Щ
1655
+ Ы
1656
+ Ь
1657
+ Э
1658
+ Ю
1659
+ Я
1660
+ а
1661
+ б
1662
+ в
1663
+ г
1664
+ д
1665
+ е
1666
+ ж
1667
+ з
1668
+ и
1669
+ й
1670
+ к
1671
+ л
1672
+ м
1673
+ н
1674
+ о
1675
+ п
1676
+ р
1677
+ с
1678
+ т
1679
+ у
1680
+ ф
1681
+ х
1682
+ ц
1683
+ ч
1684
+ ш
1685
+ щ
1686
+ ъ
1687
+ ы
1688
+ ь
1689
+ э
1690
+ ю
1691
+ я
1692
+ ё
1693
+ і
1694
+ ְ
1695
+ ִ
1696
+ ֵ
1697
+ ֶ
1698
+ ַ
1699
+ ָ
1700
+ ֹ
1701
+ ּ
1702
+ ־
1703
+ ׁ
1704
+ א
1705
+ ב
1706
+ ג
1707
+ ד
1708
+ ה
1709
+ ו
1710
+ ז
1711
+ ח
1712
+ ט
1713
+ י
1714
+ כ
1715
+ ל
1716
+ ם
1717
+ מ
1718
+ ן
1719
+ נ
1720
+ ס
1721
+ ע
1722
+ פ
1723
+ ק
1724
+ ר
1725
+ ש
1726
+ ת
1727
+ أ
1728
+ ب
1729
+ ة
1730
+ ت
1731
+ ج
1732
+ ح
1733
+ د
1734
+ ر
1735
+ ز
1736
+ س
1737
+ ص
1738
+ ط
1739
+ ع
1740
+ ق
1741
+ ك
1742
+ ل
1743
+ م
1744
+ ن
1745
+ ه
1746
+ و
1747
+ ي
1748
+ َ
1749
+ ُ
1750
+ ِ
1751
+ ْ
1752
+
1753
+
1754
+
1755
+
1756
+
1757
+
1758
+
1759
+
1760
+
1761
+
1762
+
1763
+
1764
+
1765
+
1766
+
1767
+
1768
+
1769
+
1770
+
1771
+
1772
+
1773
+
1774
+
1775
+
1776
+
1777
+
1778
+
1779
+
1780
+
1781
+
1782
+
1783
+
1784
+
1785
+
1786
+
1787
+
1788
+
1789
+
1790
+
1791
+
1792
+
1793
+
1794
+
1795
+
1796
+
1797
+
1798
+
1799
+
1800
+ ế
1801
+
1802
+
1803
+
1804
+
1805
+
1806
+
1807
+
1808
+
1809
+
1810
+
1811
+
1812
+
1813
+
1814
+
1815
+
1816
+
1817
+
1818
+
1819
+
1820
+
1821
+
1822
+
1823
+
1824
+
1825
+
1826
+
1827
+
1828
+
1829
+
1830
+
1831
+
1832
+
1833
+
1834
+
1835
+
1836
+
1837
+
1838
+
1839
+
1840
+
1841
+
1842
+
1843
+
1844
+
1845
+
1846
+
1847
+
1848
+
1849
+
1850
+
1851
+
1852
+
1853
+
1854
+
1855
+
1856
+
1857
+
1858
+
1859
+
1860
+
1861
+
1862
+
1863
+
1864
+
1865
+
1866
+
1867
+
1868
+
1869
+
1870
+
1871
+
1872
+
1873
+
1874
+
1875
+
1876
+
1877
+
1878
+
1879
+
1880
+
1881
+
1882
+
1883
+
1884
+
1885
+
1886
+
1887
+
1888
+
1889
+
1890
+
1891
+
1892
+
1893
+
1894
+
1895
+
1896
+
1897
+
1898
+
1899
+
1900
+
1901
+
1902
+
1903
+
1904
+
1905
+
1906
+
1907
+
1908
+
1909
+
1910
+
1911
+
1912
+
1913
+
1914
+
1915
+
1916
+
1917
+
1918
+
1919
+
1920
+
1921
+
1922
+
1923
+
1924
+
1925
+
1926
+
1927
+
1928
+
1929
+
1930
+
1931
+
1932
+
1933
+
1934
+
1935
+
1936
+
1937
+
1938
+
1939
+
1940
+
1941
+
1942
+
1943
+
1944
+
1945
+
1946
+
1947
+
1948
+
1949
+
1950
+
1951
+
1952
+
1953
+
1954
+
1955
+
1956
+
1957
+
1958
+
1959
+
1960
+
1961
+
1962
+
1963
+
1964
+
1965
+
1966
+
1967
+
1968
+
1969
+
1970
+
1971
+
1972
+
1973
+
1974
+
1975
+
1976
+
1977
+
1978
+
1979
+
1980
+
1981
+
1982
+
1983
+
1984
+
1985
+
1986
+
1987
+
1988
+
1989
+
1990
+
1991
+
1992
+
1993
+
1994
+
1995
+
1996
+
1997
+
1998
+
1999
+
2000
+
2001
+
2002
+
2003
+
2004
+
2005
+
2006
+
2007
+
2008
+
2009
+
2010
+
2011
+
2012
+
2013
+
2014
+
2015
+
2016
+
2017
+
2018
+
2019
+
2020
+
2021
+
2022
+
2023
+
2024
+
2025
+
2026
+
2027
+
2028
+
2029
+
2030
+
2031
+
2032
+
2033
+
2034
+
2035
+
2036
+
2037
+
2038
+
2039
+
2040
+
2041
+
2042
+
2043
+
2044
+
2045
+
2046
+
2047
+
2048
+
2049
+
2050
+
2051
+
2052
+
2053
+
2054
+
2055
+
2056
+
2057
+
2058
+
2059
+
2060
+
2061
+
2062
+
2063
+
2064
+
2065
+
2066
+
2067
+
2068
+
2069
+
2070
+
2071
+
2072
+
2073
+
2074
+
2075
+
2076
+
2077
+
2078
+
2079
+
2080
+
2081
+
2082
+
2083
+
2084
+
2085
+
2086
+
2087
+
2088
+
2089
+
2090
+
2091
+
2092
+
2093
+
2094
+
2095
+
2096
+
2097
+
2098
+
2099
+
2100
+
2101
+
2102
+
2103
+
2104
+
2105
+
2106
+
2107
+
2108
+
2109
+
2110
+
2111
+
2112
+
2113
+
2114
+
2115
+
2116
+
2117
+
2118
+
2119
+
2120
+
2121
+
2122
+
2123
+
2124
+
2125
+
2126
+
2127
+
2128
+
2129
+
2130
+
2131
+
2132
+
2133
+
2134
+
2135
+
2136
+
2137
+
2138
+
2139
+
2140
+
2141
+
2142
+
2143
+
2144
+
2145
+
2146
+
2147
+
2148
+
2149
+
2150
+
2151
+
2152
+
2153
+
2154
+
2155
+
2156
+
2157
+
2158
+
2159
+
2160
+
2161
+
2162
+
2163
+
2164
+
2165
+
2166
+
2167
+
2168
+
2169
+
2170
+
2171
+
2172
+
2173
+
2174
+
2175
+
2176
+
2177
+
2178
+
2179
+
2180
+
2181
+
2182
+
2183
+
2184
+
2185
+
2186
+
2187
+
2188
+
2189
+
2190
+
2191
+
2192
+
2193
+
2194
+
2195
+
2196
+
2197
+
2198
+
2199
+
2200
+
2201
+
2202
+
2203
+
2204
+
2205
+
2206
+
2207
+
2208
+
2209
+
2210
+
2211
+
2212
+
2213
+
2214
+
2215
+
2216
+
2217
+
2218
+
2219
+
2220
+
2221
+
2222
+
2223
+
2224
+
2225
+
2226
+
2227
+
2228
+
2229
+
2230
+
2231
+
2232
+
2233
+
2234
+
2235
+
2236
+
2237
+
2238
+
2239
+
2240
+
2241
+
2242
+
2243
+
2244
+
2245
+
2246
+
2247
+
2248
+
2249
+
2250
+
2251
+
2252
+
2253
+
2254
+
2255
+
2256
+
2257
+
2258
+
2259
+
2260
+
2261
+
2262
+
2263
+
2264
+
2265
+
2266
+
2267
+
2268
+
2269
+
2270
+
2271
+
2272
+
2273
+
2274
+
2275
+
2276
+
2277
+
2278
+
2279
+
2280
+
2281
+
2282
+
2283
+
2284
+
2285
+
2286
+
2287
+
2288
+
2289
+
2290
+
2291
+
2292
+
2293
+
2294
+
2295
+
2296
+
2297
+
2298
+
2299
+
2300
+
2301
+
2302
+
2303
+
2304
+
2305
+
2306
+
2307
+
2308
+
2309
+
2310
+
2311
+
2312
+
2313
+
2314
+
2315
+
2316
+
2317
+
2318
+
2319
+
2320
+
2321
+
2322
+
2323
+
2324
+
2325
+
2326
+
2327
+
2328
+
2329
+
2330
+
2331
+
2332
+
2333
+
2334
+
2335
+
2336
+
2337
+
2338
+
2339
+
2340
+
2341
+
2342
+
2343
+
2344
+
2345
+
2346
+
2347
+
2348
+
2349
+
2350
+
2351
+
2352
+
2353
+
2354
+
2355
+
2356
+
2357
+
2358
+
2359
+
2360
+
2361
+
2362
+
2363
+
2364
+
2365
+
2366
+
2367
+
2368
+
2369
+
2370
+
2371
+
2372
+
2373
+
2374
+
2375
+
2376
+
2377
+
2378
+
2379
+
2380
+
2381
+
2382
+
2383
+
2384
+
2385
+
2386
+
2387
+
2388
+
2389
+
2390
+
2391
+
2392
+
2393
+
2394
+
2395
+
2396
+
2397
+
2398
+
2399
+
2400
+
2401
+
2402
+
2403
+
2404
+
2405
+
2406
+
2407
+
2408
+
2409
+
2410
+
2411
+
2412
+
2413
+
2414
+
2415
+
2416
+
2417
+
2418
+
2419
+
2420
+
2421
+
2422
+
2423
+
2424
+
2425
+
2426
+
2427
+
2428
+
2429
+
2430
+
2431
+
2432
+
2433
+
2434
+
2435
+
2436
+
2437
+
2438
+
2439
+
2440
+
2441
+
2442
+
2443
+
2444
+
2445
+
2446
+
2447
+
2448
+
2449
+
2450
+
2451
+
2452
+
2453
+
2454
+
2455
+
2456
+
2457
+
2458
+
2459
+
2460
+
2461
+
2462
+
2463
+
2464
+
2465
+
2466
+
2467
+
2468
+
2469
+
2470
+
2471
+
2472
+
2473
+
2474
+
2475
+
2476
+
2477
+
2478
+
2479
+
2480
+
2481
+
2482
+
2483
+
2484
+
2485
+
2486
+
2487
+
2488
+
2489
+
2490
+
2491
+
2492
+
2493
+
2494
+
2495
+
2496
+
2497
+
2498
+
2499
+
2500
+
2501
+
2502
+
2503
+
2504
+
2505
+
2506
+
2507
+
2508
+
2509
+
2510
+
2511
+
2512
+
2513
+
2514
+
2515
+
2516
+
2517
+
2518
+
2519
+
2520
+
2521
+
2522
+
2523
+
2524
+
2525
+
2526
+
2527
+
2528
+
2529
+
2530
+
2531
+
2532
+
2533
+
2534
+
2535
+
2536
+
2537
+
2538
+
2539
+
2540
+
2541
+
2542
+
2543
+
2544
+
2545
+ 𠮶
data/librispeech_pc_test_clean_cross_sentence.lst CHANGED
The diff for this file is too large to render. See raw diff
 
finetune-cli.py CHANGED
@@ -1,57 +1,42 @@
1
  import argparse
2
- from model import CFM, UNetT, DiT, Trainer
3
  from model.utils import get_tokenizer
4
  from model.dataset import load_dataset
5
  from cached_path import cached_path
6
- import shutil
7
- import os
8
-
9
  # -------------------------- Dataset Settings --------------------------- #
10
  target_sample_rate = 24000
11
  n_mel_channels = 100
12
  hop_length = 256
13
 
 
 
14
 
15
  # -------------------------- Argument Parsing --------------------------- #
16
  def parse_args():
17
- parser = argparse.ArgumentParser(description="Train CFM Model")
18
-
19
- parser.add_argument(
20
- "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
21
- )
22
- parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
23
- parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for training")
24
- parser.add_argument("--batch_size_per_gpu", type=int, default=256, help="Batch size per GPU")
25
- parser.add_argument(
26
- "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
27
- )
28
- parser.add_argument("--max_samples", type=int, default=16, help="Max sequences per batch")
29
- parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
30
- parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
31
- parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
32
- parser.add_argument("--num_warmup_updates", type=int, default=5, help="Warmup steps")
33
- parser.add_argument("--save_per_updates", type=int, default=10, help="Save checkpoint every X steps")
34
- parser.add_argument("--last_per_steps", type=int, default=10, help="Save last checkpoint every X steps")
35
- parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
36
-
37
- parser.add_argument(
38
- "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
39
- )
40
- parser.add_argument(
41
- "--tokenizer_path",
42
- type=str,
43
- default=None,
44
- help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
45
- )
46
-
47
  return parser.parse_args()
48
 
49
-
50
  # -------------------------- Training Settings -------------------------- #
51
 
52
-
53
  def main():
54
  args = parse_args()
 
55
 
56
  # Model parameters based on experiment name
57
  if args.exp_name == "F5TTS_Base":
@@ -59,31 +44,24 @@ def main():
59
  model_cls = DiT
60
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
61
  if args.finetune:
62
- ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
63
  elif args.exp_name == "E2TTS_Base":
64
  wandb_resume_id = None
65
  model_cls = UNetT
66
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
67
  if args.finetune:
68
- ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
69
-
70
  if args.finetune:
71
- path_ckpt = os.path.join("ckpts", args.dataset_name)
72
- if not os.path.isdir(path_ckpt):
73
- os.makedirs(path_ckpt, exist_ok=True)
74
- shutil.copy2(ckpt_path, os.path.join(path_ckpt, os.path.basename(ckpt_path)))
75
-
76
- checkpoint_path = os.path.join("ckpts", args.dataset_name)
77
-
78
- # Use the tokenizer and tokenizer_path provided in the command line arguments
79
- tokenizer = args.tokenizer
80
- if tokenizer == "custom":
81
- if not args.tokenizer_path:
82
- raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
83
- tokenizer_path = args.tokenizer_path
84
- else:
85
- tokenizer_path = args.dataset_name
86
-
87
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
88
 
89
  mel_spec_kwargs = dict(
@@ -93,7 +71,11 @@ def main():
93
  )
94
 
95
  e2tts = CFM(
96
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
 
 
 
 
97
  mel_spec_kwargs=mel_spec_kwargs,
98
  vocab_char_map=vocab_char_map,
99
  )
@@ -117,11 +99,10 @@ def main():
117
  )
118
 
119
  train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
120
- trainer.train(
121
- train_dataset,
122
- resumable_with_seed=666, # seed for shuffling dataset
123
- )
124
 
125
 
126
- if __name__ == "__main__":
127
  main()
 
1
  import argparse
2
+ from model import CFM, UNetT, DiT, MMDiT, Trainer
3
  from model.utils import get_tokenizer
4
  from model.dataset import load_dataset
5
  from cached_path import cached_path
6
+ import shutil,os
 
 
7
  # -------------------------- Dataset Settings --------------------------- #
8
  target_sample_rate = 24000
9
  n_mel_channels = 100
10
  hop_length = 256
11
 
12
+ tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
13
+ tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
14
 
15
  # -------------------------- Argument Parsing --------------------------- #
16
  def parse_args():
17
+ parser = argparse.ArgumentParser(description='Train CFM Model')
18
+
19
+ parser.add_argument('--exp_name', type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"],help='Experiment name')
20
+ parser.add_argument('--dataset_name', type=str, default="Emilia_ZH_EN", help='Name of the dataset to use')
21
+ parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate for training')
22
+ parser.add_argument('--batch_size_per_gpu', type=int, default=256, help='Batch size per GPU')
23
+ parser.add_argument('--batch_size_type', type=str, default="frame", choices=["frame", "sample"],help='Batch size type')
24
+ parser.add_argument('--max_samples', type=int, default=16, help='Max sequences per batch')
25
+ parser.add_argument('--grad_accumulation_steps', type=int, default=1,help='Gradient accumulation steps')
26
+ parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
27
+ parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
28
+ parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps')
29
+ parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps')
30
+ parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps')
31
+ parser.add_argument('--finetune', type=bool, default=True, help='Use Finetune')
32
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  return parser.parse_args()
34
 
 
35
  # -------------------------- Training Settings -------------------------- #
36
 
 
37
  def main():
38
  args = parse_args()
39
+
40
 
41
  # Model parameters based on experiment name
42
  if args.exp_name == "F5TTS_Base":
 
44
  model_cls = DiT
45
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
46
  if args.finetune:
47
+ ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
48
  elif args.exp_name == "E2TTS_Base":
49
  wandb_resume_id = None
50
  model_cls = UNetT
51
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
52
  if args.finetune:
53
+ ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
54
+
55
  if args.finetune:
56
+ path_ckpt = os.path.join("ckpts",args.dataset_name)
57
+ if os.path.isdir(path_ckpt)==False:
58
+ os.makedirs(path_ckpt,exist_ok=True)
59
+ shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
60
+
61
+ checkpoint_path=os.path.join("ckpts",args.dataset_name)
62
+
63
+ # Use the dataset_name provided in the command line
64
+ tokenizer_path = args.dataset_name if tokenizer != "custom" else tokenizer_path
 
 
 
 
 
 
 
65
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
66
 
67
  mel_spec_kwargs = dict(
 
71
  )
72
 
73
  e2tts = CFM(
74
+ transformer=model_cls(
75
+ **model_cfg,
76
+ text_num_embeds=vocab_size,
77
+ mel_dim=n_mel_channels
78
+ ),
79
  mel_spec_kwargs=mel_spec_kwargs,
80
  vocab_char_map=vocab_char_map,
81
  )
 
99
  )
100
 
101
  train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
102
+ trainer.train(train_dataset,
103
+ resumable_with_seed=666 # seed for shuffling dataset
104
+ )
 
105
 
106
 
107
+ if __name__ == '__main__':
108
  main()
finetune_gradio.py CHANGED
@@ -1,12 +1,8 @@
1
- import os
2
- import sys
3
 
4
- import tempfile
5
- import random
6
  from transformers import pipeline
7
  import gradio as gr
8
  import torch
9
- import gc
10
  import click
11
  import torchaudio
12
  from glob import glob
@@ -23,43 +19,35 @@ import psutil
23
  import platform
24
  import subprocess
25
  from datasets.arrow_writer import ArrowWriter
26
- from datasets import Dataset as Dataset_
27
- from api import F5TTS
28
 
 
29
 
30
- training_process = None
31
  system = platform.system()
32
  python_executable = sys.executable or "python"
33
- tts_api = None
34
- last_checkpoint = ""
35
- last_device = ""
36
 
37
- path_data = "data"
38
 
39
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
 
 
 
40
 
41
  pipe = None
42
 
43
-
44
  # Load metadata
45
  def get_audio_duration(audio_path):
46
  """Calculate the duration of an audio file."""
47
  audio, sample_rate = torchaudio.load(audio_path)
48
- num_channels = audio.shape[0]
49
  return audio.shape[1] / (sample_rate * num_channels)
50
 
51
-
52
  def clear_text(text):
53
  """Clean and prepare text by lowering the case and stripping whitespace."""
54
  return text.lower().strip()
55
 
56
-
57
- def get_rms(
58
- y,
59
- frame_length=2048,
60
- hop_length=512,
61
- pad_mode="constant",
62
- ): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
63
  padding = (int(frame_length // 2), int(frame_length // 2))
64
  y = np.pad(y, padding, mode=pad_mode)
65
 
@@ -86,8 +74,7 @@ def get_rms(
86
 
87
  return np.sqrt(power)
88
 
89
-
90
- class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
91
  def __init__(
92
  self,
93
  sr: int,
@@ -98,9 +85,13 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.
98
  max_sil_kept: int = 2000,
99
  ):
100
  if not min_length >= min_interval >= hop_size:
101
- raise ValueError("The following condition must be satisfied: min_length >= min_interval >= hop_size")
 
 
102
  if not max_sil_kept >= hop_size:
103
- raise ValueError("The following condition must be satisfied: max_sil_kept >= hop_size")
 
 
104
  min_interval = sr * min_interval / 1000
105
  self.threshold = 10 ** (threshold / 20.0)
106
  self.hop_size = round(sr * hop_size / 1000)
@@ -111,9 +102,13 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.
111
 
112
  def _apply_slice(self, waveform, begin, end):
113
  if len(waveform.shape) > 1:
114
- return waveform[:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)]
 
 
115
  else:
116
- return waveform[begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)]
 
 
117
 
118
  # @timeit
119
  def slice(self, waveform):
@@ -123,7 +118,9 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.
123
  samples = waveform
124
  if samples.shape[0] <= self.min_length:
125
  return [waveform]
126
- rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
 
 
127
  sil_tags = []
128
  silence_start = None
129
  clip_start = 0
@@ -139,7 +136,10 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.
139
  continue
140
  # Clear recorded silence start if interval is not enough or clip is too short
141
  is_leading_silence = silence_start == 0 and i > self.max_sil_kept
142
- need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
 
 
 
143
  if not is_leading_silence and not need_slice_middle:
144
  silence_start = None
145
  continue
@@ -152,10 +152,21 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.
152
  sil_tags.append((pos, pos))
153
  clip_start = pos
154
  elif i - silence_start <= self.max_sil_kept * 2:
155
- pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
 
 
156
  pos += i - self.max_sil_kept
157
- pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
158
- pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
 
 
 
 
 
 
 
 
 
159
  if silence_start == 0:
160
  sil_tags.append((0, pos_r))
161
  clip_start = pos_r
@@ -163,8 +174,17 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.
163
  sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
164
  clip_start = max(pos_r, pos)
165
  else:
166
- pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
167
- pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
 
 
 
 
 
 
 
 
 
168
  if silence_start == 0:
169
  sil_tags.append((0, pos_r))
170
  else:
@@ -173,39 +193,33 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.
173
  silence_start = None
174
  # Deal with trailing silence.
175
  total_frames = rms_list.shape[0]
176
- if silence_start is not None and total_frames - silence_start >= self.min_interval:
 
 
 
177
  silence_end = min(total_frames, silence_start + self.max_sil_kept)
178
  pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
179
  sil_tags.append((pos, total_frames + 1))
180
  # Apply and return slices.
181
  ####音频+起始时间+终止时间
182
  if len(sil_tags) == 0:
183
- return [[waveform, 0, int(total_frames * self.hop_size)]]
184
  else:
185
  chunks = []
186
  if sil_tags[0][0] > 0:
187
- chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]), 0, int(sil_tags[0][0] * self.hop_size)])
188
  for i in range(len(sil_tags) - 1):
189
  chunks.append(
190
- [
191
- self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),
192
- int(sil_tags[i][1] * self.hop_size),
193
- int(sil_tags[i + 1][0] * self.hop_size),
194
- ]
195
  )
196
  if sil_tags[-1][1] < total_frames:
197
  chunks.append(
198
- [
199
- self._apply_slice(waveform, sil_tags[-1][1], total_frames),
200
- int(sil_tags[-1][1] * self.hop_size),
201
- int(total_frames * self.hop_size),
202
- ]
203
  )
204
  return chunks
205
 
206
-
207
- # terminal
208
- def terminate_process_tree(pid, including_parent=True):
209
  try:
210
  parent = psutil.Process(pid)
211
  except psutil.NoSuchProcess:
@@ -224,7 +238,6 @@ def terminate_process_tree(pid, including_parent=True):
224
  except OSError:
225
  pass
226
 
227
-
228
  def terminate_process(pid):
229
  if system == "Windows":
230
  cmd = f"taskkill /t /f /pid {pid}"
@@ -232,160 +245,130 @@ def terminate_process(pid):
232
  else:
233
  terminate_process_tree(pid)
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
- def start_training(
237
- dataset_name="",
238
- exp_name="F5TTS_Base",
239
- learning_rate=1e-4,
240
- batch_size_per_gpu=400,
241
- batch_size_type="frame",
242
- max_samples=64,
243
- grad_accumulation_steps=1,
244
- max_grad_norm=1.0,
245
- epochs=11,
246
- num_warmup_updates=200,
247
- save_per_updates=400,
248
- last_per_steps=800,
249
- finetune=True,
250
- ):
251
- global training_process, tts_api
252
-
253
- if tts_api is not None:
254
- del tts_api
255
- gc.collect()
256
- torch.cuda.empty_cache()
257
- tts_api = None
258
 
259
  path_project = os.path.join(path_data, dataset_name + "_pinyin")
260
 
261
- if not os.path.isdir(path_project):
262
- yield (
263
- f"There is not project with name {dataset_name}",
264
- gr.update(interactive=True),
265
- gr.update(interactive=False),
266
- )
267
  return
268
 
269
- file_raw = os.path.join(path_project, "raw.arrow")
270
- if not os.path.isfile(file_raw):
271
- yield f"There is no file {file_raw}", gr.update(interactive=True), gr.update(interactive=False)
272
- return
273
 
274
  # Check if a training process is already running
275
  if training_process is not None:
276
- return "Train run already!", gr.update(interactive=False), gr.update(interactive=True)
277
 
278
- yield "start train", gr.update(interactive=False), gr.update(interactive=False)
279
 
280
  # Command to run the training script with the specified arguments
281
- cmd = (
282
- f"accelerate launch finetune-cli.py --exp_name {exp_name} "
283
- f"--learning_rate {learning_rate} "
284
- f"--batch_size_per_gpu {batch_size_per_gpu} "
285
- f"--batch_size_type {batch_size_type} "
286
- f"--max_samples {max_samples} "
287
- f"--grad_accumulation_steps {grad_accumulation_steps} "
288
- f"--max_grad_norm {max_grad_norm} "
289
- f"--epochs {epochs} "
290
- f"--num_warmup_updates {num_warmup_updates} "
291
- f"--save_per_updates {save_per_updates} "
292
- f"--last_per_steps {last_per_steps} "
293
- f"--dataset_name {dataset_name}"
294
- )
295
- if finetune:
296
- cmd += f" --finetune {finetune}"
297
-
298
- print(cmd)
299
-
300
  try:
301
- # Start the training process
302
- training_process = subprocess.Popen(cmd, shell=True)
303
-
304
- time.sleep(5)
305
- yield "train start", gr.update(interactive=False), gr.update(interactive=True)
306
 
307
- # Wait for the training process to finish
308
- training_process.wait()
309
- time.sleep(1)
310
-
311
- if training_process is None:
312
- text_info = "train stop"
313
- else:
314
- text_info = "train complete !"
 
 
 
315
 
316
  except Exception as e: # Catch all exceptions
317
  # Ensure that we reset the training process variable in case of an error
318
- text_info = f"An error occurred: {str(e)}"
319
-
320
- training_process = None
321
-
322
- yield text_info, gr.update(interactive=True), gr.update(interactive=False)
323
 
 
324
 
325
  def stop_training():
326
  global training_process
327
- if training_process is None:
328
- return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
329
  terminate_process_tree(training_process.pid)
330
  training_process = None
331
- return "train stop", gr.update(interactive=True), gr.update(interactive=False)
332
-
333
 
334
  def create_data_project(name):
335
- name += "_pinyin"
336
- os.makedirs(os.path.join(path_data, name), exist_ok=True)
337
- os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
338
-
339
-
340
- def transcribe(file_audio, language="english"):
341
  global pipe
342
 
343
  if pipe is None:
344
- pipe = pipeline(
345
- "automatic-speech-recognition",
346
- model="openai/whisper-large-v3-turbo",
347
- torch_dtype=torch.float16,
348
- device=device,
349
- )
350
 
351
  text_transcribe = pipe(
352
  file_audio,
353
  chunk_length_s=30,
354
  batch_size=128,
355
- generate_kwargs={"task": "transcribe", "language": language},
356
  return_timestamps=False,
357
  )["text"].strip()
358
  return text_transcribe
359
 
 
 
 
 
 
 
360
 
361
- def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
362
- name_project += "_pinyin"
363
- path_project = os.path.join(path_data, name_project)
364
- path_dataset = os.path.join(path_project, "dataset")
365
- path_project_wavs = os.path.join(path_project, "wavs")
366
- file_metadata = os.path.join(path_project, "metadata.csv")
367
-
368
- if audio_files is None:
369
- return "You need to load an audio file."
370
 
371
  if os.path.isdir(path_project_wavs):
372
- shutil.rmtree(path_project_wavs)
373
 
374
  if os.path.isfile(file_metadata):
375
- os.remove(file_metadata)
376
-
377
- os.makedirs(path_project_wavs, exist_ok=True)
378
 
 
 
379
  if user:
380
- file_audios = [
381
- file
382
- for format in ("*.wav", "*.ogg", "*.opus", "*.mp3", "*.flac")
383
- for file in glob(os.path.join(path_dataset, format))
384
- ]
385
- if file_audios == []:
386
- return "No audio file was found in the dataset."
387
  else:
388
- file_audios = audio_files
 
389
 
390
  alpha = 0.5
391
  _max = 1.0
@@ -393,213 +376,179 @@ def transcribe_all(name_project, audio_files, language, user=False, progress=gr.
393
 
394
  num = 0
395
  error_num = 0
396
- data = ""
397
- for file_audio in progress.tqdm(file_audios, desc="transcribe files", total=len((file_audios))):
398
- audio, _ = librosa.load(file_audio, sr=24000, mono=True)
399
-
400
- list_slicer = slicer.slice(audio)
401
- for chunk, start, end in progress.tqdm(list_slicer, total=len(list_slicer), desc="slicer files"):
 
 
402
  name_segment = os.path.join(f"segment_{num}")
403
- file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav")
404
-
405
  tmp_max = np.abs(chunk).max()
406
- if tmp_max > 1:
407
- chunk /= tmp_max
408
  chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
409
- wavfile.write(file_segment, 24000, (chunk * 32767).astype(np.int16))
410
-
411
  try:
412
- text = transcribe(file_segment, language)
413
- text = text.lower().strip().replace('"', "")
414
 
415
- data += f"{name_segment}|{text}\n"
416
 
417
- num += 1
418
- except: # noqa: E722
419
- error_num += 1
420
 
421
- with open(file_metadata, "w", encoding="utf-8") as f:
422
  f.write(data)
423
-
424
- if error_num != []:
425
- error_text = f"\nerror files : {error_num}"
426
  else:
427
- error_text = ""
428
-
429
  return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
430
 
431
-
432
  def format_seconds_to_hms(seconds):
433
  hours = int(seconds / 3600)
434
  minutes = int((seconds % 3600) / 60)
435
  seconds = seconds % 60
436
  return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
437
 
438
-
439
- def create_metadata(name_project, progress=gr.Progress()):
440
- name_project += "_pinyin"
441
- path_project = os.path.join(path_data, name_project)
442
- path_project_wavs = os.path.join(path_project, "wavs")
443
- file_metadata = os.path.join(path_project, "metadata.csv")
444
- file_raw = os.path.join(path_project, "raw.arrow")
445
- file_duration = os.path.join(path_project, "duration.json")
446
- file_vocab = os.path.join(path_project, "vocab.txt")
447
-
448
- if not os.path.isfile(file_metadata):
449
- return "The file was not found in " + file_metadata
450
-
451
- with open(file_metadata, "r", encoding="utf-8") as f:
452
- data = f.read()
453
-
454
- audio_path_list = []
455
- text_list = []
456
- duration_list = []
457
-
458
- count = data.split("\n")
459
- lenght = 0
460
- result = []
461
- error_files = []
462
- for line in progress.tqdm(data.split("\n"), total=count):
463
- sp_line = line.split("|")
464
- if len(sp_line) != 2:
465
- continue
466
- name_audio, text = sp_line[:2]
467
 
468
  file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
469
 
470
- if not os.path.isfile(file_audio):
471
  error_files.append(file_audio)
472
  continue
473
 
474
  duraction = get_audio_duration(file_audio)
475
- if duraction < 2 and duraction > 15:
476
- continue
477
- if len(text) < 4:
478
- continue
479
 
480
  text = clear_text(text)
481
- text = convert_char_to_pinyin([text], polyphone=True)[0]
482
 
483
  audio_path_list.append(file_audio)
484
  duration_list.append(duraction)
485
  text_list.append(text)
486
-
487
  result.append({"audio_path": file_audio, "text": text, "duration": duraction})
488
 
489
- lenght += duraction
490
 
491
- if duration_list == []:
492
- error_files_text = "\n".join(error_files)
493
  return f"Error: No audio files found in the specified path : \n{error_files_text}"
494
-
495
- min_second = round(min(duration_list), 2)
496
- max_second = round(max(duration_list), 2)
497
 
498
  with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
499
- for line in progress.tqdm(result, total=len(result), desc="prepare data"):
500
  writer.write(line)
501
 
502
- with open(file_duration, "w", encoding="utf-8") as f:
503
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
504
-
505
- file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
506
- if not os.path.isfile(file_vocab_finetune):
507
- return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
508
  shutil.copy2(file_vocab_finetune, file_vocab)
509
-
510
- if error_files != []:
511
- error_text = "error files\n" + "\n".join(error_files)
512
  else:
513
- error_text = ""
514
-
515
  return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
516
 
517
-
518
  def check_user(value):
519
- return gr.update(visible=not value), gr.update(visible=value)
520
-
521
-
522
- def calculate_train(
523
- name_project,
524
- batch_size_type,
525
- max_samples,
526
- learning_rate,
527
- num_warmup_updates,
528
- save_per_updates,
529
- last_per_steps,
530
- finetune,
531
- ):
532
- name_project += "_pinyin"
533
- path_project = os.path.join(path_data, name_project)
534
- file_duraction = os.path.join(path_project, "duration.json")
535
-
536
- if not os.path.isfile(file_duraction):
537
- return (
538
- 1000,
539
- max_samples,
540
- num_warmup_updates,
541
- save_per_updates,
542
- last_per_steps,
543
- "project not found !",
544
- learning_rate,
545
- )
546
 
547
- with open(file_duraction, "r") as file:
548
- data = json.load(file)
 
 
549
 
550
- duration_list = data["duration"]
 
 
 
551
 
552
  samples = len(duration_list)
553
 
554
  if torch.cuda.is_available():
555
  gpu_properties = torch.cuda.get_device_properties(0)
556
- total_memory = gpu_properties.total_memory / (1024**3)
557
  elif torch.backends.mps.is_available():
558
- total_memory = psutil.virtual_memory().available / (1024**3)
559
-
560
- if batch_size_type == "frame":
561
- batch = int(total_memory * 0.5)
562
- batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
563
- batch_size_per_gpu = int(38400 / batch)
564
- else:
565
- batch_size_per_gpu = int(total_memory / 8)
566
- batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
567
- batch = batch_size_per_gpu
568
-
569
- if batch_size_per_gpu <= 0:
570
- batch_size_per_gpu = 1
571
-
572
- if samples < 64:
573
- max_samples = int(samples * 0.25)
574
  else:
575
- max_samples = 64
576
-
577
- num_warmup_updates = int(samples * 0.05)
578
- save_per_updates = int(samples * 0.10)
579
- last_per_steps = int(save_per_updates * 5)
580
-
 
 
 
 
 
 
 
581
  max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
582
  num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
583
  save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
584
  last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
585
 
586
- if finetune:
587
- learning_rate = 1e-5
588
- else:
589
- learning_rate = 7.5e-5
590
-
591
- return batch_size_per_gpu, max_samples, num_warmup_updates, save_per_updates, last_per_steps, samples, learning_rate
592
 
 
593
 
594
  def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None:
595
  try:
596
  checkpoint = torch.load(checkpoint_path)
597
  print("Original Checkpoint Keys:", checkpoint.keys())
598
-
599
- ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
600
 
601
  if ema_model_state_dict is not None:
602
- new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
603
  torch.save(new_checkpoint, new_checkpoint_path)
604
  return f"New checkpoint saved at: {new_checkpoint_path}"
605
  else:
@@ -608,136 +557,65 @@ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -
608
  except Exception as e:
609
  return f"An error occurred: {e}"
610
 
611
-
612
  def vocab_check(project_name):
613
  name_project = project_name + "_pinyin"
614
  path_project = os.path.join(path_data, name_project)
615
 
616
  file_metadata = os.path.join(path_project, "metadata.csv")
617
-
618
- file_vocab = "data/Emilia_ZH_EN_pinyin/vocab.txt"
619
- if not os.path.isfile(file_vocab):
620
  return f"the file {file_vocab} not found !"
621
-
622
- with open(file_vocab, "r", encoding="utf-8") as f:
623
- data = f.read()
624
 
625
  vocab = data.split("\n")
626
 
627
- if not os.path.isfile(file_metadata):
628
  return f"the file {file_metadata} not found !"
629
 
630
- with open(file_metadata, "r", encoding="utf-8") as f:
631
- data = f.read()
632
 
633
- miss_symbols = []
634
- miss_symbols_keep = {}
635
  for item in data.split("\n"):
636
- sp = item.split("|")
637
- if len(sp) != 2:
638
- continue
639
 
640
- text = sp[1].lower().strip()
 
 
 
641
 
642
- for t in text:
643
- if t not in vocab and t not in miss_symbols_keep:
644
- miss_symbols.append(t)
645
- miss_symbols_keep[t] = t
646
- if miss_symbols == []:
647
- info = "You can train using your language !"
648
- else:
649
- info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
650
 
651
  return info
652
 
653
 
654
- def get_random_sample_prepare(project_name):
655
- name_project = project_name + "_pinyin"
656
- path_project = os.path.join(path_data, name_project)
657
- file_arrow = os.path.join(path_project, "raw.arrow")
658
- if not os.path.isfile(file_arrow):
659
- return "", None
660
- dataset = Dataset_.from_file(file_arrow)
661
- random_sample = dataset.shuffle(seed=random.randint(0, 1000)).select([0])
662
- text = "[" + " , ".join(["' " + t + " '" for t in random_sample["text"][0]]) + "]"
663
- audio_path = random_sample["audio_path"][0]
664
- return text, audio_path
665
-
666
-
667
- def get_random_sample_transcribe(project_name):
668
- name_project = project_name + "_pinyin"
669
- path_project = os.path.join(path_data, name_project)
670
- file_metadata = os.path.join(path_project, "metadata.csv")
671
- if not os.path.isfile(file_metadata):
672
- return "", None
673
-
674
- data = ""
675
- with open(file_metadata, "r", encoding="utf-8") as f:
676
- data = f.read()
677
-
678
- list_data = []
679
- for item in data.split("\n"):
680
- sp = item.split("|")
681
- if len(sp) != 2:
682
- continue
683
- list_data.append([os.path.join(path_project, "wavs", sp[0] + ".wav"), sp[1]])
684
 
685
- if list_data == []:
686
- return "", None
687
-
688
- random_item = random.choice(list_data)
689
-
690
- return random_item[1], random_item[0]
691
-
692
-
693
- def get_random_sample_infer(project_name):
694
- text, audio = get_random_sample_transcribe(project_name)
695
- return (
696
- text,
697
- text,
698
- audio,
699
- )
700
-
701
-
702
- def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
703
- global last_checkpoint, last_device, tts_api
704
-
705
- if not os.path.isfile(file_checkpoint):
706
- return None
707
-
708
- if training_process is not None:
709
- device_test = "cpu"
710
- else:
711
- device_test = None
712
-
713
- if last_checkpoint != file_checkpoint or last_device != device_test:
714
- if last_checkpoint != file_checkpoint:
715
- last_checkpoint = file_checkpoint
716
- if last_device != device_test:
717
- last_device = device_test
718
-
719
- tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test)
720
-
721
- print("update", device_test, file_checkpoint)
722
 
723
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
724
- tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
725
- return f.name
 
 
726
 
 
 
727
 
728
- with gr.Blocks() as app:
729
- with gr.Row():
730
- project_name = gr.Textbox(label="project name", value="my_speak")
731
- bt_create = gr.Button("create new project")
732
 
733
- bt_create.click(fn=create_data_project, inputs=[project_name])
734
 
735
- with gr.Tabs():
736
- with gr.TabItem("transcribe Data"):
737
- ch_manual = gr.Checkbox(label="user", value=False)
738
 
739
- mark_info_transcribe = gr.Markdown(
740
- """```plaintext
741
  Place your 'wavs' folder and 'metadata.csv' file in the {your_project_name}' directory.
742
 
743
  my_speak/
@@ -746,36 +624,18 @@ with gr.Blocks() as app:
746
  ├── audio1.wav
747
  └── audio2.wav
748
  ...
749
- ```""",
750
- visible=False,
751
- )
752
-
753
- audio_speaker = gr.File(label="voice", type="filepath", file_count="multiple")
754
- txt_lang = gr.Text(label="Language", value="english")
755
- bt_transcribe = bt_create = gr.Button("transcribe")
756
- txt_info_transcribe = gr.Text(label="info", value="")
757
- bt_transcribe.click(
758
- fn=transcribe_all,
759
- inputs=[project_name, audio_speaker, txt_lang, ch_manual],
760
- outputs=[txt_info_transcribe],
761
- )
762
- ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
763
-
764
- random_sample_transcribe = gr.Button("random sample")
765
-
766
- with gr.Row():
767
- random_text_transcribe = gr.Text(label="Text")
768
- random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
769
-
770
- random_sample_transcribe.click(
771
- fn=get_random_sample_transcribe,
772
- inputs=[project_name],
773
- outputs=[random_text_transcribe, random_audio_transcribe],
774
- )
775
-
776
- with gr.TabItem("prepare Data"):
777
- gr.Markdown(
778
- """```plaintext
779
  place all your wavs folder and your metadata.csv file in {your name project}
780
  my_speak/
781
 
@@ -792,136 +652,61 @@ with gr.Blocks() as app:
792
  audio2|text1
793
  ...
794
 
795
- ```"""
796
- )
797
-
798
- bt_prepare = bt_create = gr.Button("prepare")
799
- txt_info_prepare = gr.Text(label="info", value="")
800
- bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
801
-
802
- random_sample_prepare = gr.Button("random sample")
803
-
804
- with gr.Row():
805
- random_text_prepare = gr.Text(label="Pinyin")
806
- random_audio_prepare = gr.Audio(label="Audio", type="filepath")
807
-
808
- random_sample_prepare.click(
809
- fn=get_random_sample_prepare, inputs=[project_name], outputs=[random_text_prepare, random_audio_prepare]
810
- )
811
-
812
- with gr.TabItem("train Data"):
813
- with gr.Row():
814
- bt_calculate = bt_create = gr.Button("Auto Settings")
815
- ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
816
- lb_samples = gr.Label(label="samples")
817
- batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
818
-
819
- with gr.Row():
820
- exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
821
- learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
822
-
823
- with gr.Row():
824
- batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
825
- max_samples = gr.Number(label="Max Samples", value=64)
826
-
827
- with gr.Row():
828
- grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
829
- max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
830
-
831
- with gr.Row():
832
- epochs = gr.Number(label="Epochs", value=10)
833
- num_warmup_updates = gr.Number(label="Warmup Updates", value=5)
834
-
835
- with gr.Row():
836
- save_per_updates = gr.Number(label="Save per Updates", value=10)
837
- last_per_steps = gr.Number(label="Last per Steps", value=50)
838
-
839
- with gr.Row():
840
- start_button = gr.Button("Start Training")
841
- stop_button = gr.Button("Stop Training", interactive=False)
842
-
843
- txt_info_train = gr.Text(label="info", value="")
844
- start_button.click(
845
- fn=start_training,
846
- inputs=[
847
- project_name,
848
- exp_name,
849
- learning_rate,
850
- batch_size_per_gpu,
851
- batch_size_type,
852
- max_samples,
853
- grad_accumulation_steps,
854
- max_grad_norm,
855
- epochs,
856
- num_warmup_updates,
857
- save_per_updates,
858
- last_per_steps,
859
- ch_finetune,
860
- ],
861
- outputs=[txt_info_train, start_button, stop_button],
862
- )
863
- stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button])
864
- bt_calculate.click(
865
- fn=calculate_train,
866
- inputs=[
867
- project_name,
868
- batch_size_type,
869
- max_samples,
870
- learning_rate,
871
- num_warmup_updates,
872
- save_per_updates,
873
- last_per_steps,
874
- ch_finetune,
875
- ],
876
- outputs=[
877
- batch_size_per_gpu,
878
- max_samples,
879
- num_warmup_updates,
880
- save_per_updates,
881
- last_per_steps,
882
- lb_samples,
883
- learning_rate,
884
- ],
885
- )
886
-
887
- with gr.TabItem("reduse checkpoint"):
888
- txt_path_checkpoint = gr.Text(label="path checkpoint :")
889
- txt_path_checkpoint_small = gr.Text(label="path output :")
890
- txt_info_reduse = gr.Text(label="info", value="")
891
- reduse_button = gr.Button("reduse")
892
- reduse_button.click(
893
- fn=extract_and_save_ema_model,
894
- inputs=[txt_path_checkpoint, txt_path_checkpoint_small],
895
- outputs=[txt_info_reduse],
896
- )
897
-
898
- with gr.TabItem("vocab check experiment"):
899
- check_button = gr.Button("check vocab")
900
- txt_info_check = gr.Text(label="info", value="")
901
- check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
902
-
903
- with gr.TabItem("test model"):
904
- exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
905
- nfe_step = gr.Number(label="n_step", value=32)
906
- file_checkpoint_pt = gr.Textbox(label="Checkpoint", value="")
907
-
908
- random_sample_infer = gr.Button("random sample")
909
-
910
- ref_text = gr.Textbox(label="ref text")
911
- ref_audio = gr.Audio(label="audio ref", type="filepath")
912
- gen_text = gr.Textbox(label="gen text")
913
- random_sample_infer.click(
914
- fn=get_random_sample_infer, inputs=[project_name], outputs=[ref_text, gen_text, ref_audio]
915
- )
916
- check_button_infer = gr.Button("infer")
917
- gen_audio = gr.Audio(label="audio gen", type="filepath")
918
-
919
- check_button_infer.click(
920
- fn=infer,
921
- inputs=[file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step],
922
- outputs=[gen_audio],
923
- )
924
-
925
 
926
  @click.command()
927
  @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
@@ -936,9 +721,10 @@ with gr.Blocks() as app:
936
  @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
937
  def main(port, host, share, api):
938
  global app
939
- print("Starting app...")
940
- app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
941
-
 
942
 
943
  if __name__ == "__main__":
944
  main()
 
1
+ import os,sys
 
2
 
 
 
3
  from transformers import pipeline
4
  import gradio as gr
5
  import torch
 
6
  import click
7
  import torchaudio
8
  from glob import glob
 
19
  import platform
20
  import subprocess
21
  from datasets.arrow_writer import ArrowWriter
 
 
22
 
23
+ import json
24
 
25
+ training_process = None
26
  system = platform.system()
27
  python_executable = sys.executable or "python"
 
 
 
28
 
29
+ path_data="data"
30
 
31
+ device = (
32
+ "cuda"
33
+ if torch.cuda.is_available()
34
+ else "mps" if torch.backends.mps.is_available() else "cpu"
35
+ )
36
 
37
  pipe = None
38
 
 
39
  # Load metadata
40
  def get_audio_duration(audio_path):
41
  """Calculate the duration of an audio file."""
42
  audio, sample_rate = torchaudio.load(audio_path)
43
+ num_channels = audio.shape[0]
44
  return audio.shape[1] / (sample_rate * num_channels)
45
 
 
46
  def clear_text(text):
47
  """Clean and prepare text by lowering the case and stripping whitespace."""
48
  return text.lower().strip()
49
 
50
+ def get_rms(y,frame_length=2048,hop_length=512,pad_mode="constant",): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
 
 
 
 
 
 
51
  padding = (int(frame_length // 2), int(frame_length // 2))
52
  y = np.pad(y, padding, mode=pad_mode)
53
 
 
74
 
75
  return np.sqrt(power)
76
 
77
+ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
 
78
  def __init__(
79
  self,
80
  sr: int,
 
85
  max_sil_kept: int = 2000,
86
  ):
87
  if not min_length >= min_interval >= hop_size:
88
+ raise ValueError(
89
+ "The following condition must be satisfied: min_length >= min_interval >= hop_size"
90
+ )
91
  if not max_sil_kept >= hop_size:
92
+ raise ValueError(
93
+ "The following condition must be satisfied: max_sil_kept >= hop_size"
94
+ )
95
  min_interval = sr * min_interval / 1000
96
  self.threshold = 10 ** (threshold / 20.0)
97
  self.hop_size = round(sr * hop_size / 1000)
 
102
 
103
  def _apply_slice(self, waveform, begin, end):
104
  if len(waveform.shape) > 1:
105
+ return waveform[
106
+ :, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)
107
+ ]
108
  else:
109
+ return waveform[
110
+ begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)
111
+ ]
112
 
113
  # @timeit
114
  def slice(self, waveform):
 
118
  samples = waveform
119
  if samples.shape[0] <= self.min_length:
120
  return [waveform]
121
+ rms_list = get_rms(
122
+ y=samples, frame_length=self.win_size, hop_length=self.hop_size
123
+ ).squeeze(0)
124
  sil_tags = []
125
  silence_start = None
126
  clip_start = 0
 
136
  continue
137
  # Clear recorded silence start if interval is not enough or clip is too short
138
  is_leading_silence = silence_start == 0 and i > self.max_sil_kept
139
+ need_slice_middle = (
140
+ i - silence_start >= self.min_interval
141
+ and i - clip_start >= self.min_length
142
+ )
143
  if not is_leading_silence and not need_slice_middle:
144
  silence_start = None
145
  continue
 
152
  sil_tags.append((pos, pos))
153
  clip_start = pos
154
  elif i - silence_start <= self.max_sil_kept * 2:
155
+ pos = rms_list[
156
+ i - self.max_sil_kept : silence_start + self.max_sil_kept + 1
157
+ ].argmin()
158
  pos += i - self.max_sil_kept
159
+ pos_l = (
160
+ rms_list[
161
+ silence_start : silence_start + self.max_sil_kept + 1
162
+ ].argmin()
163
+ + silence_start
164
+ )
165
+ pos_r = (
166
+ rms_list[i - self.max_sil_kept : i + 1].argmin()
167
+ + i
168
+ - self.max_sil_kept
169
+ )
170
  if silence_start == 0:
171
  sil_tags.append((0, pos_r))
172
  clip_start = pos_r
 
174
  sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
175
  clip_start = max(pos_r, pos)
176
  else:
177
+ pos_l = (
178
+ rms_list[
179
+ silence_start : silence_start + self.max_sil_kept + 1
180
+ ].argmin()
181
+ + silence_start
182
+ )
183
+ pos_r = (
184
+ rms_list[i - self.max_sil_kept : i + 1].argmin()
185
+ + i
186
+ - self.max_sil_kept
187
+ )
188
  if silence_start == 0:
189
  sil_tags.append((0, pos_r))
190
  else:
 
193
  silence_start = None
194
  # Deal with trailing silence.
195
  total_frames = rms_list.shape[0]
196
+ if (
197
+ silence_start is not None
198
+ and total_frames - silence_start >= self.min_interval
199
+ ):
200
  silence_end = min(total_frames, silence_start + self.max_sil_kept)
201
  pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
202
  sil_tags.append((pos, total_frames + 1))
203
  # Apply and return slices.
204
  ####音频+起始时间+终止时间
205
  if len(sil_tags) == 0:
206
+ return [[waveform,0,int(total_frames*self.hop_size)]]
207
  else:
208
  chunks = []
209
  if sil_tags[0][0] > 0:
210
+ chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]),0,int(sil_tags[0][0]*self.hop_size)])
211
  for i in range(len(sil_tags) - 1):
212
  chunks.append(
213
+ [self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),int(sil_tags[i][1]*self.hop_size),int(sil_tags[i + 1][0]*self.hop_size)]
 
 
 
 
214
  )
215
  if sil_tags[-1][1] < total_frames:
216
  chunks.append(
217
+ [self._apply_slice(waveform, sil_tags[-1][1], total_frames),int(sil_tags[-1][1]*self.hop_size),int(total_frames*self.hop_size)]
 
 
 
 
218
  )
219
  return chunks
220
 
221
+ #terminal
222
+ def terminate_process_tree(pid, including_parent=True):
 
223
  try:
224
  parent = psutil.Process(pid)
225
  except psutil.NoSuchProcess:
 
238
  except OSError:
239
  pass
240
 
 
241
  def terminate_process(pid):
242
  if system == "Windows":
243
  cmd = f"taskkill /t /f /pid {pid}"
 
245
  else:
246
  terminate_process_tree(pid)
247
 
248
+ def start_training(dataset_name="",
249
+ exp_name="F5TTS_Base",
250
+ learning_rate=1e-4,
251
+ batch_size_per_gpu=400,
252
+ batch_size_type="frame",
253
+ max_samples=64,
254
+ grad_accumulation_steps=1,
255
+ max_grad_norm=1.0,
256
+ epochs=11,
257
+ num_warmup_updates=200,
258
+ save_per_updates=400,
259
+ last_per_steps=800,
260
+ finetune=True,
261
+ ):
262
 
263
+
264
+ global training_process
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  path_project = os.path.join(path_data, dataset_name + "_pinyin")
267
 
268
+ if os.path.isdir(path_project)==False:
269
+ yield f"There is not project with name {dataset_name}",gr.update(interactive=True),gr.update(interactive=False)
 
 
 
 
270
  return
271
 
272
+ file_raw = os.path.join(path_project,"raw.arrow")
273
+ if os.path.isfile(file_raw)==False:
274
+ yield f"There is no file {file_raw}",gr.update(interactive=True),gr.update(interactive=False)
275
+ return
276
 
277
  # Check if a training process is already running
278
  if training_process is not None:
279
+ return "Train run already!",gr.update(interactive=False),gr.update(interactive=True)
280
 
281
+ yield "start train",gr.update(interactive=False),gr.update(interactive=False)
282
 
283
  # Command to run the training script with the specified arguments
284
+ cmd = f"accelerate launch finetune-cli.py --exp_name {exp_name} " \
285
+ f"--learning_rate {learning_rate} " \
286
+ f"--batch_size_per_gpu {batch_size_per_gpu} " \
287
+ f"--batch_size_type {batch_size_type} " \
288
+ f"--max_samples {max_samples} " \
289
+ f"--grad_accumulation_steps {grad_accumulation_steps} " \
290
+ f"--max_grad_norm {max_grad_norm} " \
291
+ f"--epochs {epochs} " \
292
+ f"--num_warmup_updates {num_warmup_updates} " \
293
+ f"--save_per_updates {save_per_updates} " \
294
+ f"--last_per_steps {last_per_steps} " \
295
+ f"--dataset_name {dataset_name}"
296
+ if finetune:cmd += f" --finetune {finetune}"
297
+ print(cmd)
 
 
 
 
 
298
  try:
299
+ # Start the training process
300
+ training_process = subprocess.Popen(cmd, shell=True)
 
 
 
301
 
302
+ time.sleep(5)
303
+ yield "check terminal for wandb",gr.update(interactive=False),gr.update(interactive=True)
304
+
305
+ # Wait for the training process to finish
306
+ training_process.wait()
307
+ time.sleep(1)
308
+
309
+ if training_process is None:
310
+ text_info = 'train stop'
311
+ else:
312
+ text_info = "train complete !"
313
 
314
  except Exception as e: # Catch all exceptions
315
  # Ensure that we reset the training process variable in case of an error
316
+ text_info=f"An error occurred: {str(e)}"
317
+
318
+ training_process=None
 
 
319
 
320
+ yield text_info,gr.update(interactive=True),gr.update(interactive=False)
321
 
322
  def stop_training():
323
  global training_process
324
+ if training_process is None:return f"Train not run !",gr.update(interactive=True),gr.update(interactive=False)
 
325
  terminate_process_tree(training_process.pid)
326
  training_process = None
327
+ return 'train stop',gr.update(interactive=True),gr.update(interactive=False)
 
328
 
329
  def create_data_project(name):
330
+ name+="_pinyin"
331
+ os.makedirs(os.path.join(path_data,name),exist_ok=True)
332
+ os.makedirs(os.path.join(path_data,name,"dataset"),exist_ok=True)
333
+
334
+ def transcribe(file_audio,language="english"):
 
335
  global pipe
336
 
337
  if pipe is None:
338
+ pipe = pipeline("automatic-speech-recognition",model="openai/whisper-large-v3-turbo", torch_dtype=torch.float16,device=device)
 
 
 
 
 
339
 
340
  text_transcribe = pipe(
341
  file_audio,
342
  chunk_length_s=30,
343
  batch_size=128,
344
+ generate_kwargs={"task": "transcribe","language": language},
345
  return_timestamps=False,
346
  )["text"].strip()
347
  return text_transcribe
348
 
349
+ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Progress()):
350
+ name_project+="_pinyin"
351
+ path_project= os.path.join(path_data,name_project)
352
+ path_dataset = os.path.join(path_project,"dataset")
353
+ path_project_wavs = os.path.join(path_project,"wavs")
354
+ file_metadata = os.path.join(path_project,"metadata.csv")
355
 
356
+ if audio_files is None:return "You need to load an audio file."
 
 
 
 
 
 
 
 
357
 
358
  if os.path.isdir(path_project_wavs):
359
+ shutil.rmtree(path_project_wavs)
360
 
361
  if os.path.isfile(file_metadata):
362
+ os.remove(file_metadata)
 
 
363
 
364
+ os.makedirs(path_project_wavs,exist_ok=True)
365
+
366
  if user:
367
+ file_audios = [file for format in ('*.wav', '*.ogg', '*.opus', '*.mp3', '*.flac') for file in glob(os.path.join(path_dataset, format))]
368
+ if file_audios==[]:return "No audio file was found in the dataset."
 
 
 
 
 
369
  else:
370
+ file_audios = audio_files
371
+
372
 
373
  alpha = 0.5
374
  _max = 1.0
 
376
 
377
  num = 0
378
  error_num = 0
379
+ data=""
380
+ for file_audio in progress.tqdm(file_audios, desc="transcribe files",total=len((file_audios))):
381
+
382
+ audio, _ = librosa.load(file_audio, sr=24000, mono=True)
383
+
384
+ list_slicer=slicer.slice(audio)
385
+ for chunk, start, end in progress.tqdm(list_slicer,total=len(list_slicer), desc="slicer files"):
386
+
387
  name_segment = os.path.join(f"segment_{num}")
388
+ file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav")
389
+
390
  tmp_max = np.abs(chunk).max()
391
+ if(tmp_max>1):chunk/=tmp_max
 
392
  chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
393
+ wavfile.write(file_segment,24000, (chunk * 32767).astype(np.int16))
394
+
395
  try:
396
+ text=transcribe(file_segment,language)
397
+ text = text.lower().strip().replace('"',"")
398
 
399
+ data+= f"{name_segment}|{text}\n"
400
 
401
+ num+=1
402
+ except:
403
+ error_num +=1
404
 
405
+ with open(file_metadata,"w",encoding="utf-8") as f:
406
  f.write(data)
407
+
408
+ if error_num!=[]:
409
+ error_text=f"\nerror files : {error_num}"
410
  else:
411
+ error_text=""
412
+
413
  return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
414
 
 
415
  def format_seconds_to_hms(seconds):
416
  hours = int(seconds / 3600)
417
  minutes = int((seconds % 3600) / 60)
418
  seconds = seconds % 60
419
  return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
420
 
421
+ def create_metadata(name_project,progress=gr.Progress()):
422
+ name_project+="_pinyin"
423
+ path_project= os.path.join(path_data,name_project)
424
+ path_project_wavs = os.path.join(path_project,"wavs")
425
+ file_metadata = os.path.join(path_project,"metadata.csv")
426
+ file_raw = os.path.join(path_project,"raw.arrow")
427
+ file_duration = os.path.join(path_project,"duration.json")
428
+ file_vocab = os.path.join(path_project,"vocab.txt")
429
+
430
+ if os.path.isfile(file_metadata)==False: return "The file was not found in " + file_metadata
431
+
432
+ with open(file_metadata,"r",encoding="utf-8") as f:
433
+ data=f.read()
434
+
435
+ audio_path_list=[]
436
+ text_list=[]
437
+ duration_list=[]
438
+
439
+ count=data.split("\n")
440
+ lenght=0
441
+ result=[]
442
+ error_files=[]
443
+ for line in progress.tqdm(data.split("\n"),total=count):
444
+ sp_line=line.split("|")
445
+ if len(sp_line)!=2:continue
446
+ name_audio,text = sp_line[:2]
 
 
 
447
 
448
  file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
449
 
450
+ if os.path.isfile(file_audio)==False:
451
  error_files.append(file_audio)
452
  continue
453
 
454
  duraction = get_audio_duration(file_audio)
455
+ if duraction<2 and duraction>15:continue
456
+ if len(text)<4:continue
 
 
457
 
458
  text = clear_text(text)
459
+ text = convert_char_to_pinyin([text], polyphone = True)[0]
460
 
461
  audio_path_list.append(file_audio)
462
  duration_list.append(duraction)
463
  text_list.append(text)
464
+
465
  result.append({"audio_path": file_audio, "text": text, "duration": duraction})
466
 
467
+ lenght+=duraction
468
 
469
+ if duration_list==[]:
470
+ error_files_text="\n".join(error_files)
471
  return f"Error: No audio files found in the specified path : \n{error_files_text}"
472
+
473
+ min_second = round(min(duration_list),2)
474
+ max_second = round(max(duration_list),2)
475
 
476
  with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
477
+ for line in progress.tqdm(result,total=len(result), desc=f"prepare data"):
478
  writer.write(line)
479
 
480
+ with open(file_duration, 'w', encoding='utf-8') as f:
481
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
482
+
483
+ file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
484
+ if os.path.isfile(file_vocab_finetune==False):return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
 
485
  shutil.copy2(file_vocab_finetune, file_vocab)
486
+
487
+ if error_files!=[]:
488
+ error_text="error files\n" + "\n".join(error_files)
489
  else:
490
+ error_text=""
491
+
492
  return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
493
 
 
494
  def check_user(value):
495
+ return gr.update(visible=not value),gr.update(visible=value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
 
497
+ def calculate_train(name_project,batch_size_type,max_samples,learning_rate,num_warmup_updates,save_per_updates,last_per_steps,finetune):
498
+ name_project+="_pinyin"
499
+ path_project= os.path.join(path_data,name_project)
500
+ file_duraction = os.path.join(path_project,"duration.json")
501
 
502
+ with open(file_duraction, 'r') as file:
503
+ data = json.load(file)
504
+
505
+ duration_list = data['duration']
506
 
507
  samples = len(duration_list)
508
 
509
  if torch.cuda.is_available():
510
  gpu_properties = torch.cuda.get_device_properties(0)
511
+ total_memory = gpu_properties.total_memory / (1024 ** 3)
512
  elif torch.backends.mps.is_available():
513
+ total_memory = psutil.virtual_memory().available / (1024 ** 3)
514
+
515
+ if batch_size_type=="frame":
516
+ batch = int(total_memory * 0.5)
517
+ batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
518
+ batch_size_per_gpu = int(38400 / batch )
 
 
 
 
 
 
 
 
 
 
519
  else:
520
+ batch_size_per_gpu = int(total_memory / 8)
521
+ batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
522
+ batch = batch_size_per_gpu
523
+
524
+ if batch_size_per_gpu<=0:batch_size_per_gpu=1
525
+
526
+ if samples<64:
527
+ max_samples = int(samples * 0.25)
528
+
529
+ num_warmup_updates = int(samples * 0.10)
530
+ save_per_updates = int(samples * 0.25)
531
+ last_per_steps =int(save_per_updates * 5)
532
+
533
  max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
534
  num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
535
  save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
536
  last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
537
 
538
+ if finetune:learning_rate=1e-4
539
+ else:learning_rate=7.5e-5
 
 
 
 
540
 
541
+ return batch_size_per_gpu,max_samples,num_warmup_updates,save_per_updates,last_per_steps,samples,learning_rate
542
 
543
  def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None:
544
  try:
545
  checkpoint = torch.load(checkpoint_path)
546
  print("Original Checkpoint Keys:", checkpoint.keys())
547
+
548
+ ema_model_state_dict = checkpoint.get('ema_model_state_dict', None)
549
 
550
  if ema_model_state_dict is not None:
551
+ new_checkpoint = {'ema_model_state_dict': ema_model_state_dict}
552
  torch.save(new_checkpoint, new_checkpoint_path)
553
  return f"New checkpoint saved at: {new_checkpoint_path}"
554
  else:
 
557
  except Exception as e:
558
  return f"An error occurred: {e}"
559
 
 
560
  def vocab_check(project_name):
561
  name_project = project_name + "_pinyin"
562
  path_project = os.path.join(path_data, name_project)
563
 
564
  file_metadata = os.path.join(path_project, "metadata.csv")
565
+
566
+ file_vocab="data/Emilia_ZH_EN_pinyin/vocab.txt"
567
+ if os.path.isfile(file_vocab)==False:
568
  return f"the file {file_vocab} not found !"
569
+
570
+ with open(file_vocab,"r",encoding="utf-8") as f:
571
+ data=f.read()
572
 
573
  vocab = data.split("\n")
574
 
575
+ if os.path.isfile(file_metadata)==False:
576
  return f"the file {file_metadata} not found !"
577
 
578
+ with open(file_metadata,"r",encoding="utf-8") as f:
579
+ data=f.read()
580
 
581
+ miss_symbols=[]
582
+ miss_symbols_keep={}
583
  for item in data.split("\n"):
584
+ sp=item.split("|")
585
+ if len(sp)!=2:continue
586
+ text=sp[1].lower().strip()
587
 
588
+ for t in text:
589
+ if (t in vocab)==False and (t in miss_symbols_keep)==False:
590
+ miss_symbols.append(t)
591
+ miss_symbols_keep[t]=t
592
 
593
+
594
+ if miss_symbols==[]:info ="You can train using your language !"
595
+ else:info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
 
 
 
 
 
596
 
597
  return info
598
 
599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
 
601
+ with gr.Blocks() as app:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
 
603
+ with gr.Row():
604
+ project_name=gr.Textbox(label="project name",value="my_speak")
605
+ bt_create=gr.Button("create new project")
606
+
607
+ bt_create.click(fn=create_data_project,inputs=[project_name])
608
 
609
+ with gr.Tabs():
610
+
611
 
612
+ with gr.TabItem("transcribe Data"):
 
 
 
613
 
 
614
 
615
+ ch_manual = gr.Checkbox(label="user",value=False)
 
 
616
 
617
+ mark_info_transcribe=gr.Markdown(
618
+ """```plaintext
619
  Place your 'wavs' folder and 'metadata.csv' file in the {your_project_name}' directory.
620
 
621
  my_speak/
 
624
  ├── audio1.wav
625
  └── audio2.wav
626
  ...
627
+ ```""",visible=False)
628
+
629
+ audio_speaker = gr.File(label="voice",type="filepath",file_count="multiple")
630
+ txt_lang = gr.Text(label="Language",value="english")
631
+ bt_transcribe=bt_create=gr.Button("transcribe")
632
+ txt_info_transcribe=gr.Text(label="info",value="")
633
+ bt_transcribe.click(fn=transcribe_all,inputs=[project_name,audio_speaker,txt_lang,ch_manual],outputs=[txt_info_transcribe])
634
+ ch_manual.change(fn=check_user,inputs=[ch_manual],outputs=[audio_speaker,mark_info_transcribe])
635
+
636
+ with gr.TabItem("prepare Data"):
637
+ gr.Markdown(
638
+ """```plaintext
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
  place all your wavs folder and your metadata.csv file in {your name project}
640
  my_speak/
641
 
 
652
  audio2|text1
653
  ...
654
 
655
+ ```""")
656
+
657
+ bt_prepare=bt_create=gr.Button("prepare")
658
+ txt_info_prepare=gr.Text(label="info",value="")
659
+ bt_prepare.click(fn=create_metadata,inputs=[project_name],outputs=[txt_info_prepare])
660
+
661
+ with gr.TabItem("train Data"):
662
+
663
+ with gr.Row():
664
+ bt_calculate=bt_create=gr.Button("Auto Settings")
665
+ ch_finetune=bt_create=gr.Checkbox(label="finetune",value=True)
666
+ lb_samples = gr.Label(label="samples")
667
+ batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
668
+
669
+ with gr.Row():
670
+ exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
671
+ learning_rate = gr.Number(label="Learning Rate", value=1e-4, step=1e-4)
672
+
673
+ with gr.Row():
674
+ batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
675
+ max_samples = gr.Number(label="Max Samples", value=16)
676
+
677
+ with gr.Row():
678
+ grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
679
+ max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
680
+
681
+ with gr.Row():
682
+ epochs = gr.Number(label="Epochs", value=10)
683
+ num_warmup_updates = gr.Number(label="Warmup Updates", value=5)
684
+
685
+ with gr.Row():
686
+ save_per_updates = gr.Number(label="Save per Updates", value=10)
687
+ last_per_steps = gr.Number(label="Last per Steps", value=50)
688
+
689
+ with gr.Row():
690
+ start_button = gr.Button("Start Training")
691
+ stop_button = gr.Button("Stop Training",interactive=False)
692
+
693
+ txt_info_train=gr.Text(label="info",value="")
694
+ start_button.click(fn=start_training,inputs=[project_name,exp_name,learning_rate,batch_size_per_gpu,batch_size_type,max_samples,grad_accumulation_steps,max_grad_norm,epochs,num_warmup_updates,save_per_updates,last_per_steps,ch_finetune],outputs=[txt_info_train,start_button,stop_button])
695
+ stop_button.click(fn=stop_training,outputs=[txt_info_train,start_button,stop_button])
696
+ bt_calculate.click(fn=calculate_train,inputs=[project_name,batch_size_type,max_samples,learning_rate,num_warmup_updates,save_per_updates,last_per_steps,ch_finetune],outputs=[batch_size_per_gpu,max_samples,num_warmup_updates,save_per_updates,last_per_steps,lb_samples,learning_rate])
697
+
698
+ with gr.TabItem("reduse checkpoint"):
699
+ txt_path_checkpoint = gr.Text(label="path checkpoint :")
700
+ txt_path_checkpoint_small = gr.Text(label="path output :")
701
+ txt_info_reduse = gr.Text(label="info",value="")
702
+ reduse_button = gr.Button("reduse")
703
+ reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small],outputs=[txt_info_reduse])
704
+
705
+ with gr.TabItem("vocab check experiment"):
706
+ check_button = gr.Button("check vocab")
707
+ txt_info_check=gr.Text(label="info",value="")
708
+ check_button.click(fn=vocab_check,inputs=[project_name],outputs=[txt_info_check])
709
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
710
 
711
  @click.command()
712
  @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
 
721
  @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
722
  def main(port, host, share, api):
723
  global app
724
+ print(f"Starting app...")
725
+ app.queue(api_open=api).launch(
726
+ server_name=host, server_port=port, share=share, show_api=api
727
+ )
728
 
729
  if __name__ == "__main__":
730
  main()
gradio_app.py DELETED
@@ -1,824 +0,0 @@
1
- import os
2
- import re
3
- import torch
4
- import torchaudio
5
- import gradio as gr
6
- import numpy as np
7
- import tempfile
8
- from einops import rearrange
9
- from vocos import Vocos
10
- from pydub import AudioSegment, silence
11
- from model import CFM, UNetT, DiT, MMDiT
12
- from cached_path import cached_path
13
- from model.utils import (
14
- load_checkpoint,
15
- get_tokenizer,
16
- convert_char_to_pinyin,
17
- save_spectrogram,
18
- )
19
- from transformers import pipeline
20
- import librosa
21
- import click
22
- import soundfile as sf
23
-
24
- try:
25
- import spaces
26
- USING_SPACES = True
27
- except ImportError:
28
- USING_SPACES = False
29
-
30
- def gpu_decorator(func):
31
- if USING_SPACES:
32
- return spaces.GPU(func)
33
- else:
34
- return func
35
-
36
-
37
-
38
- SPLIT_WORDS = [
39
- "but", "however", "nevertheless", "yet", "still",
40
- "therefore", "thus", "hence", "consequently",
41
- "moreover", "furthermore", "additionally",
42
- "meanwhile", "alternatively", "otherwise",
43
- "namely", "specifically", "for example", "such as",
44
- "in fact", "indeed", "notably",
45
- "in contrast", "on the other hand", "conversely",
46
- "in conclusion", "to summarize", "finally"
47
- ]
48
-
49
- device = (
50
- "cuda"
51
- if torch.cuda.is_available()
52
- else "mps" if torch.backends.mps.is_available() else "cpu"
53
- )
54
-
55
- print(f"Using {device} device")
56
-
57
- pipe = pipeline(
58
- "automatic-speech-recognition",
59
- model="openai/whisper-large-v3-turbo",
60
- torch_dtype=torch.float16,
61
- device=device,
62
- )
63
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
64
-
65
- # --------------------- Settings -------------------- #
66
-
67
- target_sample_rate = 24000
68
- n_mel_channels = 100
69
- hop_length = 256
70
- target_rms = 0.1
71
- nfe_step = 32 # 16, 32
72
- cfg_strength = 2.0
73
- ode_method = "euler"
74
- sway_sampling_coef = -1.0
75
- speed = 1.0
76
- # fix_duration = 27 # None or float (duration in seconds)
77
- fix_duration = None
78
-
79
-
80
- def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
81
- ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
82
- # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
83
- vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
84
- model = CFM(
85
- transformer=model_cls(
86
- **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
87
- ),
88
- mel_spec_kwargs=dict(
89
- target_sample_rate=target_sample_rate,
90
- n_mel_channels=n_mel_channels,
91
- hop_length=hop_length,
92
- ),
93
- odeint_kwargs=dict(
94
- method=ode_method,
95
- ),
96
- vocab_char_map=vocab_char_map,
97
- ).to(device)
98
-
99
- model = load_checkpoint(model, ckpt_path, device, use_ema = True)
100
-
101
- return model
102
-
103
-
104
- # load models
105
- F5TTS_model_cfg = dict(
106
- dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
107
- )
108
- E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
109
-
110
- F5TTS_ema_model = load_model(
111
- "F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
112
- )
113
- E2TTS_ema_model = load_model(
114
- "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
115
- )
116
-
117
- def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
118
- if len(text.encode('utf-8')) <= max_chars:
119
- return [text]
120
- if text[-1] not in ['。', '.', '!', '!', '?', '?']:
121
- text += '.'
122
-
123
- sentences = re.split('([。.!?!?])', text)
124
- sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
125
-
126
- batches = []
127
- current_batch = ""
128
-
129
- def split_by_words(text):
130
- words = text.split()
131
- current_word_part = ""
132
- word_batches = []
133
- for word in words:
134
- if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
135
- current_word_part += word + ' '
136
- else:
137
- if current_word_part:
138
- # Try to find a suitable split word
139
- for split_word in split_words:
140
- split_index = current_word_part.rfind(' ' + split_word + ' ')
141
- if split_index != -1:
142
- word_batches.append(current_word_part[:split_index].strip())
143
- current_word_part = current_word_part[split_index:].strip() + ' '
144
- break
145
- else:
146
- # If no suitable split word found, just append the current part
147
- word_batches.append(current_word_part.strip())
148
- current_word_part = ""
149
- current_word_part += word + ' '
150
- if current_word_part:
151
- word_batches.append(current_word_part.strip())
152
- return word_batches
153
-
154
- for sentence in sentences:
155
- if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
156
- current_batch += sentence
157
- else:
158
- # If adding this sentence would exceed the limit
159
- if current_batch:
160
- batches.append(current_batch)
161
- current_batch = ""
162
-
163
- # If the sentence itself is longer than max_chars, split it
164
- if len(sentence.encode('utf-8')) > max_chars:
165
- # First, try to split by colon
166
- colon_parts = sentence.split(':')
167
- if len(colon_parts) > 1:
168
- for part in colon_parts:
169
- if len(part.encode('utf-8')) <= max_chars:
170
- batches.append(part)
171
- else:
172
- # If colon part is still too long, split by comma
173
- comma_parts = re.split('[,,]', part)
174
- if len(comma_parts) > 1:
175
- current_comma_part = ""
176
- for comma_part in comma_parts:
177
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
178
- current_comma_part += comma_part + ','
179
- else:
180
- if current_comma_part:
181
- batches.append(current_comma_part.rstrip(','))
182
- current_comma_part = comma_part + ','
183
- if current_comma_part:
184
- batches.append(current_comma_part.rstrip(','))
185
- else:
186
- # If no comma, split by words
187
- batches.extend(split_by_words(part))
188
- else:
189
- # If no colon, split by comma
190
- comma_parts = re.split('[,,]', sentence)
191
- if len(comma_parts) > 1:
192
- current_comma_part = ""
193
- for comma_part in comma_parts:
194
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
195
- current_comma_part += comma_part + ','
196
- else:
197
- if current_comma_part:
198
- batches.append(current_comma_part.rstrip(','))
199
- current_comma_part = comma_part + ','
200
- if current_comma_part:
201
- batches.append(current_comma_part.rstrip(','))
202
- else:
203
- # If no comma, split by words
204
- batches.extend(split_by_words(sentence))
205
- else:
206
- current_batch = sentence
207
-
208
- if current_batch:
209
- batches.append(current_batch)
210
-
211
- return batches
212
-
213
- def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, progress=gr.Progress()):
214
- if exp_name == "F5-TTS":
215
- ema_model = F5TTS_ema_model
216
- elif exp_name == "E2-TTS":
217
- ema_model = E2TTS_ema_model
218
-
219
- audio, sr = ref_audio
220
- if audio.shape[0] > 1:
221
- audio = torch.mean(audio, dim=0, keepdim=True)
222
-
223
- rms = torch.sqrt(torch.mean(torch.square(audio)))
224
- if rms < target_rms:
225
- audio = audio * target_rms / rms
226
- if sr != target_sample_rate:
227
- resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
228
- audio = resampler(audio)
229
- audio = audio.to(device)
230
-
231
- generated_waves = []
232
- spectrograms = []
233
-
234
- for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
235
- # Prepare the text
236
- if len(ref_text[-1].encode('utf-8')) == 1:
237
- ref_text = ref_text + " "
238
- text_list = [ref_text + gen_text]
239
- final_text_list = convert_char_to_pinyin(text_list)
240
-
241
- # Calculate duration
242
- ref_audio_len = audio.shape[-1] // hop_length
243
- zh_pause_punc = r"。,、;:?!"
244
- ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
245
- gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
246
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
247
-
248
- # inference
249
- with torch.inference_mode():
250
- generated, _ = ema_model.sample(
251
- cond=audio,
252
- text=final_text_list,
253
- duration=duration,
254
- steps=nfe_step,
255
- cfg_strength=cfg_strength,
256
- sway_sampling_coef=sway_sampling_coef,
257
- )
258
-
259
- generated = generated[:, ref_audio_len:, :]
260
- generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
261
- generated_wave = vocos.decode(generated_mel_spec.cpu())
262
- if rms < target_rms:
263
- generated_wave = generated_wave * rms / target_rms
264
-
265
- # wav -> numpy
266
- generated_wave = generated_wave.squeeze().cpu().numpy()
267
-
268
- generated_waves.append(generated_wave)
269
- spectrograms.append(generated_mel_spec[0].cpu().numpy())
270
-
271
- # Combine all generated waves
272
- final_wave = np.concatenate(generated_waves)
273
-
274
- # Remove silence
275
- if remove_silence:
276
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
277
- sf.write(f.name, final_wave, target_sample_rate)
278
- aseg = AudioSegment.from_file(f.name)
279
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
280
- non_silent_wave = AudioSegment.silent(duration=0)
281
- for non_silent_seg in non_silent_segs:
282
- non_silent_wave += non_silent_seg
283
- aseg = non_silent_wave
284
- aseg.export(f.name, format="wav")
285
- final_wave, _ = torchaudio.load(f.name)
286
- final_wave = final_wave.squeeze().cpu().numpy()
287
-
288
- # Create a combined spectrogram
289
- combined_spectrogram = np.concatenate(spectrograms, axis=1)
290
-
291
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
292
- spectrogram_path = tmp_spectrogram.name
293
- save_spectrogram(combined_spectrogram, spectrogram_path)
294
-
295
- return (target_sample_rate, final_wave), spectrogram_path
296
-
297
- def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words=''):
298
- if not custom_split_words.strip():
299
- custom_words = [word.strip() for word in custom_split_words.split(',')]
300
- global SPLIT_WORDS
301
- SPLIT_WORDS = custom_words
302
-
303
- print(gen_text)
304
-
305
- gr.Info("Converting audio...")
306
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
307
- aseg = AudioSegment.from_file(ref_audio_orig)
308
-
309
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
310
- non_silent_wave = AudioSegment.silent(duration=0)
311
- for non_silent_seg in non_silent_segs:
312
- non_silent_wave += non_silent_seg
313
- aseg = non_silent_wave
314
-
315
- audio_duration = len(aseg)
316
- if audio_duration > 15000:
317
- gr.Warning("Audio is over 15s, clipping to only first 15s.")
318
- aseg = aseg[:15000]
319
- aseg.export(f.name, format="wav")
320
- ref_audio = f.name
321
-
322
- if not ref_text.strip():
323
- gr.Info("No reference text provided, transcribing reference audio...")
324
- ref_text = pipe(
325
- ref_audio,
326
- chunk_length_s=30,
327
- batch_size=128,
328
- generate_kwargs={"task": "transcribe"},
329
- return_timestamps=False,
330
- )["text"].strip()
331
- gr.Info("Finished transcription")
332
- else:
333
- gr.Info("Using custom reference text...")
334
-
335
- # Split the input text into batches
336
- audio, sr = torchaudio.load(ref_audio)
337
- max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
338
- gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
339
- print('ref_text', ref_text)
340
- for i, gen_text in enumerate(gen_text_batches):
341
- print(f'gen_text {i}', gen_text)
342
-
343
- gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
344
- return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
345
-
346
- def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
347
- # Split the script into speaker blocks
348
- speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
349
- speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
350
-
351
- generated_audio_segments = []
352
-
353
- for i in range(0, len(speaker_blocks), 2):
354
- speaker = speaker_blocks[i]
355
- text = speaker_blocks[i+1].strip()
356
-
357
- # Determine which speaker is talking
358
- if speaker == speaker1_name:
359
- ref_audio = ref_audio1
360
- ref_text = ref_text1
361
- elif speaker == speaker2_name:
362
- ref_audio = ref_audio2
363
- ref_text = ref_text2
364
- else:
365
- continue # Skip if the speaker is neither speaker1 nor speaker2
366
-
367
- # Generate audio for this block
368
- audio, _ = infer(ref_audio, ref_text, text, exp_name, remove_silence)
369
-
370
- # Convert the generated audio to a numpy array
371
- sr, audio_data = audio
372
-
373
- # Save the audio data as a WAV file
374
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
375
- sf.write(temp_file.name, audio_data, sr)
376
- audio_segment = AudioSegment.from_wav(temp_file.name)
377
-
378
- generated_audio_segments.append(audio_segment)
379
-
380
- # Add a short pause between speakers
381
- pause = AudioSegment.silent(duration=500) # 500ms pause
382
- generated_audio_segments.append(pause)
383
-
384
- # Concatenate all audio segments
385
- final_podcast = sum(generated_audio_segments)
386
-
387
- # Export the final podcast
388
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
389
- podcast_path = temp_file.name
390
- final_podcast.export(podcast_path, format="wav")
391
-
392
- return podcast_path
393
-
394
- def parse_speechtypes_text(gen_text):
395
- # Pattern to find (Emotion)
396
- pattern = r'\((.*?)\)'
397
-
398
- # Split the text by the pattern
399
- tokens = re.split(pattern, gen_text)
400
-
401
- segments = []
402
-
403
- current_emotion = 'Regular'
404
-
405
- for i in range(len(tokens)):
406
- if i % 2 == 0:
407
- # This is text
408
- text = tokens[i].strip()
409
- if text:
410
- segments.append({'emotion': current_emotion, 'text': text})
411
- else:
412
- # This is emotion
413
- emotion = tokens[i].strip()
414
- current_emotion = emotion
415
-
416
- return segments
417
-
418
- def update_speed(new_speed):
419
- global speed
420
- speed = new_speed
421
- return f"Speed set to: {speed}"
422
-
423
- with gr.Blocks() as app_credits:
424
- gr.Markdown("""
425
- # Credits
426
-
427
- * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
428
- * [RootingInLoad](https://github.com/RootingInLoad) for the podcast generation
429
- """)
430
- with gr.Blocks() as app_tts:
431
- gr.Markdown("# Batched TTS")
432
- ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
433
- gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
434
- model_choice = gr.Radio(
435
- choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
436
- )
437
- generate_btn = gr.Button("Synthesize", variant="primary")
438
- with gr.Accordion("Advanced Settings", open=False):
439
- ref_text_input = gr.Textbox(
440
- label="Reference Text",
441
- info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
442
- lines=2,
443
- )
444
- remove_silence = gr.Checkbox(
445
- label="Remove Silences",
446
- info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
447
- value=True,
448
- )
449
- split_words_input = gr.Textbox(
450
- label="Custom Split Words",
451
- info="Enter custom words to split on, separated by commas. Leave blank to use default list.",
452
- lines=2,
453
- )
454
- speed_slider = gr.Slider(
455
- label="Speed",
456
- minimum=0.3,
457
- maximum=2.0,
458
- value=speed,
459
- step=0.1,
460
- info="Adjust the speed of the audio.",
461
- )
462
- speed_slider.change(update_speed, inputs=speed_slider)
463
-
464
- audio_output = gr.Audio(label="Synthesized Audio")
465
- spectrogram_output = gr.Image(label="Spectrogram")
466
-
467
- generate_btn.click(
468
- infer,
469
- inputs=[
470
- ref_audio_input,
471
- ref_text_input,
472
- gen_text_input,
473
- model_choice,
474
- remove_silence,
475
- split_words_input,
476
- ],
477
- outputs=[audio_output, spectrogram_output],
478
- )
479
-
480
- with gr.Blocks() as app_podcast:
481
- gr.Markdown("# Podcast Generation")
482
- speaker1_name = gr.Textbox(label="Speaker 1 Name")
483
- ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
484
- ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
485
-
486
- speaker2_name = gr.Textbox(label="Speaker 2 Name")
487
- ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
488
- ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
489
-
490
- script_input = gr.Textbox(label="Podcast Script", lines=10,
491
- placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...")
492
-
493
- podcast_model_choice = gr.Radio(
494
- choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
495
- )
496
- podcast_remove_silence = gr.Checkbox(
497
- label="Remove Silences",
498
- value=True,
499
- )
500
- generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
501
- podcast_output = gr.Audio(label="Generated Podcast")
502
-
503
- def podcast_generation(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence):
504
- return generate_podcast(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence)
505
-
506
- generate_podcast_btn.click(
507
- podcast_generation,
508
- inputs=[
509
- script_input,
510
- speaker1_name,
511
- ref_audio_input1,
512
- ref_text_input1,
513
- speaker2_name,
514
- ref_audio_input2,
515
- ref_text_input2,
516
- podcast_model_choice,
517
- podcast_remove_silence,
518
- ],
519
- outputs=podcast_output,
520
- )
521
-
522
- def parse_emotional_text(gen_text):
523
- # Pattern to find (Emotion)
524
- pattern = r'\((.*?)\)'
525
-
526
- # Split the text by the pattern
527
- tokens = re.split(pattern, gen_text)
528
-
529
- segments = []
530
-
531
- current_emotion = 'Regular'
532
-
533
- for i in range(len(tokens)):
534
- if i % 2 == 0:
535
- # This is text
536
- text = tokens[i].strip()
537
- if text:
538
- segments.append({'emotion': current_emotion, 'text': text})
539
- else:
540
- # This is emotion
541
- emotion = tokens[i].strip()
542
- current_emotion = emotion
543
-
544
- return segments
545
-
546
- with gr.Blocks() as app_emotional:
547
- # New section for emotional generation
548
- gr.Markdown(
549
- """
550
- # Multiple Speech-Type Generation
551
-
552
- This section allows you to upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the "Add Speech Type" button. Enter your text in the format shown below, and the system will generate speech using the appropriate emotions. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
553
-
554
- **Example Input:**
555
-
556
- (Regular) Hello, I'd like to order a sandwich please. (Surprised) What do you mean you're out of bread? (Sad) I really wanted a sandwich though... (Angry) You know what, darn you and your little shop, you suck! (Whisper) I'll just go back home and cry now. (Shouting) Why me?!
557
- """
558
- )
559
-
560
- gr.Markdown("Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button.")
561
-
562
- # Regular speech type (mandatory)
563
- with gr.Row():
564
- regular_name = gr.Textbox(value='Regular', label='Speech Type Name', interactive=False)
565
- regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
566
- regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
567
-
568
- # Additional speech types (up to 9 more)
569
- max_speech_types = 10
570
- speech_type_names = []
571
- speech_type_audios = []
572
- speech_type_ref_texts = []
573
- speech_type_delete_btns = []
574
-
575
- for i in range(max_speech_types - 1):
576
- with gr.Row():
577
- name_input = gr.Textbox(label='Speech Type Name', visible=False)
578
- audio_input = gr.Audio(label='Reference Audio', type='filepath', visible=False)
579
- ref_text_input = gr.Textbox(label='Reference Text', lines=2, visible=False)
580
- delete_btn = gr.Button("Delete", variant="secondary", visible=False)
581
- speech_type_names.append(name_input)
582
- speech_type_audios.append(audio_input)
583
- speech_type_ref_texts.append(ref_text_input)
584
- speech_type_delete_btns.append(delete_btn)
585
-
586
- # Button to add speech type
587
- add_speech_type_btn = gr.Button("Add Speech Type")
588
-
589
- # Keep track of current number of speech types
590
- speech_type_count = gr.State(value=0)
591
-
592
- # Function to add a speech type
593
- def add_speech_type_fn(speech_type_count):
594
- if speech_type_count < max_speech_types - 1:
595
- speech_type_count += 1
596
- # Prepare updates for the components
597
- name_updates = []
598
- audio_updates = []
599
- ref_text_updates = []
600
- delete_btn_updates = []
601
- for i in range(max_speech_types - 1):
602
- if i < speech_type_count:
603
- name_updates.append(gr.update(visible=True))
604
- audio_updates.append(gr.update(visible=True))
605
- ref_text_updates.append(gr.update(visible=True))
606
- delete_btn_updates.append(gr.update(visible=True))
607
- else:
608
- name_updates.append(gr.update())
609
- audio_updates.append(gr.update())
610
- ref_text_updates.append(gr.update())
611
- delete_btn_updates.append(gr.update())
612
- else:
613
- # Optionally, show a warning
614
- # gr.Warning("Maximum number of speech types reached.")
615
- name_updates = [gr.update() for _ in range(max_speech_types - 1)]
616
- audio_updates = [gr.update() for _ in range(max_speech_types - 1)]
617
- ref_text_updates = [gr.update() for _ in range(max_speech_types - 1)]
618
- delete_btn_updates = [gr.update() for _ in range(max_speech_types - 1)]
619
- return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
620
-
621
- add_speech_type_btn.click(
622
- add_speech_type_fn,
623
- inputs=speech_type_count,
624
- outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
625
- )
626
-
627
- # Function to delete a speech type
628
- def make_delete_speech_type_fn(index):
629
- def delete_speech_type_fn(speech_type_count):
630
- # Prepare updates
631
- name_updates = []
632
- audio_updates = []
633
- ref_text_updates = []
634
- delete_btn_updates = []
635
-
636
- for i in range(max_speech_types - 1):
637
- if i == index:
638
- name_updates.append(gr.update(visible=False, value=''))
639
- audio_updates.append(gr.update(visible=False, value=None))
640
- ref_text_updates.append(gr.update(visible=False, value=''))
641
- delete_btn_updates.append(gr.update(visible=False))
642
- else:
643
- name_updates.append(gr.update())
644
- audio_updates.append(gr.update())
645
- ref_text_updates.append(gr.update())
646
- delete_btn_updates.append(gr.update())
647
-
648
- speech_type_count = max(0, speech_type_count - 1)
649
-
650
- return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
651
-
652
- return delete_speech_type_fn
653
-
654
- for i, delete_btn in enumerate(speech_type_delete_btns):
655
- delete_fn = make_delete_speech_type_fn(i)
656
- delete_btn.click(
657
- delete_fn,
658
- inputs=speech_type_count,
659
- outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
660
- )
661
-
662
- # Text input for the prompt
663
- gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
664
-
665
- # Model choice
666
- model_choice_emotional = gr.Radio(
667
- choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
668
- )
669
-
670
- with gr.Accordion("Advanced Settings", open=False):
671
- remove_silence_emotional = gr.Checkbox(
672
- label="Remove Silences",
673
- value=True,
674
- )
675
-
676
- # Generate button
677
- generate_emotional_btn = gr.Button("Generate Emotional Speech", variant="primary")
678
-
679
- # Output audio
680
- audio_output_emotional = gr.Audio(label="Synthesized Audio")
681
-
682
- def generate_emotional_speech(
683
- regular_audio,
684
- regular_ref_text,
685
- gen_text,
686
- *args,
687
- ):
688
- num_additional_speech_types = max_speech_types - 1
689
- speech_type_names_list = args[:num_additional_speech_types]
690
- speech_type_audios_list = args[num_additional_speech_types:2 * num_additional_speech_types]
691
- speech_type_ref_texts_list = args[2 * num_additional_speech_types:3 * num_additional_speech_types]
692
- model_choice = args[3 * num_additional_speech_types]
693
- remove_silence = args[3 * num_additional_speech_types + 1]
694
-
695
- # Collect the speech types and their audios into a dict
696
- speech_types = {'Regular': {'audio': regular_audio, 'ref_text': regular_ref_text}}
697
-
698
- for name_input, audio_input, ref_text_input in zip(speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list):
699
- if name_input and audio_input:
700
- speech_types[name_input] = {'audio': audio_input, 'ref_text': ref_text_input}
701
-
702
- # Parse the gen_text into segments
703
- segments = parse_speechtypes_text(gen_text)
704
-
705
- # For each segment, generate speech
706
- generated_audio_segments = []
707
- current_emotion = 'Regular'
708
-
709
- for segment in segments:
710
- emotion = segment['emotion']
711
- text = segment['text']
712
-
713
- if emotion in speech_types:
714
- current_emotion = emotion
715
- else:
716
- # If emotion not available, default to Regular
717
- current_emotion = 'Regular'
718
-
719
- ref_audio = speech_types[current_emotion]['audio']
720
- ref_text = speech_types[current_emotion].get('ref_text', '')
721
-
722
- # Generate speech for this segment
723
- audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, "")
724
- sr, audio_data = audio
725
-
726
- generated_audio_segments.append(audio_data)
727
-
728
- # Concatenate all audio segments
729
- if generated_audio_segments:
730
- final_audio_data = np.concatenate(generated_audio_segments)
731
- return (sr, final_audio_data)
732
- else:
733
- gr.Warning("No audio generated.")
734
- return None
735
-
736
- generate_emotional_btn.click(
737
- generate_emotional_speech,
738
- inputs=[
739
- regular_audio,
740
- regular_ref_text,
741
- gen_text_input_emotional,
742
- ] + speech_type_names + speech_type_audios + speech_type_ref_texts + [
743
- model_choice_emotional,
744
- remove_silence_emotional,
745
- ],
746
- outputs=audio_output_emotional,
747
- )
748
-
749
- # Validation function to disable Generate button if speech types are missing
750
- def validate_speech_types(
751
- gen_text,
752
- regular_name,
753
- *args
754
- ):
755
- num_additional_speech_types = max_speech_types - 1
756
- speech_type_names_list = args[:num_additional_speech_types]
757
-
758
- # Collect the speech types names
759
- speech_types_available = set()
760
- if regular_name:
761
- speech_types_available.add(regular_name)
762
- for name_input in speech_type_names_list:
763
- if name_input:
764
- speech_types_available.add(name_input)
765
-
766
- # Parse the gen_text to get the speech types used
767
- segments = parse_emotional_text(gen_text)
768
- speech_types_in_text = set(segment['emotion'] for segment in segments)
769
-
770
- # Check if all speech types in text are available
771
- missing_speech_types = speech_types_in_text - speech_types_available
772
-
773
- if missing_speech_types:
774
- # Disable the generate button
775
- return gr.update(interactive=False)
776
- else:
777
- # Enable the generate button
778
- return gr.update(interactive=True)
779
-
780
- gen_text_input_emotional.change(
781
- validate_speech_types,
782
- inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
783
- outputs=generate_emotional_btn
784
- )
785
- with gr.Blocks() as app:
786
- gr.Markdown(
787
- """
788
- # E2/F5 TTS
789
-
790
- This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
791
-
792
- * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
793
- * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
794
-
795
- The checkpoints support English and Chinese.
796
-
797
- If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
798
-
799
- **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
800
- """
801
- )
802
- gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
803
-
804
- @click.command()
805
- @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
806
- @click.option("--host", "-H", default=None, help="Host to run the app on")
807
- @click.option(
808
- "--share",
809
- "-s",
810
- default=False,
811
- is_flag=True,
812
- help="Share the app via Gradio share link",
813
- )
814
- @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
815
- def main(port, host, share, api):
816
- global app
817
- print(f"Starting app...")
818
- app.queue(api_open=api).launch(
819
- server_name=host, server_port=port, share=share, show_api=api
820
- )
821
-
822
-
823
- if __name__ == "__main__":
824
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference-cli.py CHANGED
@@ -1,22 +1,24 @@
1
  import argparse
2
  import codecs
3
  import re
 
4
  from pathlib import Path
5
 
6
  import numpy as np
7
  import soundfile as sf
8
  import tomli
 
 
 
9
  from cached_path import cached_path
 
 
 
 
10
 
11
- from model import DiT, UNetT
12
- from model.utils_infer import (
13
- load_vocoder,
14
- load_model,
15
- preprocess_ref_audio_text,
16
- infer_process,
17
- remove_silence_for_generated_wav,
18
- )
19
-
20
 
21
  parser = argparse.ArgumentParser(
22
  prog="python3 inference-cli.py",
@@ -35,17 +37,18 @@ parser.add_argument(
35
  help="F5-TTS | E2-TTS",
36
  )
37
  parser.add_argument(
38
- "-p",
39
- "--ckpt_file",
40
- help="The Checkpoint .pt",
 
41
  )
42
  parser.add_argument(
43
- "-v",
44
- "--vocab_file",
45
- help="The vocab .txt",
 
 
46
  )
47
- parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
48
- parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
49
  parser.add_argument(
50
  "-t",
51
  "--gen_text",
@@ -85,86 +88,305 @@ if gen_file:
85
  gen_text = codecs.open(gen_file, "r", "utf-8").read()
86
  output_dir = args.output_dir if args.output_dir else config["output_dir"]
87
  model = args.model if args.model else config["model"]
88
- ckpt_file = args.ckpt_file if args.ckpt_file else ""
89
- vocab_file = args.vocab_file if args.vocab_file else ""
90
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
91
- wave_path = Path(output_dir) / "out.wav"
92
- spectrogram_path = Path(output_dir) / "out.png"
93
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
94
 
95
- vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
  # load models
99
- if model == "F5-TTS":
100
- model_cls = DiT
101
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
102
- if ckpt_file == "":
103
- repo_name = "F5-TTS"
104
- exp_name = "F5TTS_Base"
105
- ckpt_step = 1200000
106
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
107
- # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
108
-
109
- elif model == "E2-TTS":
110
- model_cls = UNetT
111
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
112
- if ckpt_file == "":
113
- repo_name = "E2-TTS"
114
- exp_name = "E2TTS_Base"
115
- ckpt_step = 1200000
116
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
117
- # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
118
-
119
- print(f"Using {model}...")
120
- ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
121
-
122
-
123
- def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
124
- main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  if "voices" not in config:
126
  voices = {"main": main_voice}
127
  else:
128
  voices = config["voices"]
129
  voices["main"] = main_voice
130
  for voice in voices:
131
- voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
132
- voices[voice]["ref_audio"], voices[voice]["ref_text"]
133
- )
134
- print("Voice:", voice)
135
- print("Ref_audio:", voices[voice]["ref_audio"])
136
- print("Ref_text:", voices[voice]["ref_text"])
137
 
138
  generated_audio_segments = []
139
- reg1 = r"(?=\[\w+\])"
140
  chunks = re.split(reg1, text_gen)
141
- reg2 = r"\[(\w+)\]"
142
  for text in chunks:
143
  match = re.match(reg2, text)
144
- if match:
145
- voice = match[1]
146
- else:
147
- print("No voice tag found, using main.")
148
- voice = "main"
149
- if voice not in voices:
150
- print(f"Voice {voice} not found, using main.")
151
  voice = "main"
 
 
152
  text = re.sub(reg2, "", text)
153
  gen_text = text.strip()
154
- ref_audio = voices[voice]["ref_audio"]
155
- ref_text = voices[voice]["ref_text"]
156
  print(f"Voice: {voice}")
157
- audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj)
158
  generated_audio_segments.append(audio)
159
 
160
  if generated_audio_segments:
161
  final_wave = np.concatenate(generated_audio_segments)
162
  with open(wave_path, "wb") as f:
163
- sf.write(f.name, final_wave, final_sample_rate)
164
  # Remove silence
165
  if remove_silence:
166
- remove_silence_for_generated_wav(f.name)
 
 
 
 
 
 
167
  print(f.name)
168
 
169
-
170
- main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence)
 
1
  import argparse
2
  import codecs
3
  import re
4
+ import tempfile
5
  from pathlib import Path
6
 
7
  import numpy as np
8
  import soundfile as sf
9
  import tomli
10
+ import torch
11
+ import torchaudio
12
+ import tqdm
13
  from cached_path import cached_path
14
+ from einops import rearrange
15
+ from pydub import AudioSegment, silence
16
+ from transformers import pipeline
17
+ from vocos import Vocos
18
 
19
+ from model import CFM, DiT, MMDiT, UNetT
20
+ from model.utils import (convert_char_to_pinyin, get_tokenizer,
21
+ load_checkpoint, save_spectrogram)
 
 
 
 
 
 
22
 
23
  parser = argparse.ArgumentParser(
24
  prog="python3 inference-cli.py",
 
37
  help="F5-TTS | E2-TTS",
38
  )
39
  parser.add_argument(
40
+ "-r",
41
+ "--ref_audio",
42
+ type=str,
43
+ help="Reference audio file < 15 seconds."
44
  )
45
  parser.add_argument(
46
+ "-s",
47
+ "--ref_text",
48
+ type=str,
49
+ default="666",
50
+ help="Subtitle for the reference audio."
51
  )
 
 
52
  parser.add_argument(
53
  "-t",
54
  "--gen_text",
 
88
  gen_text = codecs.open(gen_file, "r", "utf-8").read()
89
  output_dir = args.output_dir if args.output_dir else config["output_dir"]
90
  model = args.model if args.model else config["model"]
 
 
91
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
92
+ wave_path = Path(output_dir)/"out.wav"
93
+ spectrogram_path = Path(output_dir)/"out.png"
94
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
95
 
96
+ device = (
97
+ "cuda"
98
+ if torch.cuda.is_available()
99
+ else "mps" if torch.backends.mps.is_available() else "cpu"
100
+ )
101
+
102
+ if args.load_vocoder_from_local:
103
+ print(f"Load vocos from local path {vocos_local_path}")
104
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
105
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
106
+ vocos.load_state_dict(state_dict)
107
+ vocos.eval()
108
+ else:
109
+ print("Donwload Vocos from huggingface charactr/vocos-mel-24khz")
110
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
111
+
112
+ print(f"Using {device} device")
113
+
114
+ # --------------------- Settings -------------------- #
115
+
116
+ target_sample_rate = 24000
117
+ n_mel_channels = 100
118
+ hop_length = 256
119
+ target_rms = 0.1
120
+ nfe_step = 32 # 16, 32
121
+ cfg_strength = 2.0
122
+ ode_method = "euler"
123
+ sway_sampling_coef = -1.0
124
+ speed = 1.0
125
+ # fix_duration = 27 # None or float (duration in seconds)
126
+ fix_duration = None
127
+
128
+ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
129
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
130
+ if not Path(ckpt_path).exists():
131
+ ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
132
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
133
+ model = CFM(
134
+ transformer=model_cls(
135
+ **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
136
+ ),
137
+ mel_spec_kwargs=dict(
138
+ target_sample_rate=target_sample_rate,
139
+ n_mel_channels=n_mel_channels,
140
+ hop_length=hop_length,
141
+ ),
142
+ odeint_kwargs=dict(
143
+ method=ode_method,
144
+ ),
145
+ vocab_char_map=vocab_char_map,
146
+ ).to(device)
147
+
148
+ model = load_checkpoint(model, ckpt_path, device, use_ema = True)
149
+
150
+ return model
151
 
152
 
153
  # load models
154
+ F5TTS_model_cfg = dict(
155
+ dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
156
+ )
157
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
158
+
159
+
160
+ def chunk_text(text, max_chars=135):
161
+ """
162
+ Splits the input text into chunks, each with a maximum number of characters.
163
+ Args:
164
+ text (str): The text to be split.
165
+ max_chars (int): The maximum number of characters per chunk.
166
+ Returns:
167
+ List[str]: A list of text chunks.
168
+ """
169
+ chunks = []
170
+ current_chunk = ""
171
+ # Split the text into sentences based on punctuation followed by whitespace
172
+ sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
173
+
174
+ for sentence in sentences:
175
+ if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
176
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
177
+ else:
178
+ if current_chunk:
179
+ chunks.append(current_chunk.strip())
180
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
181
+
182
+ if current_chunk:
183
+ chunks.append(current_chunk.strip())
184
+
185
+ return chunks
186
+
187
+
188
+ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
189
+ if model == "F5-TTS":
190
+ ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
191
+ elif model == "E2-TTS":
192
+ ema_model = load_model(model, "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
193
+
194
+ audio, sr = ref_audio
195
+ if audio.shape[0] > 1:
196
+ audio = torch.mean(audio, dim=0, keepdim=True)
197
+
198
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
199
+ if rms < target_rms:
200
+ audio = audio * target_rms / rms
201
+ if sr != target_sample_rate:
202
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
203
+ audio = resampler(audio)
204
+ audio = audio.to(device)
205
+
206
+ generated_waves = []
207
+ spectrograms = []
208
+
209
+ for i, gen_text in enumerate(tqdm.tqdm(gen_text_batches)):
210
+ # Prepare the text
211
+ if len(ref_text[-1].encode('utf-8')) == 1:
212
+ ref_text = ref_text + " "
213
+ text_list = [ref_text + gen_text]
214
+ final_text_list = convert_char_to_pinyin(text_list)
215
+
216
+ # Calculate duration
217
+ ref_audio_len = audio.shape[-1] // hop_length
218
+ zh_pause_punc = r"。,、;:?!"
219
+ ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
220
+ gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
221
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
222
+
223
+ # inference
224
+ with torch.inference_mode():
225
+ generated, _ = ema_model.sample(
226
+ cond=audio,
227
+ text=final_text_list,
228
+ duration=duration,
229
+ steps=nfe_step,
230
+ cfg_strength=cfg_strength,
231
+ sway_sampling_coef=sway_sampling_coef,
232
+ )
233
+
234
+ generated = generated[:, ref_audio_len:, :]
235
+ generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
236
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
237
+ if rms < target_rms:
238
+ generated_wave = generated_wave * rms / target_rms
239
+
240
+ # wav -> numpy
241
+ generated_wave = generated_wave.squeeze().cpu().numpy()
242
+
243
+ generated_waves.append(generated_wave)
244
+ spectrograms.append(generated_mel_spec[0].cpu().numpy())
245
+
246
+ # Combine all generated waves with cross-fading
247
+ if cross_fade_duration <= 0:
248
+ # Simply concatenate
249
+ final_wave = np.concatenate(generated_waves)
250
+ else:
251
+ final_wave = generated_waves[0]
252
+ for i in range(1, len(generated_waves)):
253
+ prev_wave = final_wave
254
+ next_wave = generated_waves[i]
255
+
256
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
257
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
258
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
259
+
260
+ if cross_fade_samples <= 0:
261
+ # No overlap possible, concatenate
262
+ final_wave = np.concatenate([prev_wave, next_wave])
263
+ continue
264
+
265
+ # Overlapping parts
266
+ prev_overlap = prev_wave[-cross_fade_samples:]
267
+ next_overlap = next_wave[:cross_fade_samples]
268
+
269
+ # Fade out and fade in
270
+ fade_out = np.linspace(1, 0, cross_fade_samples)
271
+ fade_in = np.linspace(0, 1, cross_fade_samples)
272
+
273
+ # Cross-faded overlap
274
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
275
+
276
+ # Combine
277
+ new_wave = np.concatenate([
278
+ prev_wave[:-cross_fade_samples],
279
+ cross_faded_overlap,
280
+ next_wave[cross_fade_samples:]
281
+ ])
282
+
283
+ final_wave = new_wave
284
+
285
+ # Create a combined spectrogram
286
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
287
+
288
+ return final_wave, combined_spectrogram
289
+
290
+ def process_voice(ref_audio_orig, ref_text):
291
+ print("Converting audio...")
292
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
293
+ aseg = AudioSegment.from_file(ref_audio_orig)
294
+
295
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
296
+ non_silent_wave = AudioSegment.silent(duration=0)
297
+ for non_silent_seg in non_silent_segs:
298
+ non_silent_wave += non_silent_seg
299
+ aseg = non_silent_wave
300
+
301
+ audio_duration = len(aseg)
302
+ if audio_duration > 15000:
303
+ print("Audio is over 15s, clipping to only first 15s.")
304
+ aseg = aseg[:15000]
305
+ aseg.export(f.name, format="wav")
306
+ ref_audio = f.name
307
+
308
+ if not ref_text.strip():
309
+ print("No reference text provided, transcribing reference audio...")
310
+ pipe = pipeline(
311
+ "automatic-speech-recognition",
312
+ model="openai/whisper-large-v3-turbo",
313
+ torch_dtype=torch.float16,
314
+ device=device,
315
+ )
316
+ ref_text = pipe(
317
+ ref_audio,
318
+ chunk_length_s=30,
319
+ batch_size=128,
320
+ generate_kwargs={"task": "transcribe"},
321
+ return_timestamps=False,
322
+ )["text"].strip()
323
+ print("Finished transcription")
324
+ else:
325
+ print("Using custom reference text...")
326
+ return ref_audio, ref_text
327
+
328
+ def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
329
+ print(gen_text)
330
+ # Add the functionality to ensure it ends with ". "
331
+ if not ref_text.endswith(". ") and not ref_text.endswith("。"):
332
+ if ref_text.endswith("."):
333
+ ref_text += " "
334
+ else:
335
+ ref_text += ". "
336
+
337
+ # Split the input text into batches
338
+ audio, sr = torchaudio.load(ref_audio)
339
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
340
+ gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
341
+ print('ref_text', ref_text)
342
+ for i, gen_text in enumerate(gen_text_batches):
343
+ print(f'gen_text {i}', gen_text)
344
+
345
+ print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
346
+ return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
347
+
348
+
349
+ def process(ref_audio, ref_text, text_gen, model, remove_silence):
350
+ main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
351
  if "voices" not in config:
352
  voices = {"main": main_voice}
353
  else:
354
  voices = config["voices"]
355
  voices["main"] = main_voice
356
  for voice in voices:
357
+ voices[voice]['ref_audio'], voices[voice]['ref_text'] = process_voice(voices[voice]['ref_audio'], voices[voice]['ref_text'])
 
 
 
 
 
358
 
359
  generated_audio_segments = []
360
+ reg1 = r'(?=\[\w+\])'
361
  chunks = re.split(reg1, text_gen)
362
+ reg2 = r'\[(\w+)\]'
363
  for text in chunks:
364
  match = re.match(reg2, text)
365
+ if not match or voice not in voices:
 
 
 
 
 
 
366
  voice = "main"
367
+ else:
368
+ voice = match[1]
369
  text = re.sub(reg2, "", text)
370
  gen_text = text.strip()
371
+ ref_audio = voices[voice]['ref_audio']
372
+ ref_text = voices[voice]['ref_text']
373
  print(f"Voice: {voice}")
374
+ audio, spectragram = infer(ref_audio, ref_text, gen_text, model, remove_silence)
375
  generated_audio_segments.append(audio)
376
 
377
  if generated_audio_segments:
378
  final_wave = np.concatenate(generated_audio_segments)
379
  with open(wave_path, "wb") as f:
380
+ sf.write(f.name, final_wave, target_sample_rate)
381
  # Remove silence
382
  if remove_silence:
383
+ aseg = AudioSegment.from_file(f.name)
384
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
385
+ non_silent_wave = AudioSegment.silent(duration=0)
386
+ for non_silent_seg in non_silent_segs:
387
+ non_silent_wave += non_silent_seg
388
+ aseg = non_silent_wave
389
+ aseg.export(f.name, format="wav")
390
  print(f.name)
391
 
392
+ process(ref_audio, ref_text, gen_text, model, remove_silence)
 
model/__init__.py CHANGED
@@ -5,6 +5,3 @@ from model.backbones.dit import DiT
5
  from model.backbones.mmdit import MMDiT
6
 
7
  from model.trainer import Trainer
8
-
9
-
10
- __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
 
5
  from model.backbones.mmdit import MMDiT
6
 
7
  from model.trainer import Trainer
 
 
 
model/backbones/dit.py CHANGED
@@ -13,6 +13,8 @@ import torch
13
  from torch import nn
14
  import torch.nn.functional as F
15
 
 
 
16
  from x_transformers.x_transformers import RotaryEmbedding
17
 
18
  from model.modules import (
@@ -21,16 +23,14 @@ from model.modules import (
21
  ConvPositionEmbedding,
22
  DiTBlock,
23
  AdaLayerNormZero_Final,
24
- precompute_freqs_cis,
25
- get_pos_embed_indices,
26
  )
27
 
28
 
29
  # Text embedding
30
 
31
-
32
  class TextEmbedding(nn.Module):
33
- def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
34
  super().__init__()
35
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
 
@@ -38,22 +38,20 @@ class TextEmbedding(nn.Module):
38
  self.extra_modeling = True
39
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
40
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
41
- self.text_blocks = nn.Sequential(
42
- *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
43
- )
44
  else:
45
  self.extra_modeling = False
46
 
47
- def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
 
48
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
49
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
50
- batch, text_len = text.shape[0], text.shape[1]
51
- text = F.pad(text, (0, seq_len - text_len), value=0)
52
 
53
  if drop_text: # cfg for text
54
  text = torch.zeros_like(text)
55
 
56
- text = self.text_embed(text) # b n -> b n d
57
 
58
  # possible extra modeling
59
  if self.extra_modeling:
@@ -71,91 +69,88 @@ class TextEmbedding(nn.Module):
71
 
72
  # noised input audio and context mixing embedding
73
 
74
-
75
  class InputEmbedding(nn.Module):
76
  def __init__(self, mel_dim, text_dim, out_dim):
77
  super().__init__()
78
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
79
- self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
80
 
81
- def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
82
  if drop_audio_cond: # cfg for cond audio
83
  cond = torch.zeros_like(cond)
84
 
85
- x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
86
  x = self.conv_pos_embed(x) + x
87
  return x
88
-
89
 
90
  # Transformer backbone using DiT blocks
91
 
92
-
93
  class DiT(nn.Module):
94
- def __init__(
95
- self,
96
- *,
97
- dim,
98
- depth=8,
99
- heads=8,
100
- dim_head=64,
101
- dropout=0.1,
102
- ff_mult=4,
103
- mel_dim=100,
104
- text_num_embeds=256,
105
- text_dim=None,
106
- conv_layers=0,
107
- long_skip_connection=False,
108
  ):
109
  super().__init__()
110
 
111
  self.time_embed = TimestepEmbedding(dim)
112
  if text_dim is None:
113
  text_dim = mel_dim
114
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
115
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
116
 
117
  self.rotary_embed = RotaryEmbedding(dim_head)
118
 
119
  self.dim = dim
120
  self.depth = depth
121
-
122
  self.transformer_blocks = nn.ModuleList(
123
- [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
 
 
 
 
 
 
 
 
 
124
  )
125
- self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
126
-
127
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
128
  self.proj_out = nn.Linear(dim, mel_dim)
129
 
130
  def forward(
131
  self,
132
- x: float["b n d"], # nosied input audio # noqa: F722
133
- cond: float["b n d"], # masked cond audio # noqa: F722
134
- text: int["b nt"], # text # noqa: F722
135
- time: float["b"] | float[""], # time step # noqa: F821 F722
136
  drop_audio_cond, # cfg for cond audio
137
- drop_text, # cfg for text
138
- mask: bool["b n"] | None = None, # noqa: F722
139
  ):
140
  batch, seq_len = x.shape[0], x.shape[1]
141
  if time.ndim == 0:
142
- time = time.repeat(batch)
143
-
144
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
145
  t = self.time_embed(time)
146
- text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
147
- x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
148
-
149
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
150
 
151
  if self.long_skip_connection is not None:
152
  residual = x
153
 
154
  for block in self.transformer_blocks:
155
- x = block(x, t, mask=mask, rope=rope)
156
 
157
  if self.long_skip_connection is not None:
158
- x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
159
 
160
  x = self.norm_out(x, t)
161
  output = self.proj_out(x)
 
13
  from torch import nn
14
  import torch.nn.functional as F
15
 
16
+ from einops import repeat
17
+
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
20
  from model.modules import (
 
23
  ConvPositionEmbedding,
24
  DiTBlock,
25
  AdaLayerNormZero_Final,
26
+ precompute_freqs_cis, get_pos_embed_indices,
 
27
  )
28
 
29
 
30
  # Text embedding
31
 
 
32
  class TextEmbedding(nn.Module):
33
+ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
34
  super().__init__()
35
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
 
 
38
  self.extra_modeling = True
39
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
40
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
41
+ self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
 
 
42
  else:
43
  self.extra_modeling = False
44
 
45
+ def forward(self, text: int['b nt'], seq_len, drop_text = False):
46
+ batch, text_len = text.shape[0], text.shape[1]
47
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
48
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
49
+ text = F.pad(text, (0, seq_len - text_len), value = 0)
 
50
 
51
  if drop_text: # cfg for text
52
  text = torch.zeros_like(text)
53
 
54
+ text = self.text_embed(text) # b n -> b n d
55
 
56
  # possible extra modeling
57
  if self.extra_modeling:
 
69
 
70
  # noised input audio and context mixing embedding
71
 
 
72
  class InputEmbedding(nn.Module):
73
  def __init__(self, mel_dim, text_dim, out_dim):
74
  super().__init__()
75
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
76
+ self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
77
 
78
+ def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
79
  if drop_audio_cond: # cfg for cond audio
80
  cond = torch.zeros_like(cond)
81
 
82
+ x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
83
  x = self.conv_pos_embed(x) + x
84
  return x
85
+
86
 
87
  # Transformer backbone using DiT blocks
88
 
 
89
  class DiT(nn.Module):
90
+ def __init__(self, *,
91
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
92
+ mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
93
+ long_skip_connection = False,
 
 
 
 
 
 
 
 
 
 
94
  ):
95
  super().__init__()
96
 
97
  self.time_embed = TimestepEmbedding(dim)
98
  if text_dim is None:
99
  text_dim = mel_dim
100
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
101
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
102
 
103
  self.rotary_embed = RotaryEmbedding(dim_head)
104
 
105
  self.dim = dim
106
  self.depth = depth
107
+
108
  self.transformer_blocks = nn.ModuleList(
109
+ [
110
+ DiTBlock(
111
+ dim = dim,
112
+ heads = heads,
113
+ dim_head = dim_head,
114
+ ff_mult = ff_mult,
115
+ dropout = dropout
116
+ )
117
+ for _ in range(depth)
118
+ ]
119
  )
120
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None
121
+
122
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
123
  self.proj_out = nn.Linear(dim, mel_dim)
124
 
125
  def forward(
126
  self,
127
+ x: float['b n d'], # nosied input audio
128
+ cond: float['b n d'], # masked cond audio
129
+ text: int['b nt'], # text
130
+ time: float['b'] | float[''], # time step
131
  drop_audio_cond, # cfg for cond audio
132
+ drop_text, # cfg for text
133
+ mask: bool['b n'] | None = None,
134
  ):
135
  batch, seq_len = x.shape[0], x.shape[1]
136
  if time.ndim == 0:
137
+ time = repeat(time, ' -> b', b = batch)
138
+
139
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
140
  t = self.time_embed(time)
141
+ text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
142
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
143
+
144
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
145
 
146
  if self.long_skip_connection is not None:
147
  residual = x
148
 
149
  for block in self.transformer_blocks:
150
+ x = block(x, t, mask = mask, rope = rope)
151
 
152
  if self.long_skip_connection is not None:
153
+ x = self.long_skip_connection(torch.cat((x, residual), dim = -1))
154
 
155
  x = self.norm_out(x, t)
156
  output = self.proj_out(x)
model/backbones/mmdit.py CHANGED
@@ -12,6 +12,8 @@ from __future__ import annotations
12
  import torch
13
  from torch import nn
14
 
 
 
15
  from x_transformers.x_transformers import RotaryEmbedding
16
 
17
  from model.modules import (
@@ -19,14 +21,12 @@ from model.modules import (
19
  ConvPositionEmbedding,
20
  MMDiTBlock,
21
  AdaLayerNormZero_Final,
22
- precompute_freqs_cis,
23
- get_pos_embed_indices,
24
  )
25
 
26
 
27
  # text embedding
28
 
29
-
30
  class TextEmbedding(nn.Module):
31
  def __init__(self, out_dim, text_num_embeds):
32
  super().__init__()
@@ -35,7 +35,7 @@ class TextEmbedding(nn.Module):
35
  self.precompute_max_pos = 1024
36
  self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
 
38
- def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
39
  text = text + 1
40
  if drop_text:
41
  text = torch.zeros_like(text)
@@ -54,37 +54,27 @@ class TextEmbedding(nn.Module):
54
 
55
  # noised input & masked cond audio embedding
56
 
57
-
58
  class AudioEmbedding(nn.Module):
59
  def __init__(self, in_dim, out_dim):
60
  super().__init__()
61
  self.linear = nn.Linear(2 * in_dim, out_dim)
62
  self.conv_pos_embed = ConvPositionEmbedding(out_dim)
63
 
64
- def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
65
  if drop_audio_cond:
66
  cond = torch.zeros_like(cond)
67
- x = torch.cat((x, cond), dim=-1)
68
  x = self.linear(x)
69
  x = self.conv_pos_embed(x) + x
70
  return x
71
-
72
 
73
  # Transformer backbone using MM-DiT blocks
74
 
75
-
76
  class MMDiT(nn.Module):
77
- def __init__(
78
- self,
79
- *,
80
- dim,
81
- depth=8,
82
- heads=8,
83
- dim_head=64,
84
- dropout=0.1,
85
- ff_mult=4,
86
- text_num_embeds=256,
87
- mel_dim=100,
88
  ):
89
  super().__init__()
90
 
@@ -96,16 +86,16 @@ class MMDiT(nn.Module):
96
 
97
  self.dim = dim
98
  self.depth = depth
99
-
100
  self.transformer_blocks = nn.ModuleList(
101
  [
102
  MMDiTBlock(
103
- dim=dim,
104
- heads=heads,
105
- dim_head=dim_head,
106
- dropout=dropout,
107
- ff_mult=ff_mult,
108
- context_pre_only=i == depth - 1,
109
  )
110
  for i in range(depth)
111
  ]
@@ -115,30 +105,30 @@ class MMDiT(nn.Module):
115
 
116
  def forward(
117
  self,
118
- x: float["b n d"], # nosied input audio # noqa: F722
119
- cond: float["b n d"], # masked cond audio # noqa: F722
120
- text: int["b nt"], # text # noqa: F722
121
- time: float["b"] | float[""], # time step # noqa: F821 F722
122
  drop_audio_cond, # cfg for cond audio
123
- drop_text, # cfg for text
124
- mask: bool["b n"] | None = None, # noqa: F722
125
  ):
126
  batch = x.shape[0]
127
  if time.ndim == 0:
128
- time = time.repeat(batch)
129
 
130
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
131
  t = self.time_embed(time)
132
- c = self.text_embed(text, drop_text=drop_text)
133
- x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
134
 
135
  seq_len = x.shape[1]
136
  text_len = text.shape[1]
137
  rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
138
  rope_text = self.rotary_embed.forward_from_seq_len(text_len)
139
-
140
  for block in self.transformer_blocks:
141
- c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
142
 
143
  x = self.norm_out(x, t)
144
  output = self.proj_out(x)
 
12
  import torch
13
  from torch import nn
14
 
15
+ from einops import repeat
16
+
17
  from x_transformers.x_transformers import RotaryEmbedding
18
 
19
  from model.modules import (
 
21
  ConvPositionEmbedding,
22
  MMDiTBlock,
23
  AdaLayerNormZero_Final,
24
+ precompute_freqs_cis, get_pos_embed_indices,
 
25
  )
26
 
27
 
28
  # text embedding
29
 
 
30
  class TextEmbedding(nn.Module):
31
  def __init__(self, out_dim, text_num_embeds):
32
  super().__init__()
 
35
  self.precompute_max_pos = 1024
36
  self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
 
38
+ def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']:
39
  text = text + 1
40
  if drop_text:
41
  text = torch.zeros_like(text)
 
54
 
55
  # noised input & masked cond audio embedding
56
 
 
57
  class AudioEmbedding(nn.Module):
58
  def __init__(self, in_dim, out_dim):
59
  super().__init__()
60
  self.linear = nn.Linear(2 * in_dim, out_dim)
61
  self.conv_pos_embed = ConvPositionEmbedding(out_dim)
62
 
63
+ def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False):
64
  if drop_audio_cond:
65
  cond = torch.zeros_like(cond)
66
+ x = torch.cat((x, cond), dim = -1)
67
  x = self.linear(x)
68
  x = self.conv_pos_embed(x) + x
69
  return x
70
+
71
 
72
  # Transformer backbone using MM-DiT blocks
73
 
 
74
  class MMDiT(nn.Module):
75
+ def __init__(self, *,
76
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
77
+ text_num_embeds = 256, mel_dim = 100,
 
 
 
 
 
 
 
 
78
  ):
79
  super().__init__()
80
 
 
86
 
87
  self.dim = dim
88
  self.depth = depth
89
+
90
  self.transformer_blocks = nn.ModuleList(
91
  [
92
  MMDiTBlock(
93
+ dim = dim,
94
+ heads = heads,
95
+ dim_head = dim_head,
96
+ dropout = dropout,
97
+ ff_mult = ff_mult,
98
+ context_pre_only = i == depth - 1,
99
  )
100
  for i in range(depth)
101
  ]
 
105
 
106
  def forward(
107
  self,
108
+ x: float['b n d'], # nosied input audio
109
+ cond: float['b n d'], # masked cond audio
110
+ text: int['b nt'], # text
111
+ time: float['b'] | float[''], # time step
112
  drop_audio_cond, # cfg for cond audio
113
+ drop_text, # cfg for text
114
+ mask: bool['b n'] | None = None,
115
  ):
116
  batch = x.shape[0]
117
  if time.ndim == 0:
118
+ time = repeat(time, ' -> b', b = batch)
119
 
120
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
121
  t = self.time_embed(time)
122
+ c = self.text_embed(text, drop_text = drop_text)
123
+ x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond)
124
 
125
  seq_len = x.shape[1]
126
  text_len = text.shape[1]
127
  rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
128
  rope_text = self.rotary_embed.forward_from_seq_len(text_len)
129
+
130
  for block in self.transformer_blocks:
131
+ c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text)
132
 
133
  x = self.norm_out(x, t)
134
  output = self.proj_out(x)
model/backbones/unett.py CHANGED
@@ -14,6 +14,8 @@ import torch
14
  from torch import nn
15
  import torch.nn.functional as F
16
 
 
 
17
  from x_transformers import RMSNorm
18
  from x_transformers.x_transformers import RotaryEmbedding
19
 
@@ -24,16 +26,14 @@ from model.modules import (
24
  Attention,
25
  AttnProcessor,
26
  FeedForward,
27
- precompute_freqs_cis,
28
- get_pos_embed_indices,
29
  )
30
 
31
 
32
  # Text embedding
33
 
34
-
35
  class TextEmbedding(nn.Module):
36
- def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
37
  super().__init__()
38
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
 
@@ -41,22 +41,20 @@ class TextEmbedding(nn.Module):
41
  self.extra_modeling = True
42
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
44
- self.text_blocks = nn.Sequential(
45
- *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
46
- )
47
  else:
48
  self.extra_modeling = False
49
 
50
- def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
 
51
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
52
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
53
- batch, text_len = text.shape[0], text.shape[1]
54
- text = F.pad(text, (0, seq_len - text_len), value=0)
55
 
56
  if drop_text: # cfg for text
57
  text = torch.zeros_like(text)
58
 
59
- text = self.text_embed(text) # b n -> b n d
60
 
61
  # possible extra modeling
62
  if self.extra_modeling:
@@ -74,40 +72,28 @@ class TextEmbedding(nn.Module):
74
 
75
  # noised input audio and context mixing embedding
76
 
77
-
78
  class InputEmbedding(nn.Module):
79
  def __init__(self, mel_dim, text_dim, out_dim):
80
  super().__init__()
81
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
82
- self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
83
 
84
- def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
85
  if drop_audio_cond: # cfg for cond audio
86
  cond = torch.zeros_like(cond)
87
 
88
- x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
89
  x = self.conv_pos_embed(x) + x
90
  return x
91
 
92
 
93
  # Flat UNet Transformer backbone
94
 
95
-
96
  class UNetT(nn.Module):
97
- def __init__(
98
- self,
99
- *,
100
- dim,
101
- depth=8,
102
- heads=8,
103
- dim_head=64,
104
- dropout=0.1,
105
- ff_mult=4,
106
- mel_dim=100,
107
- text_num_embeds=256,
108
- text_dim=None,
109
- conv_layers=0,
110
- skip_connect_type: Literal["add", "concat", "none"] = "concat",
111
  ):
112
  super().__init__()
113
  assert depth % 2 == 0, "UNet-Transformer's depth should be even."
@@ -115,7 +101,7 @@ class UNetT(nn.Module):
115
  self.time_embed = TimestepEmbedding(dim)
116
  if text_dim is None:
117
  text_dim = mel_dim
118
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
119
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
120
 
121
  self.rotary_embed = RotaryEmbedding(dim_head)
@@ -124,7 +110,7 @@ class UNetT(nn.Module):
124
 
125
  self.dim = dim
126
  self.skip_connect_type = skip_connect_type
127
- needs_skip_proj = skip_connect_type == "concat"
128
 
129
  self.depth = depth
130
  self.layers = nn.ModuleList([])
@@ -134,57 +120,53 @@ class UNetT(nn.Module):
134
 
135
  attn_norm = RMSNorm(dim)
136
  attn = Attention(
137
- processor=AttnProcessor(),
138
- dim=dim,
139
- heads=heads,
140
- dim_head=dim_head,
141
- dropout=dropout,
142
- )
143
 
144
  ff_norm = RMSNorm(dim)
145
- ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
146
-
147
- skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
148
-
149
- self.layers.append(
150
- nn.ModuleList(
151
- [
152
- skip_proj,
153
- attn_norm,
154
- attn,
155
- ff_norm,
156
- ff,
157
- ]
158
- )
159
- )
160
 
161
  self.norm_out = RMSNorm(dim)
162
  self.proj_out = nn.Linear(dim, mel_dim)
163
 
164
  def forward(
165
  self,
166
- x: float["b n d"], # nosied input audio # noqa: F722
167
- cond: float["b n d"], # masked cond audio # noqa: F722
168
- text: int["b nt"], # text # noqa: F722
169
- time: float["b"] | float[""], # time step # noqa: F821 F722
170
  drop_audio_cond, # cfg for cond audio
171
- drop_text, # cfg for text
172
- mask: bool["b n"] | None = None, # noqa: F722
173
  ):
174
  batch, seq_len = x.shape[0], x.shape[1]
175
  if time.ndim == 0:
176
- time = time.repeat(batch)
177
-
178
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
179
  t = self.time_embed(time)
180
- text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
181
- x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
182
 
183
  # postfix time t to input x, [b n d] -> [b n+1 d]
184
- x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
185
  if mask is not None:
186
  mask = F.pad(mask, (1, 0), value=1)
187
-
188
  rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
189
 
190
  # flat unet transformer
@@ -202,18 +184,18 @@ class UNetT(nn.Module):
202
 
203
  if is_later_half:
204
  skip = skips.pop()
205
- if skip_connect_type == "concat":
206
- x = torch.cat((x, skip), dim=-1)
207
  x = maybe_skip_proj(x)
208
- elif skip_connect_type == "add":
209
  x = x + skip
210
 
211
  # attention and feedforward blocks
212
- x = attn(attn_norm(x), rope=rope, mask=mask) + x
213
  x = ff(ff_norm(x)) + x
214
 
215
  assert len(skips) == 0
216
 
217
- x = self.norm_out(x)[:, 1:, :] # unpack t from x
218
 
219
  return self.proj_out(x)
 
14
  from torch import nn
15
  import torch.nn.functional as F
16
 
17
+ from einops import repeat, pack, unpack
18
+
19
  from x_transformers import RMSNorm
20
  from x_transformers.x_transformers import RotaryEmbedding
21
 
 
26
  Attention,
27
  AttnProcessor,
28
  FeedForward,
29
+ precompute_freqs_cis, get_pos_embed_indices,
 
30
  )
31
 
32
 
33
  # Text embedding
34
 
 
35
  class TextEmbedding(nn.Module):
36
+ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
37
  super().__init__()
38
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
 
 
41
  self.extra_modeling = True
42
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
44
+ self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
 
 
45
  else:
46
  self.extra_modeling = False
47
 
48
+ def forward(self, text: int['b nt'], seq_len, drop_text = False):
49
+ batch, text_len = text.shape[0], text.shape[1]
50
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
51
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
52
+ text = F.pad(text, (0, seq_len - text_len), value = 0)
 
53
 
54
  if drop_text: # cfg for text
55
  text = torch.zeros_like(text)
56
 
57
+ text = self.text_embed(text) # b n -> b n d
58
 
59
  # possible extra modeling
60
  if self.extra_modeling:
 
72
 
73
  # noised input audio and context mixing embedding
74
 
 
75
  class InputEmbedding(nn.Module):
76
  def __init__(self, mel_dim, text_dim, out_dim):
77
  super().__init__()
78
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
79
+ self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
80
 
81
+ def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
82
  if drop_audio_cond: # cfg for cond audio
83
  cond = torch.zeros_like(cond)
84
 
85
+ x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
86
  x = self.conv_pos_embed(x) + x
87
  return x
88
 
89
 
90
  # Flat UNet Transformer backbone
91
 
 
92
  class UNetT(nn.Module):
93
+ def __init__(self, *,
94
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
95
+ mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
96
+ skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
 
 
 
 
 
 
 
 
 
 
97
  ):
98
  super().__init__()
99
  assert depth % 2 == 0, "UNet-Transformer's depth should be even."
 
101
  self.time_embed = TimestepEmbedding(dim)
102
  if text_dim is None:
103
  text_dim = mel_dim
104
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
105
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
106
 
107
  self.rotary_embed = RotaryEmbedding(dim_head)
 
110
 
111
  self.dim = dim
112
  self.skip_connect_type = skip_connect_type
113
+ needs_skip_proj = skip_connect_type == 'concat'
114
 
115
  self.depth = depth
116
  self.layers = nn.ModuleList([])
 
120
 
121
  attn_norm = RMSNorm(dim)
122
  attn = Attention(
123
+ processor = AttnProcessor(),
124
+ dim = dim,
125
+ heads = heads,
126
+ dim_head = dim_head,
127
+ dropout = dropout,
128
+ )
129
 
130
  ff_norm = RMSNorm(dim)
131
+ ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
132
+
133
+ skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None
134
+
135
+ self.layers.append(nn.ModuleList([
136
+ skip_proj,
137
+ attn_norm,
138
+ attn,
139
+ ff_norm,
140
+ ff,
141
+ ]))
 
 
 
 
142
 
143
  self.norm_out = RMSNorm(dim)
144
  self.proj_out = nn.Linear(dim, mel_dim)
145
 
146
  def forward(
147
  self,
148
+ x: float['b n d'], # nosied input audio
149
+ cond: float['b n d'], # masked cond audio
150
+ text: int['b nt'], # text
151
+ time: float['b'] | float[''], # time step
152
  drop_audio_cond, # cfg for cond audio
153
+ drop_text, # cfg for text
154
+ mask: bool['b n'] | None = None,
155
  ):
156
  batch, seq_len = x.shape[0], x.shape[1]
157
  if time.ndim == 0:
158
+ time = repeat(time, ' -> b', b = batch)
159
+
160
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
161
  t = self.time_embed(time)
162
+ text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
163
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
164
 
165
  # postfix time t to input x, [b n d] -> [b n+1 d]
166
+ x, ps = pack((t, x), 'b * d')
167
  if mask is not None:
168
  mask = F.pad(mask, (1, 0), value=1)
169
+
170
  rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
171
 
172
  # flat unet transformer
 
184
 
185
  if is_later_half:
186
  skip = skips.pop()
187
+ if skip_connect_type == 'concat':
188
+ x = torch.cat((x, skip), dim = -1)
189
  x = maybe_skip_proj(x)
190
+ elif skip_connect_type == 'add':
191
  x = x + skip
192
 
193
  # attention and feedforward blocks
194
+ x = attn(attn_norm(x), rope = rope, mask = mask) + x
195
  x = ff(ff_norm(x)) + x
196
 
197
  assert len(skips) == 0
198
 
199
+ _, x = unpack(self.norm_out(x), ps, 'b * d')
200
 
201
  return self.proj_out(x)
model/cfm.py CHANGED
@@ -18,34 +18,34 @@ from torch.nn.utils.rnn import pad_sequence
18
 
19
  from torchdiffeq import odeint
20
 
 
 
21
  from model.modules import MelSpec
 
22
  from model.utils import (
23
- default,
24
- exists,
25
- list_str_to_idx,
26
- list_str_to_tensor,
27
- lens_to_mask,
28
- mask_from_frac_lengths,
29
- )
30
 
31
 
32
  class CFM(nn.Module):
33
  def __init__(
34
  self,
35
  transformer: nn.Module,
36
- sigma=0.0,
37
  odeint_kwargs: dict = dict(
38
  # atol = 1e-5,
39
  # rtol = 1e-5,
40
- method="euler" # 'midpoint'
41
  ),
42
- audio_drop_prob=0.3,
43
- cond_drop_prob=0.2,
44
- num_channels=None,
45
  mel_spec_module: nn.Module | None = None,
46
  mel_spec_kwargs: dict = dict(),
47
- frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
48
- vocab_char_map: dict[str:int] | None = None,
49
  ):
50
  super().__init__()
51
 
@@ -81,37 +81,34 @@ class CFM(nn.Module):
81
  @torch.no_grad()
82
  def sample(
83
  self,
84
- cond: float["b n d"] | float["b nw"], # noqa: F722
85
- text: int["b nt"] | list[str], # noqa: F722
86
- duration: int | int["b"], # noqa: F821
87
  *,
88
- lens: int["b"] | None = None, # noqa: F821
89
- steps=32,
90
- cfg_strength=1.0,
91
- sway_sampling_coef=None,
92
  seed: int | None = None,
93
- max_duration=4096,
94
- vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
95
- no_ref_audio=False,
96
- duplicate_test=False,
97
- t_inter=0.1,
98
- edit_mask=None,
99
  ):
100
  self.eval()
101
 
102
- if next(self.parameters()).dtype == torch.float16:
103
- cond = cond.half()
104
-
105
  # raw wave
106
 
107
  if cond.ndim == 2:
108
  cond = self.mel_spec(cond)
109
- cond = cond.permute(0, 2, 1)
110
  assert cond.shape[-1] == self.num_channels
111
 
112
  batch, cond_seq_len, device = *cond.shape[:2], cond.device
113
  if not exists(lens):
114
- lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
115
 
116
  # text
117
 
@@ -123,8 +120,8 @@ class CFM(nn.Module):
123
  assert text.shape[0] == batch
124
 
125
  if exists(text):
126
- text_lens = (text != -1).sum(dim=-1)
127
- lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
128
 
129
  # duration
130
 
@@ -133,22 +130,20 @@ class CFM(nn.Module):
133
  cond_mask = cond_mask & edit_mask
134
 
135
  if isinstance(duration, int):
136
- duration = torch.full((batch,), duration, device=device, dtype=torch.long)
137
 
138
- duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
139
- duration = duration.clamp(max=max_duration)
140
  max_duration = duration.amax()
141
-
142
  # duplicate test corner for inner time step oberservation
143
  if duplicate_test:
144
- test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
145
-
146
- cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
147
- cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
148
- cond_mask = cond_mask.unsqueeze(-1)
149
- step_cond = torch.where(
150
- cond_mask, cond, torch.zeros_like(cond)
151
- ) # allow direct control (cut cond audio) with lens passed in
152
 
153
  if batch > 1:
154
  mask = lens_to_mask(duration)
@@ -166,15 +161,11 @@ class CFM(nn.Module):
166
  # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
167
 
168
  # predict flow
169
- pred = self.transformer(
170
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
171
- )
172
  if cfg_strength < 1e-5:
173
  return pred
174
-
175
- null_pred = self.transformer(
176
- x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
177
- )
178
  return pred + (pred - null_pred) * cfg_strength
179
 
180
  # noise input
@@ -184,8 +175,8 @@ class CFM(nn.Module):
184
  for dur in duration:
185
  if exists(seed):
186
  torch.manual_seed(seed)
187
- y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
188
- y0 = pad_sequence(y0, padding_value=0, batch_first=True)
189
 
190
  t_start = 0
191
 
@@ -195,37 +186,37 @@ class CFM(nn.Module):
195
  y0 = (1 - t_start) * y0 + t_start * test_cond
196
  steps = int(steps * (1 - t_start))
197
 
198
- t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
199
  if sway_sampling_coef is not None:
200
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
201
 
202
  trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
203
-
204
  sampled = trajectory[-1]
205
  out = sampled
206
  out = torch.where(cond_mask, cond, out)
207
 
208
  if exists(vocoder):
209
- out = out.permute(0, 2, 1)
210
  out = vocoder(out)
211
 
212
  return out, trajectory
213
 
214
  def forward(
215
  self,
216
- inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
217
- text: int["b nt"] | list[str], # noqa: F722
218
  *,
219
- lens: int["b"] | None = None, # noqa: F821
220
  noise_scheduler: str | None = None,
221
  ):
222
  # handle raw wave
223
  if inp.ndim == 2:
224
  inp = self.mel_spec(inp)
225
- inp = inp.permute(0, 2, 1)
226
  assert inp.shape[-1] == self.num_channels
227
 
228
- batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
229
 
230
  # handle text as string
231
  if isinstance(text, list):
@@ -237,12 +228,12 @@ class CFM(nn.Module):
237
 
238
  # lens and mask
239
  if not exists(lens):
240
- lens = torch.full((batch,), seq_len, device=device)
241
-
242
- mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
243
 
244
  # get a random span to mask out for training conditionally
245
- frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
246
  rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
247
 
248
  if exists(mask):
@@ -255,16 +246,19 @@ class CFM(nn.Module):
255
  x0 = torch.randn_like(x1)
256
 
257
  # time step
258
- time = torch.rand((batch,), dtype=dtype, device=self.device)
259
  # TODO. noise_scheduler
260
 
261
  # sample xt (φ_t(x) in the paper)
262
- t = time.unsqueeze(-1).unsqueeze(-1)
263
  φ = (1 - t) * x0 + t * x1
264
  flow = x1 - x0
265
 
266
  # only predict what is within the random mask span for infilling
267
- cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
 
 
 
268
 
269
  # transformer and cfg training with a drop rate
270
  drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
@@ -273,15 +267,13 @@ class CFM(nn.Module):
273
  drop_text = True
274
  else:
275
  drop_text = False
276
-
277
  # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
278
  # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
279
- pred = self.transformer(
280
- x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
281
- )
282
 
283
  # flow matching loss
284
- loss = F.mse_loss(pred, flow, reduction="none")
285
  loss = loss[rand_span_mask]
286
 
287
  return loss.mean(), cond, pred
 
18
 
19
  from torchdiffeq import odeint
20
 
21
+ from einops import rearrange
22
+
23
  from model.modules import MelSpec
24
+
25
  from model.utils import (
26
+ default, exists,
27
+ list_str_to_idx, list_str_to_tensor,
28
+ lens_to_mask, mask_from_frac_lengths,
29
+ )
 
 
 
30
 
31
 
32
  class CFM(nn.Module):
33
  def __init__(
34
  self,
35
  transformer: nn.Module,
36
+ sigma = 0.,
37
  odeint_kwargs: dict = dict(
38
  # atol = 1e-5,
39
  # rtol = 1e-5,
40
+ method = 'euler' # 'midpoint'
41
  ),
42
+ audio_drop_prob = 0.3,
43
+ cond_drop_prob = 0.2,
44
+ num_channels = None,
45
  mel_spec_module: nn.Module | None = None,
46
  mel_spec_kwargs: dict = dict(),
47
+ frac_lengths_mask: tuple[float, float] = (0.7, 1.),
48
+ vocab_char_map: dict[str: int] | None = None
49
  ):
50
  super().__init__()
51
 
 
81
  @torch.no_grad()
82
  def sample(
83
  self,
84
+ cond: float['b n d'] | float['b nw'],
85
+ text: int['b nt'] | list[str],
86
+ duration: int | int['b'],
87
  *,
88
+ lens: int['b'] | None = None,
89
+ steps = 32,
90
+ cfg_strength = 1.,
91
+ sway_sampling_coef = None,
92
  seed: int | None = None,
93
+ max_duration = 4096,
94
+ vocoder: Callable[[float['b d n']], float['b nw']] | None = None,
95
+ no_ref_audio = False,
96
+ duplicate_test = False,
97
+ t_inter = 0.1,
98
+ edit_mask = None,
99
  ):
100
  self.eval()
101
 
 
 
 
102
  # raw wave
103
 
104
  if cond.ndim == 2:
105
  cond = self.mel_spec(cond)
106
+ cond = rearrange(cond, 'b d n -> b n d')
107
  assert cond.shape[-1] == self.num_channels
108
 
109
  batch, cond_seq_len, device = *cond.shape[:2], cond.device
110
  if not exists(lens):
111
+ lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
112
 
113
  # text
114
 
 
120
  assert text.shape[0] == batch
121
 
122
  if exists(text):
123
+ text_lens = (text != -1).sum(dim = -1)
124
+ lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
125
 
126
  # duration
127
 
 
130
  cond_mask = cond_mask & edit_mask
131
 
132
  if isinstance(duration, int):
133
+ duration = torch.full((batch,), duration, device = device, dtype = torch.long)
134
 
135
+ duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
136
+ duration = duration.clamp(max = max_duration)
137
  max_duration = duration.amax()
138
+
139
  # duplicate test corner for inner time step oberservation
140
  if duplicate_test:
141
+ test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.)
142
+
143
+ cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
144
+ cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
145
+ cond_mask = rearrange(cond_mask, '... -> ... 1')
146
+ step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
 
 
147
 
148
  if batch > 1:
149
  mask = lens_to_mask(duration)
 
161
  # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
162
 
163
  # predict flow
164
+ pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False)
 
 
165
  if cfg_strength < 1e-5:
166
  return pred
167
+
168
+ null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True)
 
 
169
  return pred + (pred - null_pred) * cfg_strength
170
 
171
  # noise input
 
175
  for dur in duration:
176
  if exists(seed):
177
  torch.manual_seed(seed)
178
+ y0.append(torch.randn(dur, self.num_channels, device = self.device))
179
+ y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
180
 
181
  t_start = 0
182
 
 
186
  y0 = (1 - t_start) * y0 + t_start * test_cond
187
  steps = int(steps * (1 - t_start))
188
 
189
+ t = torch.linspace(t_start, 1, steps, device = self.device)
190
  if sway_sampling_coef is not None:
191
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
192
 
193
  trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
194
+
195
  sampled = trajectory[-1]
196
  out = sampled
197
  out = torch.where(cond_mask, cond, out)
198
 
199
  if exists(vocoder):
200
+ out = rearrange(out, 'b n d -> b d n')
201
  out = vocoder(out)
202
 
203
  return out, trajectory
204
 
205
  def forward(
206
  self,
207
+ inp: float['b n d'] | float['b nw'], # mel or raw wave
208
+ text: int['b nt'] | list[str],
209
  *,
210
+ lens: int['b'] | None = None,
211
  noise_scheduler: str | None = None,
212
  ):
213
  # handle raw wave
214
  if inp.ndim == 2:
215
  inp = self.mel_spec(inp)
216
+ inp = rearrange(inp, 'b d n -> b n d')
217
  assert inp.shape[-1] == self.num_channels
218
 
219
+ batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
220
 
221
  # handle text as string
222
  if isinstance(text, list):
 
228
 
229
  # lens and mask
230
  if not exists(lens):
231
+ lens = torch.full((batch,), seq_len, device = device)
232
+
233
+ mask = lens_to_mask(lens, length = seq_len) # useless here, as collate_fn will pad to max length in batch
234
 
235
  # get a random span to mask out for training conditionally
236
+ frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
237
  rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
238
 
239
  if exists(mask):
 
246
  x0 = torch.randn_like(x1)
247
 
248
  # time step
249
+ time = torch.rand((batch,), dtype = dtype, device = self.device)
250
  # TODO. noise_scheduler
251
 
252
  # sample xt (φ_t(x) in the paper)
253
+ t = rearrange(time, 'b -> b 1 1')
254
  φ = (1 - t) * x0 + t * x1
255
  flow = x1 - x0
256
 
257
  # only predict what is within the random mask span for infilling
258
+ cond = torch.where(
259
+ rand_span_mask[..., None],
260
+ torch.zeros_like(x1), x1
261
+ )
262
 
263
  # transformer and cfg training with a drop rate
264
  drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
 
267
  drop_text = True
268
  else:
269
  drop_text = False
270
+
271
  # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
272
  # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
273
+ pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text)
 
 
274
 
275
  # flow matching loss
276
+ loss = F.mse_loss(pred, flow, reduction = 'none')
277
  loss = loss[rand_span_mask]
278
 
279
  return loss.mean(), cond, pred
model/dataset.py CHANGED
@@ -6,67 +6,65 @@ import torch
6
  import torch.nn.functional as F
7
  from torch.utils.data import Dataset, Sampler
8
  import torchaudio
9
- from datasets import load_from_disk
10
  from datasets import Dataset as Dataset_
11
- from torch import nn
 
12
 
13
  from model.modules import MelSpec
14
- from model.utils import default
15
 
16
 
17
  class HFDataset(Dataset):
18
  def __init__(
19
  self,
20
  hf_dataset: Dataset,
21
- target_sample_rate=24_000,
22
- n_mel_channels=100,
23
- hop_length=256,
24
  ):
25
  self.data = hf_dataset
26
  self.target_sample_rate = target_sample_rate
27
  self.hop_length = hop_length
28
- self.mel_spectrogram = MelSpec(
29
- target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
30
- )
31
-
32
  def get_frame_len(self, index):
33
  row = self.data[index]
34
- audio = row["audio"]["array"]
35
- sample_rate = row["audio"]["sampling_rate"]
36
  return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
37
 
38
  def __len__(self):
39
  return len(self.data)
40
-
41
  def __getitem__(self, index):
42
  row = self.data[index]
43
- audio = row["audio"]["array"]
44
 
45
  # logger.info(f"Audio shape: {audio.shape}")
46
 
47
- sample_rate = row["audio"]["sampling_rate"]
48
  duration = audio.shape[-1] / sample_rate
49
 
50
  if duration > 30 or duration < 0.3:
51
  return self.__getitem__((index + 1) % len(self.data))
52
-
53
  audio_tensor = torch.from_numpy(audio).float()
54
-
55
  if sample_rate != self.target_sample_rate:
56
  resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
57
  audio_tensor = resampler(audio_tensor)
58
-
59
- audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
60
-
61
  mel_spec = self.mel_spectrogram(audio_tensor)
62
-
63
- mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
64
-
65
- text = row["text"]
66
-
67
  return dict(
68
- mel_spec=mel_spec,
69
- text=text,
70
  )
71
 
72
 
@@ -74,39 +72,28 @@ class CustomDataset(Dataset):
74
  def __init__(
75
  self,
76
  custom_dataset: Dataset,
77
- durations=None,
78
- target_sample_rate=24_000,
79
- hop_length=256,
80
- n_mel_channels=100,
81
- preprocessed_mel=False,
82
- mel_spec_module: nn.Module | None = None,
83
  ):
84
  self.data = custom_dataset
85
  self.durations = durations
86
  self.target_sample_rate = target_sample_rate
87
  self.hop_length = hop_length
88
  self.preprocessed_mel = preprocessed_mel
89
-
90
  if not preprocessed_mel:
91
- self.mel_spectrogram = default(
92
- mel_spec_module,
93
- MelSpec(
94
- target_sample_rate=target_sample_rate,
95
- hop_length=hop_length,
96
- n_mel_channels=n_mel_channels,
97
- ),
98
- )
99
 
100
  def get_frame_len(self, index):
101
- if (
102
- self.durations is not None
103
- ): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
104
  return self.durations[index] * self.target_sample_rate / self.hop_length
105
  return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
106
-
107
  def __len__(self):
108
  return len(self.data)
109
-
110
  def __getitem__(self, index):
111
  row = self.data[index]
112
  audio_path = row["audio_path"]
@@ -118,57 +105,48 @@ class CustomDataset(Dataset):
118
 
119
  else:
120
  audio, source_sample_rate = torchaudio.load(audio_path)
121
- if audio.shape[0] > 1:
122
- audio = torch.mean(audio, dim=0, keepdim=True)
123
 
124
  if duration > 30 or duration < 0.3:
125
  return self.__getitem__((index + 1) % len(self.data))
126
-
127
  if source_sample_rate != self.target_sample_rate:
128
  resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
129
  audio = resampler(audio)
130
-
131
  mel_spec = self.mel_spectrogram(audio)
132
- mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
133
-
134
  return dict(
135
- mel_spec=mel_spec,
136
- text=text,
137
  )
138
-
139
 
140
  # Dynamic Batch Sampler
141
 
142
-
143
  class DynamicBatchSampler(Sampler[list[int]]):
144
- """Extension of Sampler that will do the following:
145
- 1. Change the batch size (essentially number of sequences)
146
- in a batch to ensure that the total number of frames are less
147
- than a certain threshold.
148
- 2. Make sure the padding efficiency in the batch is high.
149
  """
150
 
151
- def __init__(
152
- self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
153
- ):
154
  self.sampler = sampler
155
  self.frames_threshold = frames_threshold
156
  self.max_samples = max_samples
157
 
158
  indices, batches = [], []
159
  data_source = self.sampler.data_source
160
-
161
- for idx in tqdm(
162
- self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration"
163
- ):
164
  indices.append((idx, data_source.get_frame_len(idx)))
165
- indices.sort(key=lambda elem: elem[1])
166
 
167
  batch = []
168
  batch_frames = 0
169
- for idx, frame_len in tqdm(
170
- indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
171
- ):
172
  if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
173
  batch.append(idx)
174
  batch_frames += frame_len
@@ -204,91 +182,76 @@ class DynamicBatchSampler(Sampler[list[int]]):
204
 
205
  # Load dataset
206
 
207
-
208
  def load_dataset(
209
- dataset_name: str,
210
- tokenizer: str = "pinyin",
211
- dataset_type: str = "CustomDataset",
212
- audio_type: str = "raw",
213
- mel_spec_module: nn.Module | None = None,
214
- mel_spec_kwargs: dict = dict(),
215
- ) -> CustomDataset | HFDataset:
216
- """
217
  dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
218
  - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
219
- """
220
-
221
  print("Loading dataset ...")
222
 
223
  if dataset_type == "CustomDataset":
224
  if audio_type == "raw":
225
  try:
226
  train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
227
- except: # noqa: E722
228
  train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
229
  preprocessed_mel = False
230
  elif audio_type == "mel":
231
  train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
232
  preprocessed_mel = True
233
- with open(f"data/{dataset_name}_{tokenizer}/duration.json", "r", encoding="utf-8") as f:
234
  data_dict = json.load(f)
235
  durations = data_dict["duration"]
236
- train_dataset = CustomDataset(
237
- train_dataset,
238
- durations=durations,
239
- preprocessed_mel=preprocessed_mel,
240
- mel_spec_module=mel_spec_module,
241
- **mel_spec_kwargs,
242
- )
243
-
244
  elif dataset_type == "CustomDatasetPath":
245
  try:
246
  train_dataset = load_from_disk(f"{dataset_name}/raw")
247
- except: # noqa: E722
248
  train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
249
-
250
- with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f:
251
  data_dict = json.load(f)
252
  durations = data_dict["duration"]
253
- train_dataset = CustomDataset(
254
- train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
255
- )
256
-
257
  elif dataset_type == "HFDataset":
258
- print(
259
- "Should manually modify the path of huggingface dataset to your need.\n"
260
- + "May also the corresponding script cuz different dataset may have different format."
261
- )
262
  pre, post = dataset_name.split("_")
263
- train_dataset = HFDataset(
264
- load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),
265
- )
266
 
267
  return train_dataset
268
 
269
 
270
  # collation
271
 
272
-
273
  def collate_fn(batch):
274
- mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
275
  mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
276
  max_mel_length = mel_lengths.amax()
277
 
278
  padded_mel_specs = []
279
  for spec in mel_specs: # TODO. maybe records mask for attention here
280
  padding = (0, max_mel_length - spec.size(-1))
281
- padded_spec = F.pad(spec, padding, value=0)
282
  padded_mel_specs.append(padded_spec)
283
-
284
  mel_specs = torch.stack(padded_mel_specs)
285
 
286
- text = [item["text"] for item in batch]
287
  text_lengths = torch.LongTensor([len(item) for item in text])
288
 
289
  return dict(
290
- mel=mel_specs,
291
- mel_lengths=mel_lengths,
292
- text=text,
293
- text_lengths=text_lengths,
294
  )
 
6
  import torch.nn.functional as F
7
  from torch.utils.data import Dataset, Sampler
8
  import torchaudio
9
+ from datasets import load_dataset, load_from_disk
10
  from datasets import Dataset as Dataset_
11
+
12
+ from einops import rearrange
13
 
14
  from model.modules import MelSpec
 
15
 
16
 
17
  class HFDataset(Dataset):
18
  def __init__(
19
  self,
20
  hf_dataset: Dataset,
21
+ target_sample_rate = 24_000,
22
+ n_mel_channels = 100,
23
+ hop_length = 256,
24
  ):
25
  self.data = hf_dataset
26
  self.target_sample_rate = target_sample_rate
27
  self.hop_length = hop_length
28
+ self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
29
+
 
 
30
  def get_frame_len(self, index):
31
  row = self.data[index]
32
+ audio = row['audio']['array']
33
+ sample_rate = row['audio']['sampling_rate']
34
  return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
35
 
36
  def __len__(self):
37
  return len(self.data)
38
+
39
  def __getitem__(self, index):
40
  row = self.data[index]
41
+ audio = row['audio']['array']
42
 
43
  # logger.info(f"Audio shape: {audio.shape}")
44
 
45
+ sample_rate = row['audio']['sampling_rate']
46
  duration = audio.shape[-1] / sample_rate
47
 
48
  if duration > 30 or duration < 0.3:
49
  return self.__getitem__((index + 1) % len(self.data))
50
+
51
  audio_tensor = torch.from_numpy(audio).float()
52
+
53
  if sample_rate != self.target_sample_rate:
54
  resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
55
  audio_tensor = resampler(audio_tensor)
56
+
57
+ audio_tensor = rearrange(audio_tensor, 't -> 1 t')
58
+
59
  mel_spec = self.mel_spectrogram(audio_tensor)
60
+
61
+ mel_spec = rearrange(mel_spec, '1 d t -> d t')
62
+
63
+ text = row['text']
64
+
65
  return dict(
66
+ mel_spec = mel_spec,
67
+ text = text,
68
  )
69
 
70
 
 
72
  def __init__(
73
  self,
74
  custom_dataset: Dataset,
75
+ durations = None,
76
+ target_sample_rate = 24_000,
77
+ hop_length = 256,
78
+ n_mel_channels = 100,
79
+ preprocessed_mel = False,
 
80
  ):
81
  self.data = custom_dataset
82
  self.durations = durations
83
  self.target_sample_rate = target_sample_rate
84
  self.hop_length = hop_length
85
  self.preprocessed_mel = preprocessed_mel
 
86
  if not preprocessed_mel:
87
+ self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels)
 
 
 
 
 
 
 
88
 
89
  def get_frame_len(self, index):
90
+ if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
 
 
91
  return self.durations[index] * self.target_sample_rate / self.hop_length
92
  return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
93
+
94
  def __len__(self):
95
  return len(self.data)
96
+
97
  def __getitem__(self, index):
98
  row = self.data[index]
99
  audio_path = row["audio_path"]
 
105
 
106
  else:
107
  audio, source_sample_rate = torchaudio.load(audio_path)
 
 
108
 
109
  if duration > 30 or duration < 0.3:
110
  return self.__getitem__((index + 1) % len(self.data))
111
+
112
  if source_sample_rate != self.target_sample_rate:
113
  resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
114
  audio = resampler(audio)
115
+
116
  mel_spec = self.mel_spectrogram(audio)
117
+ mel_spec = rearrange(mel_spec, '1 d t -> d t')
118
+
119
  return dict(
120
+ mel_spec = mel_spec,
121
+ text = text,
122
  )
123
+
124
 
125
  # Dynamic Batch Sampler
126
 
 
127
  class DynamicBatchSampler(Sampler[list[int]]):
128
+ """ Extension of Sampler that will do the following:
129
+ 1. Change the batch size (essentially number of sequences)
130
+ in a batch to ensure that the total number of frames are less
131
+ than a certain threshold.
132
+ 2. Make sure the padding efficiency in the batch is high.
133
  """
134
 
135
+ def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False):
 
 
136
  self.sampler = sampler
137
  self.frames_threshold = frames_threshold
138
  self.max_samples = max_samples
139
 
140
  indices, batches = [], []
141
  data_source = self.sampler.data_source
142
+
143
+ for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"):
 
 
144
  indices.append((idx, data_source.get_frame_len(idx)))
145
+ indices.sort(key=lambda elem : elem[1])
146
 
147
  batch = []
148
  batch_frames = 0
149
+ for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"):
 
 
150
  if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
151
  batch.append(idx)
152
  batch_frames += frame_len
 
182
 
183
  # Load dataset
184
 
 
185
  def load_dataset(
186
+ dataset_name: str,
187
+ tokenizer: str = "pinyin",
188
+ dataset_type: str = "CustomDataset",
189
+ audio_type: str = "raw",
190
+ mel_spec_kwargs: dict = dict()
191
+ ) -> CustomDataset | HFDataset:
192
+ '''
 
193
  dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
194
  - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
195
+ '''
196
+
197
  print("Loading dataset ...")
198
 
199
  if dataset_type == "CustomDataset":
200
  if audio_type == "raw":
201
  try:
202
  train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
203
+ except:
204
  train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
205
  preprocessed_mel = False
206
  elif audio_type == "mel":
207
  train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
208
  preprocessed_mel = True
209
+ with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f:
210
  data_dict = json.load(f)
211
  durations = data_dict["duration"]
212
+ train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
213
+
 
 
 
 
 
 
214
  elif dataset_type == "CustomDatasetPath":
215
  try:
216
  train_dataset = load_from_disk(f"{dataset_name}/raw")
217
+ except:
218
  train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
219
+
220
+ with open(f"{dataset_name}/duration.json", 'r', encoding='utf-8') as f:
221
  data_dict = json.load(f)
222
  durations = data_dict["duration"]
223
+ train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
224
+
 
 
225
  elif dataset_type == "HFDataset":
226
+ print("Should manually modify the path of huggingface dataset to your need.\n" +
227
+ "May also the corresponding script cuz different dataset may have different format.")
 
 
228
  pre, post = dataset_name.split("_")
229
+ train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),)
 
 
230
 
231
  return train_dataset
232
 
233
 
234
  # collation
235
 
 
236
  def collate_fn(batch):
237
+ mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
238
  mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
239
  max_mel_length = mel_lengths.amax()
240
 
241
  padded_mel_specs = []
242
  for spec in mel_specs: # TODO. maybe records mask for attention here
243
  padding = (0, max_mel_length - spec.size(-1))
244
+ padded_spec = F.pad(spec, padding, value = 0)
245
  padded_mel_specs.append(padded_spec)
246
+
247
  mel_specs = torch.stack(padded_mel_specs)
248
 
249
+ text = [item['text'] for item in batch]
250
  text_lengths = torch.LongTensor([len(item) for item in text])
251
 
252
  return dict(
253
+ mel = mel_specs,
254
+ mel_lengths = mel_lengths,
255
+ text = text,
256
+ text_lengths = text_lengths,
257
  )
model/ecapa_tdnn.py CHANGED
@@ -9,14 +9,13 @@ import torch.nn as nn
9
  import torch.nn.functional as F
10
 
11
 
12
- """ Res2Conv1d + BatchNorm1d + ReLU
13
- """
14
-
15
 
16
  class Res2Conv1dReluBn(nn.Module):
17
- """
18
  in_channels == out_channels == channels
19
- """
20
 
21
  def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
22
  super().__init__()
@@ -52,9 +51,8 @@ class Res2Conv1dReluBn(nn.Module):
52
  return out
53
 
54
 
55
- """ Conv1d + BatchNorm1d + ReLU
56
- """
57
-
58
 
59
  class Conv1dReluBn(nn.Module):
60
  def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
@@ -66,9 +64,8 @@ class Conv1dReluBn(nn.Module):
66
  return self.bn(F.relu(self.conv(x)))
67
 
68
 
69
- """ The SE connection of 1D case.
70
- """
71
-
72
 
73
  class SE_Connect(nn.Module):
74
  def __init__(self, channels, se_bottleneck_dim=128):
@@ -85,8 +82,8 @@ class SE_Connect(nn.Module):
85
  return out
86
 
87
 
88
- """ SE-Res2Block of the ECAPA-TDNN architecture.
89
- """
90
 
91
  # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
92
  # return nn.Sequential(
@@ -96,7 +93,6 @@ class SE_Connect(nn.Module):
96
  # SE_Connect(channels)
97
  # )
98
 
99
-
100
  class SE_Res2Block(nn.Module):
101
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
102
  super().__init__()
@@ -126,9 +122,8 @@ class SE_Res2Block(nn.Module):
126
  return x + residual
127
 
128
 
129
- """ Attentive weighted mean and standard deviation pooling.
130
- """
131
-
132
 
133
  class AttentiveStatsPool(nn.Module):
134
  def __init__(self, in_dim, attention_channels=128, global_context_att=False):
@@ -143,6 +138,7 @@ class AttentiveStatsPool(nn.Module):
143
  self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
144
 
145
  def forward(self, x):
 
146
  if self.global_context_att:
147
  context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
148
  context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
@@ -155,52 +151,38 @@ class AttentiveStatsPool(nn.Module):
155
  # alpha = F.relu(self.linear1(x_in))
156
  alpha = torch.softmax(self.linear2(alpha), dim=2)
157
  mean = torch.sum(alpha * x, dim=2)
158
- residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
159
  std = torch.sqrt(residuals.clamp(min=1e-9))
160
  return torch.cat([mean, std], dim=1)
161
 
162
 
163
  class ECAPA_TDNN(nn.Module):
164
- def __init__(
165
- self,
166
- feat_dim=80,
167
- channels=512,
168
- emb_dim=192,
169
- global_context_att=False,
170
- feat_type="wavlm_large",
171
- sr=16000,
172
- feature_selection="hidden_states",
173
- update_extract=False,
174
- config_path=None,
175
- ):
176
  super().__init__()
177
 
178
  self.feat_type = feat_type
179
  self.feature_selection = feature_selection
180
  self.update_extract = update_extract
181
  self.sr = sr
182
-
183
- torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
184
  try:
185
  local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
186
- self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path)
187
- except: # noqa: E722
188
- self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
189
 
190
- if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
191
- self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
192
- ):
193
  self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
194
- if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
195
- self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
196
- ):
197
  self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
198
 
199
  self.feat_num = self.get_feat_num()
200
  self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
201
 
202
- if feat_type != "fbank" and feat_type != "mfcc":
203
- freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"]
204
  for name, param in self.feature_extract.named_parameters():
205
  for freeze_val in freeze_list:
206
  if freeze_val in name:
@@ -216,46 +198,18 @@ class ECAPA_TDNN(nn.Module):
216
  self.channels = [channels] * 4 + [1536]
217
 
218
  self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
219
- self.layer2 = SE_Res2Block(
220
- self.channels[0],
221
- self.channels[1],
222
- kernel_size=3,
223
- stride=1,
224
- padding=2,
225
- dilation=2,
226
- scale=8,
227
- se_bottleneck_dim=128,
228
- )
229
- self.layer3 = SE_Res2Block(
230
- self.channels[1],
231
- self.channels[2],
232
- kernel_size=3,
233
- stride=1,
234
- padding=3,
235
- dilation=3,
236
- scale=8,
237
- se_bottleneck_dim=128,
238
- )
239
- self.layer4 = SE_Res2Block(
240
- self.channels[2],
241
- self.channels[3],
242
- kernel_size=3,
243
- stride=1,
244
- padding=4,
245
- dilation=4,
246
- scale=8,
247
- se_bottleneck_dim=128,
248
- )
249
 
250
  # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
251
  cat_channels = channels * 3
252
  self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
253
- self.pooling = AttentiveStatsPool(
254
- self.channels[-1], attention_channels=128, global_context_att=global_context_att
255
- )
256
  self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
257
  self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
258
 
 
259
  def get_feat_num(self):
260
  self.feature_extract.eval()
261
  wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
@@ -272,12 +226,12 @@ class ECAPA_TDNN(nn.Module):
272
  x = self.feature_extract([sample for sample in x])
273
  else:
274
  with torch.no_grad():
275
- if self.feat_type == "fbank" or self.feat_type == "mfcc":
276
  x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
277
  else:
278
  x = self.feature_extract([sample for sample in x])
279
 
280
- if self.feat_type == "fbank":
281
  x = x.log()
282
 
283
  if self.feat_type != "fbank" and self.feat_type != "mfcc":
@@ -309,22 +263,6 @@ class ECAPA_TDNN(nn.Module):
309
  return out
310
 
311
 
312
- def ECAPA_TDNN_SMALL(
313
- feat_dim,
314
- emb_dim=256,
315
- feat_type="wavlm_large",
316
- sr=16000,
317
- feature_selection="hidden_states",
318
- update_extract=False,
319
- config_path=None,
320
- ):
321
- return ECAPA_TDNN(
322
- feat_dim=feat_dim,
323
- channels=512,
324
- emb_dim=emb_dim,
325
- feat_type=feat_type,
326
- sr=sr,
327
- feature_selection=feature_selection,
328
- update_extract=update_extract,
329
- config_path=config_path,
330
- )
 
9
  import torch.nn.functional as F
10
 
11
 
12
+ ''' Res2Conv1d + BatchNorm1d + ReLU
13
+ '''
 
14
 
15
  class Res2Conv1dReluBn(nn.Module):
16
+ '''
17
  in_channels == out_channels == channels
18
+ '''
19
 
20
  def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
21
  super().__init__()
 
51
  return out
52
 
53
 
54
+ ''' Conv1d + BatchNorm1d + ReLU
55
+ '''
 
56
 
57
  class Conv1dReluBn(nn.Module):
58
  def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
 
64
  return self.bn(F.relu(self.conv(x)))
65
 
66
 
67
+ ''' The SE connection of 1D case.
68
+ '''
 
69
 
70
  class SE_Connect(nn.Module):
71
  def __init__(self, channels, se_bottleneck_dim=128):
 
82
  return out
83
 
84
 
85
+ ''' SE-Res2Block of the ECAPA-TDNN architecture.
86
+ '''
87
 
88
  # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
89
  # return nn.Sequential(
 
93
  # SE_Connect(channels)
94
  # )
95
 
 
96
  class SE_Res2Block(nn.Module):
97
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
98
  super().__init__()
 
122
  return x + residual
123
 
124
 
125
+ ''' Attentive weighted mean and standard deviation pooling.
126
+ '''
 
127
 
128
  class AttentiveStatsPool(nn.Module):
129
  def __init__(self, in_dim, attention_channels=128, global_context_att=False):
 
138
  self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
139
 
140
  def forward(self, x):
141
+
142
  if self.global_context_att:
143
  context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
144
  context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
 
151
  # alpha = F.relu(self.linear1(x_in))
152
  alpha = torch.softmax(self.linear2(alpha), dim=2)
153
  mean = torch.sum(alpha * x, dim=2)
154
+ residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
155
  std = torch.sqrt(residuals.clamp(min=1e-9))
156
  return torch.cat([mean, std], dim=1)
157
 
158
 
159
  class ECAPA_TDNN(nn.Module):
160
+ def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
161
+ feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
 
 
 
 
 
 
 
 
 
 
162
  super().__init__()
163
 
164
  self.feat_type = feat_type
165
  self.feature_selection = feature_selection
166
  self.update_extract = update_extract
167
  self.sr = sr
168
+
169
+ torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
170
  try:
171
  local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
172
+ self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path)
173
+ except:
174
+ self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
175
 
176
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
 
 
177
  self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
178
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
 
 
179
  self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
180
 
181
  self.feat_num = self.get_feat_num()
182
  self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
183
 
184
+ if feat_type != 'fbank' and feat_type != 'mfcc':
185
+ freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
186
  for name, param in self.feature_extract.named_parameters():
187
  for freeze_val in freeze_list:
188
  if freeze_val in name:
 
198
  self.channels = [channels] * 4 + [1536]
199
 
200
  self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
201
+ self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
202
+ self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
203
+ self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
206
  cat_channels = channels * 3
207
  self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
208
+ self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
 
 
209
  self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
210
  self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
211
 
212
+
213
  def get_feat_num(self):
214
  self.feature_extract.eval()
215
  wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
 
226
  x = self.feature_extract([sample for sample in x])
227
  else:
228
  with torch.no_grad():
229
+ if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
230
  x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
231
  else:
232
  x = self.feature_extract([sample for sample in x])
233
 
234
+ if self.feat_type == 'fbank':
235
  x = x.log()
236
 
237
  if self.feat_type != "fbank" and self.feat_type != "mfcc":
 
263
  return out
264
 
265
 
266
+ def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
267
+ return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
268
+ feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/modules.py CHANGED
@@ -16,45 +16,45 @@ from torch import nn
16
  import torch.nn.functional as F
17
  import torchaudio
18
 
 
19
  from x_transformers.x_transformers import apply_rotary_pos_emb
20
 
21
 
22
  # raw wav to mel spec
23
 
24
-
25
  class MelSpec(nn.Module):
26
  def __init__(
27
  self,
28
- filter_length=1024,
29
- hop_length=256,
30
- win_length=1024,
31
- n_mel_channels=100,
32
- target_sample_rate=24_000,
33
- normalize=False,
34
- power=1,
35
- norm=None,
36
- center=True,
37
  ):
38
  super().__init__()
39
  self.n_mel_channels = n_mel_channels
40
 
41
  self.mel_stft = torchaudio.transforms.MelSpectrogram(
42
- sample_rate=target_sample_rate,
43
- n_fft=filter_length,
44
- win_length=win_length,
45
- hop_length=hop_length,
46
- n_mels=n_mel_channels,
47
- power=power,
48
- center=center,
49
- normalized=normalize,
50
- norm=norm,
51
  )
52
 
53
- self.register_buffer("dummy", torch.tensor(0), persistent=False)
54
 
55
  def forward(self, inp):
56
  if len(inp.shape) == 3:
57
- inp = inp.squeeze(1) # 'b 1 nw -> b nw'
58
 
59
  assert len(inp.shape) == 2
60
 
@@ -62,13 +62,12 @@ class MelSpec(nn.Module):
62
  self.to(inp.device)
63
 
64
  mel = self.mel_stft(inp)
65
- mel = mel.clamp(min=1e-5).log()
66
  return mel
67
-
68
 
69
  # sinusoidal position embedding
70
 
71
-
72
  class SinusPositionEmbedding(nn.Module):
73
  def __init__(self, dim):
74
  super().__init__()
@@ -86,37 +85,35 @@ class SinusPositionEmbedding(nn.Module):
86
 
87
  # convolutional position embedding
88
 
89
-
90
  class ConvPositionEmbedding(nn.Module):
91
- def __init__(self, dim, kernel_size=31, groups=16):
92
  super().__init__()
93
  assert kernel_size % 2 != 0
94
  self.conv1d = nn.Sequential(
95
- nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
96
  nn.Mish(),
97
- nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
98
  nn.Mish(),
99
  )
100
 
101
- def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
102
  if mask is not None:
103
  mask = mask[..., None]
104
- x = x.masked_fill(~mask, 0.0)
105
 
106
- x = x.permute(0, 2, 1)
107
  x = self.conv1d(x)
108
- out = x.permute(0, 2, 1)
109
 
110
  if mask is not None:
111
- out = out.masked_fill(~mask, 0.0)
112
 
113
  return out
114
 
115
 
116
  # rotary positional embedding related
117
 
118
-
119
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
120
  # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
121
  # has some connection to NTK literature
122
  # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
@@ -129,14 +126,12 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
129
  freqs_sin = torch.sin(freqs) # imaginary part
130
  return torch.cat([freqs_cos, freqs_sin], dim=-1)
131
 
132
-
133
- def get_pos_embed_indices(start, length, max_pos, scale=1.0):
134
  # length = length if isinstance(length, int) else length.max()
135
  scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
136
- pos = (
137
- start.unsqueeze(1)
138
- + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
139
- )
140
  # avoid extra long error.
141
  pos = torch.where(pos < max_pos, pos, max_pos - 1)
142
  return pos
@@ -144,7 +139,6 @@ def get_pos_embed_indices(start, length, max_pos, scale=1.0):
144
 
145
  # Global Response Normalization layer (Instance Normalization ?)
146
 
147
-
148
  class GRN(nn.Module):
149
  def __init__(self, dim):
150
  super().__init__()
@@ -160,7 +154,6 @@ class GRN(nn.Module):
160
  # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
161
  # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
162
 
163
-
164
  class ConvNeXtV2Block(nn.Module):
165
  def __init__(
166
  self,
@@ -170,9 +163,7 @@ class ConvNeXtV2Block(nn.Module):
170
  ):
171
  super().__init__()
172
  padding = (dilation * (7 - 1)) // 2
173
- self.dwconv = nn.Conv1d(
174
- dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
175
- ) # depthwise conv
176
  self.norm = nn.LayerNorm(dim, eps=1e-6)
177
  self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
178
  self.act = nn.GELU()
@@ -195,7 +186,6 @@ class ConvNeXtV2Block(nn.Module):
195
  # AdaLayerNormZero
196
  # return with modulated x for attn input, and params for later mlp modulation
197
 
198
-
199
  class AdaLayerNormZero(nn.Module):
200
  def __init__(self, dim):
201
  super().__init__()
@@ -205,7 +195,7 @@ class AdaLayerNormZero(nn.Module):
205
 
206
  self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
207
 
208
- def forward(self, x, emb=None):
209
  emb = self.linear(self.silu(emb))
210
  shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
211
 
@@ -216,7 +206,6 @@ class AdaLayerNormZero(nn.Module):
216
  # AdaLayerNormZero for final layer
217
  # return only with modulated x for attn input, cuz no more mlp modulation
218
 
219
-
220
  class AdaLayerNormZero_Final(nn.Module):
221
  def __init__(self, dim):
222
  super().__init__()
@@ -236,16 +225,22 @@ class AdaLayerNormZero_Final(nn.Module):
236
 
237
  # FeedForward
238
 
239
-
240
  class FeedForward(nn.Module):
241
- def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
242
  super().__init__()
243
  inner_dim = int(dim * mult)
244
  dim_out = dim_out if dim_out is not None else dim
245
 
246
  activation = nn.GELU(approximate=approximate)
247
- project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
248
- self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
 
 
 
 
 
 
 
249
 
250
  def forward(self, x):
251
  return self.ff(x)
@@ -254,7 +249,6 @@ class FeedForward(nn.Module):
254
  # Attention with possible joint part
255
  # modified from diffusers/src/diffusers/models/attention_processor.py
256
 
257
-
258
  class Attention(nn.Module):
259
  def __init__(
260
  self,
@@ -263,8 +257,8 @@ class Attention(nn.Module):
263
  heads: int = 8,
264
  dim_head: int = 64,
265
  dropout: float = 0.0,
266
- context_dim: Optional[int] = None, # if not None -> joint attention
267
- context_pre_only=None,
268
  ):
269
  super().__init__()
270
 
@@ -300,21 +294,20 @@ class Attention(nn.Module):
300
 
301
  def forward(
302
  self,
303
- x: float["b n d"], # noised input x # noqa: F722
304
- c: float["b n d"] = None, # context c # noqa: F722
305
- mask: bool["b n"] | None = None, # noqa: F722
306
- rope=None, # rotary position embedding for x
307
- c_rope=None, # rotary position embedding for c
308
  ) -> torch.Tensor:
309
  if c is not None:
310
- return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
311
  else:
312
- return self.processor(self, x, mask=mask, rope=rope)
313
 
314
 
315
  # Attention processor
316
 
317
-
318
  class AttnProcessor:
319
  def __init__(self):
320
  pass
@@ -322,10 +315,11 @@ class AttnProcessor:
322
  def __call__(
323
  self,
324
  attn: Attention,
325
- x: float["b n d"], # noised input x # noqa: F722
326
- mask: bool["b n"] | None = None, # noqa: F722
327
- rope=None, # rotary position embedding
328
  ) -> torch.FloatTensor:
 
329
  batch_size = x.shape[0]
330
 
331
  # `sample` projections.
@@ -336,7 +330,7 @@ class AttnProcessor:
336
  # apply rotary position embedding
337
  if rope is not None:
338
  freqs, xpos_scale = rope
339
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
340
 
341
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
342
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
@@ -351,7 +345,7 @@ class AttnProcessor:
351
  # mask. e.g. inference got a batch with different target durations, mask out the padding
352
  if mask is not None:
353
  attn_mask = mask
354
- attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
355
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
356
  else:
357
  attn_mask = None
@@ -366,16 +360,15 @@ class AttnProcessor:
366
  x = attn.to_out[1](x)
367
 
368
  if mask is not None:
369
- mask = mask.unsqueeze(-1)
370
- x = x.masked_fill(~mask, 0.0)
371
 
372
  return x
373
-
374
 
375
  # Joint Attention processor for MM-DiT
376
  # modified from diffusers/src/diffusers/models/attention_processor.py
377
 
378
-
379
  class JointAttnProcessor:
380
  def __init__(self):
381
  pass
@@ -383,11 +376,11 @@ class JointAttnProcessor:
383
  def __call__(
384
  self,
385
  attn: Attention,
386
- x: float["b n d"], # noised input x # noqa: F722
387
- c: float["b nt d"] = None, # context c, here text # noqa: F722
388
- mask: bool["b n"] | None = None, # noqa: F722
389
- rope=None, # rotary position embedding for x
390
- c_rope=None, # rotary position embedding for c
391
  ) -> torch.FloatTensor:
392
  residual = x
393
 
@@ -406,12 +399,12 @@ class JointAttnProcessor:
406
  # apply rope for context and noised input independently
407
  if rope is not None:
408
  freqs, xpos_scale = rope
409
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
410
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
411
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
412
  if c_rope is not None:
413
  freqs, xpos_scale = c_rope
414
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
415
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
416
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
417
 
@@ -428,8 +421,8 @@ class JointAttnProcessor:
428
 
429
  # mask. e.g. inference got a batch with different target durations, mask out the padding
430
  if mask is not None:
431
- attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
432
- attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
433
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
434
  else:
435
  attn_mask = None
@@ -440,8 +433,8 @@ class JointAttnProcessor:
440
 
441
  # Split the attention outputs.
442
  x, c = (
443
- x[:, : residual.shape[1]],
444
- x[:, residual.shape[1] :],
445
  )
446
 
447
  # linear proj
@@ -452,8 +445,8 @@ class JointAttnProcessor:
452
  c = attn.to_out_c(c)
453
 
454
  if mask is not None:
455
- mask = mask.unsqueeze(-1)
456
- x = x.masked_fill(~mask, 0.0)
457
  # c = c.masked_fill(~mask, 0.) # no mask for c (text)
458
 
459
  return x, c
@@ -461,24 +454,24 @@ class JointAttnProcessor:
461
 
462
  # DiT Block
463
 
464
-
465
  class DiTBlock(nn.Module):
466
- def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
467
- super().__init__()
468
 
 
 
 
469
  self.attn_norm = AdaLayerNormZero(dim)
470
  self.attn = Attention(
471
- processor=AttnProcessor(),
472
- dim=dim,
473
- heads=heads,
474
- dim_head=dim_head,
475
- dropout=dropout,
476
- )
477
-
478
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
479
- self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
480
 
481
- def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
482
  # pre-norm & modulation for attention input
483
  norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
484
 
@@ -487,7 +480,7 @@ class DiTBlock(nn.Module):
487
 
488
  # process attention output for input x
489
  x = x + gate_msa.unsqueeze(1) * attn_output
490
-
491
  norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
492
  ff_output = self.ff(norm)
493
  x = x + gate_mlp.unsqueeze(1) * ff_output
@@ -497,9 +490,8 @@ class DiTBlock(nn.Module):
497
 
498
  # MMDiT Block https://arxiv.org/abs/2403.03206
499
 
500
-
501
  class MMDiTBlock(nn.Module):
502
- r"""
503
  modified from diffusers/src/diffusers/models/attention.py
504
 
505
  notes.
@@ -508,33 +500,33 @@ class MMDiTBlock(nn.Module):
508
  context_pre_only: last layer only do prenorm + modulation cuz no more ffn
509
  """
510
 
511
- def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
512
  super().__init__()
513
 
514
  self.context_pre_only = context_pre_only
515
-
516
  self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
517
  self.attn_norm_x = AdaLayerNormZero(dim)
518
  self.attn = Attention(
519
- processor=JointAttnProcessor(),
520
- dim=dim,
521
- heads=heads,
522
- dim_head=dim_head,
523
- dropout=dropout,
524
- context_dim=dim,
525
- context_pre_only=context_pre_only,
526
- )
527
 
528
  if not context_pre_only:
529
  self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
530
- self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
531
  else:
532
  self.ff_norm_c = None
533
  self.ff_c = None
534
  self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
535
- self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
536
 
537
- def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
538
  # pre-norm & modulation for attention input
539
  if self.context_pre_only:
540
  norm_c = self.attn_norm_c(c, t)
@@ -548,7 +540,7 @@ class MMDiTBlock(nn.Module):
548
  # process attention output for context c
549
  if self.context_pre_only:
550
  c = None
551
- else: # if not last layer
552
  c = c + c_gate_msa.unsqueeze(1) * c_attn_output
553
 
554
  norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
@@ -557,7 +549,7 @@ class MMDiTBlock(nn.Module):
557
 
558
  # process attention output for input x
559
  x = x + x_gate_msa.unsqueeze(1) * x_attn_output
560
-
561
  norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
562
  x_ff_output = self.ff_x(norm_x)
563
  x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
@@ -567,15 +559,17 @@ class MMDiTBlock(nn.Module):
567
 
568
  # time step conditioning embedding
569
 
570
-
571
  class TimestepEmbedding(nn.Module):
572
  def __init__(self, dim, freq_embed_dim=256):
573
  super().__init__()
574
  self.time_embed = SinusPositionEmbedding(freq_embed_dim)
575
- self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
 
 
 
 
576
 
577
- def forward(self, timestep: float["b"]): # noqa: F821
578
  time_hidden = self.time_embed(timestep)
579
- time_hidden = time_hidden.to(timestep.dtype)
580
  time = self.time_mlp(time_hidden) # b d
581
  return time
 
16
  import torch.nn.functional as F
17
  import torchaudio
18
 
19
+ from einops import rearrange
20
  from x_transformers.x_transformers import apply_rotary_pos_emb
21
 
22
 
23
  # raw wav to mel spec
24
 
 
25
  class MelSpec(nn.Module):
26
  def __init__(
27
  self,
28
+ filter_length = 1024,
29
+ hop_length = 256,
30
+ win_length = 1024,
31
+ n_mel_channels = 100,
32
+ target_sample_rate = 24_000,
33
+ normalize = False,
34
+ power = 1,
35
+ norm = None,
36
+ center = True,
37
  ):
38
  super().__init__()
39
  self.n_mel_channels = n_mel_channels
40
 
41
  self.mel_stft = torchaudio.transforms.MelSpectrogram(
42
+ sample_rate = target_sample_rate,
43
+ n_fft = filter_length,
44
+ win_length = win_length,
45
+ hop_length = hop_length,
46
+ n_mels = n_mel_channels,
47
+ power = power,
48
+ center = center,
49
+ normalized = normalize,
50
+ norm = norm,
51
  )
52
 
53
+ self.register_buffer('dummy', torch.tensor(0), persistent = False)
54
 
55
  def forward(self, inp):
56
  if len(inp.shape) == 3:
57
+ inp = rearrange(inp, 'b 1 nw -> b nw')
58
 
59
  assert len(inp.shape) == 2
60
 
 
62
  self.to(inp.device)
63
 
64
  mel = self.mel_stft(inp)
65
+ mel = mel.clamp(min = 1e-5).log()
66
  return mel
67
+
68
 
69
  # sinusoidal position embedding
70
 
 
71
  class SinusPositionEmbedding(nn.Module):
72
  def __init__(self, dim):
73
  super().__init__()
 
85
 
86
  # convolutional position embedding
87
 
 
88
  class ConvPositionEmbedding(nn.Module):
89
+ def __init__(self, dim, kernel_size = 31, groups = 16):
90
  super().__init__()
91
  assert kernel_size % 2 != 0
92
  self.conv1d = nn.Sequential(
93
+ nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
94
  nn.Mish(),
95
+ nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
96
  nn.Mish(),
97
  )
98
 
99
+ def forward(self, x: float['b n d'], mask: bool['b n'] | None = None):
100
  if mask is not None:
101
  mask = mask[..., None]
102
+ x = x.masked_fill(~mask, 0.)
103
 
104
+ x = rearrange(x, 'b n d -> b d n')
105
  x = self.conv1d(x)
106
+ out = rearrange(x, 'b d n -> b n d')
107
 
108
  if mask is not None:
109
+ out = out.masked_fill(~mask, 0.)
110
 
111
  return out
112
 
113
 
114
  # rotary positional embedding related
115
 
116
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.):
 
117
  # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
118
  # has some connection to NTK literature
119
  # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
 
126
  freqs_sin = torch.sin(freqs) # imaginary part
127
  return torch.cat([freqs_cos, freqs_sin], dim=-1)
128
 
129
+ def get_pos_embed_indices(start, length, max_pos, scale=1.):
 
130
  # length = length if isinstance(length, int) else length.max()
131
  scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
132
+ pos = start.unsqueeze(1) + (
133
+ torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) *
134
+ scale.unsqueeze(1)).long()
 
135
  # avoid extra long error.
136
  pos = torch.where(pos < max_pos, pos, max_pos - 1)
137
  return pos
 
139
 
140
  # Global Response Normalization layer (Instance Normalization ?)
141
 
 
142
  class GRN(nn.Module):
143
  def __init__(self, dim):
144
  super().__init__()
 
154
  # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
155
  # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
156
 
 
157
  class ConvNeXtV2Block(nn.Module):
158
  def __init__(
159
  self,
 
163
  ):
164
  super().__init__()
165
  padding = (dilation * (7 - 1)) // 2
166
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv
 
 
167
  self.norm = nn.LayerNorm(dim, eps=1e-6)
168
  self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
169
  self.act = nn.GELU()
 
186
  # AdaLayerNormZero
187
  # return with modulated x for attn input, and params for later mlp modulation
188
 
 
189
  class AdaLayerNormZero(nn.Module):
190
  def __init__(self, dim):
191
  super().__init__()
 
195
 
196
  self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
197
 
198
+ def forward(self, x, emb = None):
199
  emb = self.linear(self.silu(emb))
200
  shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
201
 
 
206
  # AdaLayerNormZero for final layer
207
  # return only with modulated x for attn input, cuz no more mlp modulation
208
 
 
209
  class AdaLayerNormZero_Final(nn.Module):
210
  def __init__(self, dim):
211
  super().__init__()
 
225
 
226
  # FeedForward
227
 
 
228
  class FeedForward(nn.Module):
229
+ def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'):
230
  super().__init__()
231
  inner_dim = int(dim * mult)
232
  dim_out = dim_out if dim_out is not None else dim
233
 
234
  activation = nn.GELU(approximate=approximate)
235
+ project_in = nn.Sequential(
236
+ nn.Linear(dim, inner_dim),
237
+ activation
238
+ )
239
+ self.ff = nn.Sequential(
240
+ project_in,
241
+ nn.Dropout(dropout),
242
+ nn.Linear(inner_dim, dim_out)
243
+ )
244
 
245
  def forward(self, x):
246
  return self.ff(x)
 
249
  # Attention with possible joint part
250
  # modified from diffusers/src/diffusers/models/attention_processor.py
251
 
 
252
  class Attention(nn.Module):
253
  def __init__(
254
  self,
 
257
  heads: int = 8,
258
  dim_head: int = 64,
259
  dropout: float = 0.0,
260
+ context_dim: Optional[int] = None, # if not None -> joint attention
261
+ context_pre_only = None,
262
  ):
263
  super().__init__()
264
 
 
294
 
295
  def forward(
296
  self,
297
+ x: float['b n d'], # noised input x
298
+ c: float['b n d'] = None, # context c
299
+ mask: bool['b n'] | None = None,
300
+ rope = None, # rotary position embedding for x
301
+ c_rope = None, # rotary position embedding for c
302
  ) -> torch.Tensor:
303
  if c is not None:
304
+ return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope)
305
  else:
306
+ return self.processor(self, x, mask = mask, rope = rope)
307
 
308
 
309
  # Attention processor
310
 
 
311
  class AttnProcessor:
312
  def __init__(self):
313
  pass
 
315
  def __call__(
316
  self,
317
  attn: Attention,
318
+ x: float['b n d'], # noised input x
319
+ mask: bool['b n'] | None = None,
320
+ rope = None, # rotary position embedding
321
  ) -> torch.FloatTensor:
322
+
323
  batch_size = x.shape[0]
324
 
325
  # `sample` projections.
 
330
  # apply rotary position embedding
331
  if rope is not None:
332
  freqs, xpos_scale = rope
333
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
334
 
335
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
336
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
 
345
  # mask. e.g. inference got a batch with different target durations, mask out the padding
346
  if mask is not None:
347
  attn_mask = mask
348
+ attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
349
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
350
  else:
351
  attn_mask = None
 
360
  x = attn.to_out[1](x)
361
 
362
  if mask is not None:
363
+ mask = rearrange(mask, 'b n -> b n 1')
364
+ x = x.masked_fill(~mask, 0.)
365
 
366
  return x
367
+
368
 
369
  # Joint Attention processor for MM-DiT
370
  # modified from diffusers/src/diffusers/models/attention_processor.py
371
 
 
372
  class JointAttnProcessor:
373
  def __init__(self):
374
  pass
 
376
  def __call__(
377
  self,
378
  attn: Attention,
379
+ x: float['b n d'], # noised input x
380
+ c: float['b nt d'] = None, # context c, here text
381
+ mask: bool['b n'] | None = None,
382
+ rope = None, # rotary position embedding for x
383
+ c_rope = None, # rotary position embedding for c
384
  ) -> torch.FloatTensor:
385
  residual = x
386
 
 
399
  # apply rope for context and noised input independently
400
  if rope is not None:
401
  freqs, xpos_scale = rope
402
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
403
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
404
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
405
  if c_rope is not None:
406
  freqs, xpos_scale = c_rope
407
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
408
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
409
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
410
 
 
421
 
422
  # mask. e.g. inference got a batch with different target durations, mask out the padding
423
  if mask is not None:
424
+ attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
425
+ attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
426
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
427
  else:
428
  attn_mask = None
 
433
 
434
  # Split the attention outputs.
435
  x, c = (
436
+ x[:, :residual.shape[1]],
437
+ x[:, residual.shape[1]:],
438
  )
439
 
440
  # linear proj
 
445
  c = attn.to_out_c(c)
446
 
447
  if mask is not None:
448
+ mask = rearrange(mask, 'b n -> b n 1')
449
+ x = x.masked_fill(~mask, 0.)
450
  # c = c.masked_fill(~mask, 0.) # no mask for c (text)
451
 
452
  return x, c
 
454
 
455
  # DiT Block
456
 
 
457
  class DiTBlock(nn.Module):
 
 
458
 
459
+ def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1):
460
+ super().__init__()
461
+
462
  self.attn_norm = AdaLayerNormZero(dim)
463
  self.attn = Attention(
464
+ processor = AttnProcessor(),
465
+ dim = dim,
466
+ heads = heads,
467
+ dim_head = dim_head,
468
+ dropout = dropout,
469
+ )
470
+
471
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
472
+ self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
473
 
474
+ def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding
475
  # pre-norm & modulation for attention input
476
  norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
477
 
 
480
 
481
  # process attention output for input x
482
  x = x + gate_msa.unsqueeze(1) * attn_output
483
+
484
  norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
485
  ff_output = self.ff(norm)
486
  x = x + gate_mlp.unsqueeze(1) * ff_output
 
490
 
491
  # MMDiT Block https://arxiv.org/abs/2403.03206
492
 
 
493
  class MMDiTBlock(nn.Module):
494
+ r"""
495
  modified from diffusers/src/diffusers/models/attention.py
496
 
497
  notes.
 
500
  context_pre_only: last layer only do prenorm + modulation cuz no more ffn
501
  """
502
 
503
+ def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False):
504
  super().__init__()
505
 
506
  self.context_pre_only = context_pre_only
507
+
508
  self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
509
  self.attn_norm_x = AdaLayerNormZero(dim)
510
  self.attn = Attention(
511
+ processor = JointAttnProcessor(),
512
+ dim = dim,
513
+ heads = heads,
514
+ dim_head = dim_head,
515
+ dropout = dropout,
516
+ context_dim = dim,
517
+ context_pre_only = context_pre_only,
518
+ )
519
 
520
  if not context_pre_only:
521
  self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
522
+ self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
523
  else:
524
  self.ff_norm_c = None
525
  self.ff_c = None
526
  self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
527
+ self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
528
 
529
+ def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding
530
  # pre-norm & modulation for attention input
531
  if self.context_pre_only:
532
  norm_c = self.attn_norm_c(c, t)
 
540
  # process attention output for context c
541
  if self.context_pre_only:
542
  c = None
543
+ else: # if not last layer
544
  c = c + c_gate_msa.unsqueeze(1) * c_attn_output
545
 
546
  norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
 
549
 
550
  # process attention output for input x
551
  x = x + x_gate_msa.unsqueeze(1) * x_attn_output
552
+
553
  norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
554
  x_ff_output = self.ff_x(norm_x)
555
  x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
 
559
 
560
  # time step conditioning embedding
561
 
 
562
  class TimestepEmbedding(nn.Module):
563
  def __init__(self, dim, freq_embed_dim=256):
564
  super().__init__()
565
  self.time_embed = SinusPositionEmbedding(freq_embed_dim)
566
+ self.time_mlp = nn.Sequential(
567
+ nn.Linear(freq_embed_dim, dim),
568
+ nn.SiLU(),
569
+ nn.Linear(dim, dim)
570
+ )
571
 
572
+ def forward(self, timestep: float['b']):
573
  time_hidden = self.time_embed(timestep)
 
574
  time = self.time_mlp(time_hidden) # b d
575
  return time
model/trainer.py CHANGED
@@ -10,6 +10,8 @@ from torch.optim import AdamW
10
  from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
  from torch.optim.lr_scheduler import LinearLR, SequentialLR
12
 
 
 
13
  from accelerate import Accelerator
14
  from accelerate.utils import DistributedDataParallelKwargs
15
 
@@ -22,69 +24,66 @@ from model.dataset import DynamicBatchSampler, collate_fn
22
 
23
  # trainer
24
 
25
-
26
  class Trainer:
27
  def __init__(
28
  self,
29
  model: CFM,
30
  epochs,
31
  learning_rate,
32
- num_warmup_updates=20000,
33
- save_per_updates=1000,
34
- checkpoint_path=None,
35
- batch_size=32,
36
  batch_size_type: str = "sample",
37
- max_samples=32,
38
- grad_accumulation_steps=1,
39
- max_grad_norm=1.0,
40
  noise_scheduler: str | None = None,
41
  duration_predictor: torch.nn.Module | None = None,
42
- wandb_project="test_e2-tts",
43
- wandb_run_name="test_run",
44
  wandb_resume_id: str = None,
45
- last_per_steps=None,
46
  accelerate_kwargs: dict = dict(),
47
- ema_kwargs: dict = dict(),
48
- bnb_optimizer: bool = False,
49
  ):
50
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
51
-
52
- logger = "wandb" if wandb.api.api_key else None
53
- print(f"Using logger: {logger}")
54
 
55
  self.accelerator = Accelerator(
56
- log_with=logger,
57
- kwargs_handlers=[ddp_kwargs],
58
- gradient_accumulation_steps=grad_accumulation_steps,
59
- **accelerate_kwargs,
60
  )
61
-
62
- if logger == "wandb":
63
- if exists(wandb_resume_id):
64
- init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
65
- else:
66
- init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
67
- self.accelerator.init_trackers(
68
- project_name=wandb_project,
69
- init_kwargs=init_kwargs,
70
- config={
71
- "epochs": epochs,
72
  "learning_rate": learning_rate,
73
- "num_warmup_updates": num_warmup_updates,
74
  "batch_size": batch_size,
75
  "batch_size_type": batch_size_type,
76
  "max_samples": max_samples,
77
  "grad_accumulation_steps": grad_accumulation_steps,
78
  "max_grad_norm": max_grad_norm,
79
  "gpus": self.accelerator.num_processes,
80
- "noise_scheduler": noise_scheduler,
81
- },
82
  )
83
 
84
  self.model = model
85
 
86
  if self.is_main:
87
- self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
 
 
 
 
88
 
89
  self.ema_model.to(self.accelerator.device)
90
 
@@ -92,7 +91,7 @@ class Trainer:
92
  self.num_warmup_updates = num_warmup_updates
93
  self.save_per_updates = save_per_updates
94
  self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
95
- self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
96
 
97
  self.batch_size = batch_size
98
  self.batch_size_type = batch_size_type
@@ -104,13 +103,10 @@ class Trainer:
104
 
105
  self.duration_predictor = duration_predictor
106
 
107
- if bnb_optimizer:
108
- import bitsandbytes as bnb
109
-
110
- self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
111
- else:
112
- self.optimizer = AdamW(model.parameters(), lr=learning_rate)
113
- self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
114
 
115
  @property
116
  def is_main(self):
@@ -120,112 +116,81 @@ class Trainer:
120
  self.accelerator.wait_for_everyone()
121
  if self.is_main:
122
  checkpoint = dict(
123
- model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
124
- optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
125
- ema_model_state_dict=self.ema_model.state_dict(),
126
- scheduler_state_dict=self.scheduler.state_dict(),
127
- step=step,
128
  )
129
  if not os.path.exists(self.checkpoint_path):
130
  os.makedirs(self.checkpoint_path)
131
- if last:
132
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
133
  print(f"Saved last checkpoint at step {step}")
134
  else:
135
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
136
 
137
  def load_checkpoint(self):
138
- if (
139
- not exists(self.checkpoint_path)
140
- or not os.path.exists(self.checkpoint_path)
141
- or not os.listdir(self.checkpoint_path)
142
- ):
143
  return 0
144
-
145
  self.accelerator.wait_for_everyone()
146
  if "model_last.pt" in os.listdir(self.checkpoint_path):
147
  latest_checkpoint = "model_last.pt"
148
  else:
149
- latest_checkpoint = sorted(
150
- [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
151
- key=lambda x: int("".join(filter(str.isdigit, x))),
152
- )[-1]
153
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
154
  checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
155
 
156
  if self.is_main:
157
- self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
158
 
159
- if "step" in checkpoint:
160
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
161
- self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
162
  if self.scheduler:
163
- self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
164
- step = checkpoint["step"]
165
  else:
166
- checkpoint["model_state_dict"] = {
167
- k.replace("ema_model.", ""): v
168
- for k, v in checkpoint["ema_model_state_dict"].items()
169
- if k not in ["initted", "step"]
170
- }
171
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
172
  step = 0
173
 
174
- del checkpoint
175
- gc.collect()
176
  return step
177
 
178
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
 
179
  if exists(resumable_with_seed):
180
  generator = torch.Generator()
181
  generator.manual_seed(resumable_with_seed)
182
- else:
183
  generator = None
184
 
185
  if self.batch_size_type == "sample":
186
- train_dataloader = DataLoader(
187
- train_dataset,
188
- collate_fn=collate_fn,
189
- num_workers=num_workers,
190
- pin_memory=True,
191
- persistent_workers=True,
192
- batch_size=self.batch_size,
193
- shuffle=True,
194
- generator=generator,
195
- )
196
  elif self.batch_size_type == "frame":
197
  self.accelerator.even_batches = False
198
  sampler = SequentialSampler(train_dataset)
199
- batch_sampler = DynamicBatchSampler(
200
- sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
201
- )
202
- train_dataloader = DataLoader(
203
- train_dataset,
204
- collate_fn=collate_fn,
205
- num_workers=num_workers,
206
- pin_memory=True,
207
- persistent_workers=True,
208
- batch_sampler=batch_sampler,
209
- )
210
  else:
211
  raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
212
-
213
  # accelerator.prepare() dispatches batches to devices;
214
  # which means the length of dataloader calculated before, should consider the number of devices
215
- warmup_steps = (
216
- self.num_warmup_updates * self.accelerator.num_processes
217
- ) # consider a fixed warmup steps while using accelerate multi-gpu ddp
218
- # otherwise by default with split_batches=False, warmup steps change with num_processes
219
  total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
220
  decay_steps = total_steps - warmup_steps
221
  warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
222
  decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
223
- self.scheduler = SequentialLR(
224
- self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
225
- )
226
- train_dataloader, self.scheduler = self.accelerator.prepare(
227
- train_dataloader, self.scheduler
228
- ) # actual steps = 1 gpu steps / gpus
229
  start_step = self.load_checkpoint()
230
  global_step = start_step
231
 
@@ -240,36 +205,23 @@ class Trainer:
240
  for epoch in range(skipped_epoch, self.epochs):
241
  self.model.train()
242
  if exists(resumable_with_seed) and epoch == skipped_epoch:
243
- progress_bar = tqdm(
244
- skipped_dataloader,
245
- desc=f"Epoch {epoch+1}/{self.epochs}",
246
- unit="step",
247
- disable=not self.accelerator.is_local_main_process,
248
- initial=skipped_batch,
249
- total=orig_epoch_step,
250
- )
251
  else:
252
- progress_bar = tqdm(
253
- train_dataloader,
254
- desc=f"Epoch {epoch+1}/{self.epochs}",
255
- unit="step",
256
- disable=not self.accelerator.is_local_main_process,
257
- )
258
 
259
  for batch in progress_bar:
260
  with self.accelerator.accumulate(self.model):
261
- text_inputs = batch["text"]
262
- mel_spec = batch["mel"].permute(0, 2, 1)
263
  mel_lengths = batch["mel_lengths"]
264
 
265
  # TODO. add duration predictor training
266
  if self.duration_predictor is not None and self.accelerator.is_local_main_process:
267
- dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
268
  self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
269
 
270
- loss, cond, pred = self.model(
271
- mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
272
- )
273
  self.accelerator.backward(loss)
274
 
275
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
@@ -286,15 +238,13 @@ class Trainer:
286
 
287
  if self.accelerator.is_local_main_process:
288
  self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
289
-
290
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
291
-
292
  if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
293
  self.save_checkpoint(global_step)
294
-
295
  if global_step % self.last_per_steps == 0:
296
  self.save_checkpoint(global_step, last=True)
297
-
298
- self.save_checkpoint(global_step, last=True)
299
-
300
  self.accelerator.end_training()
 
10
  from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
  from torch.optim.lr_scheduler import LinearLR, SequentialLR
12
 
13
+ from einops import rearrange
14
+
15
  from accelerate import Accelerator
16
  from accelerate.utils import DistributedDataParallelKwargs
17
 
 
24
 
25
  # trainer
26
 
 
27
  class Trainer:
28
  def __init__(
29
  self,
30
  model: CFM,
31
  epochs,
32
  learning_rate,
33
+ num_warmup_updates = 20000,
34
+ save_per_updates = 1000,
35
+ checkpoint_path = None,
36
+ batch_size = 32,
37
  batch_size_type: str = "sample",
38
+ max_samples = 32,
39
+ grad_accumulation_steps = 1,
40
+ max_grad_norm = 1.0,
41
  noise_scheduler: str | None = None,
42
  duration_predictor: torch.nn.Module | None = None,
43
+ wandb_project = "test_e2-tts",
44
+ wandb_run_name = "test_run",
45
  wandb_resume_id: str = None,
46
+ last_per_steps = None,
47
  accelerate_kwargs: dict = dict(),
48
+ ema_kwargs: dict = dict()
 
49
  ):
50
+
51
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
 
 
52
 
53
  self.accelerator = Accelerator(
54
+ log_with = "wandb",
55
+ kwargs_handlers = [ddp_kwargs],
56
+ gradient_accumulation_steps = grad_accumulation_steps,
57
+ **accelerate_kwargs
58
  )
59
+
60
+ if exists(wandb_resume_id):
61
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
62
+ else:
63
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
64
+ self.accelerator.init_trackers(
65
+ project_name = wandb_project,
66
+ init_kwargs=init_kwargs,
67
+ config={"epochs": epochs,
 
 
68
  "learning_rate": learning_rate,
69
+ "num_warmup_updates": num_warmup_updates,
70
  "batch_size": batch_size,
71
  "batch_size_type": batch_size_type,
72
  "max_samples": max_samples,
73
  "grad_accumulation_steps": grad_accumulation_steps,
74
  "max_grad_norm": max_grad_norm,
75
  "gpus": self.accelerator.num_processes,
76
+ "noise_scheduler": noise_scheduler}
 
77
  )
78
 
79
  self.model = model
80
 
81
  if self.is_main:
82
+ self.ema_model = EMA(
83
+ model,
84
+ include_online_model = False,
85
+ **ema_kwargs
86
+ )
87
 
88
  self.ema_model.to(self.accelerator.device)
89
 
 
91
  self.num_warmup_updates = num_warmup_updates
92
  self.save_per_updates = save_per_updates
93
  self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
94
+ self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
95
 
96
  self.batch_size = batch_size
97
  self.batch_size_type = batch_size_type
 
103
 
104
  self.duration_predictor = duration_predictor
105
 
106
+ self.optimizer = AdamW(model.parameters(), lr=learning_rate)
107
+ self.model, self.optimizer = self.accelerator.prepare(
108
+ self.model, self.optimizer
109
+ )
 
 
 
110
 
111
  @property
112
  def is_main(self):
 
116
  self.accelerator.wait_for_everyone()
117
  if self.is_main:
118
  checkpoint = dict(
119
+ model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
120
+ optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
121
+ ema_model_state_dict = self.ema_model.state_dict(),
122
+ scheduler_state_dict = self.scheduler.state_dict(),
123
+ step = step
124
  )
125
  if not os.path.exists(self.checkpoint_path):
126
  os.makedirs(self.checkpoint_path)
127
+ if last == True:
128
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
129
  print(f"Saved last checkpoint at step {step}")
130
  else:
131
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
132
 
133
  def load_checkpoint(self):
134
+ if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
 
 
 
 
135
  return 0
136
+
137
  self.accelerator.wait_for_everyone()
138
  if "model_last.pt" in os.listdir(self.checkpoint_path):
139
  latest_checkpoint = "model_last.pt"
140
  else:
141
+ latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
 
 
 
142
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
  checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
144
 
145
  if self.is_main:
146
+ self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
147
 
148
+ if 'step' in checkpoint:
149
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
150
+ self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
151
  if self.scheduler:
152
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
153
+ step = checkpoint['step']
154
  else:
155
+ checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
156
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
 
 
 
 
157
  step = 0
158
 
159
+ del checkpoint; gc.collect()
 
160
  return step
161
 
162
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
163
+
164
  if exists(resumable_with_seed):
165
  generator = torch.Generator()
166
  generator.manual_seed(resumable_with_seed)
167
+ else:
168
  generator = None
169
 
170
  if self.batch_size_type == "sample":
171
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
172
+ batch_size=self.batch_size, shuffle=True, generator=generator)
 
 
 
 
 
 
 
 
173
  elif self.batch_size_type == "frame":
174
  self.accelerator.even_batches = False
175
  sampler = SequentialSampler(train_dataset)
176
+ batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
177
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
178
+ batch_sampler=batch_sampler)
 
 
 
 
 
 
 
 
179
  else:
180
  raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
181
+
182
  # accelerator.prepare() dispatches batches to devices;
183
  # which means the length of dataloader calculated before, should consider the number of devices
184
+ warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp
185
+ # otherwise by default with split_batches=False, warmup steps change with num_processes
 
 
186
  total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
187
  decay_steps = total_steps - warmup_steps
188
  warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
189
  decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
190
+ self.scheduler = SequentialLR(self.optimizer,
191
+ schedulers=[warmup_scheduler, decay_scheduler],
192
+ milestones=[warmup_steps])
193
+ train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus
 
 
194
  start_step = self.load_checkpoint()
195
  global_step = start_step
196
 
 
205
  for epoch in range(skipped_epoch, self.epochs):
206
  self.model.train()
207
  if exists(resumable_with_seed) and epoch == skipped_epoch:
208
+ progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process,
209
+ initial=skipped_batch, total=orig_epoch_step)
 
 
 
 
 
 
210
  else:
211
+ progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
 
 
 
 
 
212
 
213
  for batch in progress_bar:
214
  with self.accelerator.accumulate(self.model):
215
+ text_inputs = batch['text']
216
+ mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
217
  mel_lengths = batch["mel_lengths"]
218
 
219
  # TODO. add duration predictor training
220
  if self.duration_predictor is not None and self.accelerator.is_local_main_process:
221
+ dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
222
  self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
223
 
224
+ loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
 
 
225
  self.accelerator.backward(loss)
226
 
227
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
 
238
 
239
  if self.accelerator.is_local_main_process:
240
  self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
241
+
242
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
243
+
244
  if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
245
  self.save_checkpoint(global_step)
246
+
247
  if global_step % self.last_per_steps == 0:
248
  self.save_checkpoint(global_step, last=True)
249
+
 
 
250
  self.accelerator.end_training()
model/utils.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
  import os
 
4
  import math
5
  import random
6
  import string
@@ -8,7 +9,6 @@ from tqdm import tqdm
8
  from collections import defaultdict
9
 
10
  import matplotlib
11
-
12
  matplotlib.use("Agg")
13
  import matplotlib.pylab as plt
14
 
@@ -17,6 +17,9 @@ import torch.nn.functional as F
17
  from torch.nn.utils.rnn import pad_sequence
18
  import torchaudio
19
 
 
 
 
20
  import jieba
21
  from pypinyin import lazy_pinyin, Style
22
 
@@ -26,102 +29,107 @@ from model.modules import MelSpec
26
 
27
  # seed everything
28
 
29
-
30
- def seed_everything(seed=0):
31
  random.seed(seed)
32
- os.environ["PYTHONHASHSEED"] = str(seed)
33
  torch.manual_seed(seed)
34
  torch.cuda.manual_seed(seed)
35
  torch.cuda.manual_seed_all(seed)
36
  torch.backends.cudnn.deterministic = True
37
  torch.backends.cudnn.benchmark = False
38
 
39
-
40
  # helpers
41
 
42
-
43
  def exists(v):
44
  return v is not None
45
 
46
-
47
  def default(v, d):
48
  return v if exists(v) else d
49
 
50
-
51
  # tensor helpers
52
 
 
 
 
 
53
 
54
- def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
55
  if not exists(length):
56
  length = t.amax()
57
 
58
- seq = torch.arange(length, device=t.device)
59
- return seq[None, :] < t[:, None]
60
-
61
-
62
- def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
63
- max_seq_len = seq_len.max().item()
64
- seq = torch.arange(max_seq_len, device=start.device).long()
65
- start_mask = seq[None, :] >= start[:, None]
66
- end_mask = seq[None, :] < end[:, None]
67
- return start_mask & end_mask
68
 
 
 
 
 
 
 
 
 
69
 
70
- def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
 
 
 
71
  lengths = (frac_lengths * seq_len).long()
72
  max_start = seq_len - lengths
73
 
74
  rand = torch.rand_like(frac_lengths)
75
- start = (max_start * rand).long().clamp(min=0)
76
  end = start + lengths
77
 
78
  return mask_from_start_end_indices(seq_len, start, end)
79
 
 
 
 
 
80
 
81
- def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
82
  if not exists(mask):
83
- return t.mean(dim=1)
84
 
85
- t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
86
- num = t.sum(dim=1)
87
- den = mask.float().sum(dim=1)
88
 
89
- return num / den.clamp(min=1.0)
90
 
91
 
92
  # simple utf-8 tokenizer, since paper went character based
93
- def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
94
- list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
95
- text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
 
 
 
96
  return text
97
 
98
-
99
  # char tokenizer, based on custom dataset's extracted .txt file
100
  def list_str_to_idx(
101
  text: list[str] | list[list[str]],
102
  vocab_char_map: dict[str, int], # {char: idx}
103
- padding_value=-1,
104
- ) -> int["b nt"]: # noqa: F722
105
  list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
106
- text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
107
  return text
108
 
109
 
110
  # Get tokenizer
111
 
112
-
113
  def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
114
- """
115
  tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
116
  - "char" for char-wise tokenizer, need .txt vocab_file
117
  - "byte" for utf-8 tokenizer
118
  - "custom" if you're directly passing in a path to the vocab.txt you want to use
119
  vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
120
  - if use "char", derived from unfiltered character & symbol counts of custom dataset
121
- - if use "byte", set to 256 (unicode byte range)
122
- """
123
  if tokenizer in ["pinyin", "char"]:
124
- with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
125
  vocab_char_map = {}
126
  for i, char in enumerate(f):
127
  vocab_char_map[char[:-1]] = i
@@ -132,7 +140,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
132
  vocab_char_map = None
133
  vocab_size = 256
134
  elif tokenizer == "custom":
135
- with open(dataset_name, "r", encoding="utf-8") as f:
136
  vocab_char_map = {}
137
  for i, char in enumerate(f):
138
  vocab_char_map[char[:-1]] = i
@@ -143,19 +151,16 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
143
 
144
  # convert char to pinyin
145
 
146
-
147
- def convert_char_to_pinyin(text_list, polyphone=True):
148
  final_text_list = []
149
- god_knows_why_en_testset_contains_zh_quote = str.maketrans(
150
- {"“": '"', "”": '"', "‘": "'", "’": "'"}
151
- ) # in case librispeech (orig no-pc) test-clean
152
- custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov
153
  for text in text_list:
154
  char_list = []
155
  text = text.translate(god_knows_why_en_testset_contains_zh_quote)
156
  text = text.translate(custom_trans)
157
  for seg in jieba.cut(text):
158
- seg_byte_len = len(bytes(seg, "UTF-8"))
159
  if seg_byte_len == len(seg): # if pure alphabets and symbols
160
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
161
  char_list.append(" ")
@@ -184,7 +189,7 @@ def convert_char_to_pinyin(text_list, polyphone=True):
184
  # save spectrogram
185
  def save_spectrogram(spectrogram, path):
186
  plt.figure(figsize=(12, 4))
187
- plt.imshow(spectrogram, origin="lower", aspect="auto")
188
  plt.colorbar()
189
  plt.savefig(path)
190
  plt.close()
@@ -192,15 +197,13 @@ def save_spectrogram(spectrogram, path):
192
 
193
  # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
194
  def get_seedtts_testset_metainfo(metalst):
195
- f = open(metalst)
196
- lines = f.readlines()
197
- f.close()
198
  metainfo = []
199
  for line in lines:
200
- if len(line.strip().split("|")) == 5:
201
- utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
202
- elif len(line.strip().split("|")) == 4:
203
- utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
204
  gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
205
  if not os.path.isabs(prompt_wav):
206
  prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
@@ -210,20 +213,18 @@ def get_seedtts_testset_metainfo(metalst):
210
 
211
  # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
212
  def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
213
- f = open(metalst)
214
- lines = f.readlines()
215
- f.close()
216
  metainfo = []
217
  for line in lines:
218
- ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
219
 
220
  # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
221
- ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
222
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
223
 
224
  # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
225
- gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
226
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
227
 
228
  metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
229
 
@@ -235,30 +236,21 @@ def padded_mel_batch(ref_mels):
235
  max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
236
  padded_ref_mels = []
237
  for mel in ref_mels:
238
- padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
239
  padded_ref_mels.append(padded_ref_mel)
240
  padded_ref_mels = torch.stack(padded_ref_mels)
241
- padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
242
  return padded_ref_mels
243
 
244
 
245
  # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
246
 
247
-
248
  def get_inference_prompt(
249
- metainfo,
250
- speed=1.0,
251
- tokenizer="pinyin",
252
- polyphone=True,
253
- target_sample_rate=24000,
254
- n_mel_channels=100,
255
- hop_length=256,
256
- target_rms=0.1,
257
- use_truth_duration=False,
258
- infer_batch_size=1,
259
- num_buckets=200,
260
- min_secs=3,
261
- max_secs=40,
262
  ):
263
  prompts_all = []
264
 
@@ -266,15 +258,13 @@ def get_inference_prompt(
266
  max_tokens = max_secs * target_sample_rate // hop_length
267
 
268
  batch_accum = [0] * num_buckets
269
- utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
270
- [[] for _ in range(num_buckets)] for _ in range(6)
271
- )
272
 
273
- mel_spectrogram = MelSpec(
274
- target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
275
- )
276
 
277
  for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
 
278
  # Audio
279
  ref_audio, ref_sr = torchaudio.load(prompt_wav)
280
  ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
@@ -286,11 +276,11 @@ def get_inference_prompt(
286
  ref_audio = resampler(ref_audio)
287
 
288
  # Text
289
- if len(prompt_text[-1].encode("utf-8")) == 1:
290
  prompt_text = prompt_text + " "
291
  text = [prompt_text + gt_text]
292
  if tokenizer == "pinyin":
293
- text_list = convert_char_to_pinyin(text, polyphone=polyphone)
294
  else:
295
  text_list = text
296
 
@@ -306,19 +296,19 @@ def get_inference_prompt(
306
  # # test vocoder resynthesis
307
  # ref_audio = gt_audio
308
  else:
309
- ref_text_len = len(prompt_text.encode("utf-8"))
310
- gen_text_len = len(gt_text.encode("utf-8"))
 
311
  total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
312
 
313
  # to mel spectrogram
314
  ref_mel = mel_spectrogram(ref_audio)
315
- ref_mel = ref_mel.squeeze(0)
316
 
317
  # deal with batch
318
  assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
319
- assert (
320
- min_tokens <= total_mel_len <= max_tokens
321
- ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
322
  bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
323
 
324
  utts[bucket_i].append(utt)
@@ -332,39 +322,28 @@ def get_inference_prompt(
332
 
333
  if batch_accum[bucket_i] >= infer_batch_size:
334
  # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
335
- prompts_all.append(
336
- (
337
- utts[bucket_i],
338
- ref_rms_list[bucket_i],
339
- padded_mel_batch(ref_mels[bucket_i]),
340
- ref_mel_lens[bucket_i],
341
- total_mel_lens[bucket_i],
342
- final_text_list[bucket_i],
343
- )
344
- )
345
  batch_accum[bucket_i] = 0
346
- (
347
- utts[bucket_i],
348
- ref_rms_list[bucket_i],
349
- ref_mels[bucket_i],
350
- ref_mel_lens[bucket_i],
351
- total_mel_lens[bucket_i],
352
- final_text_list[bucket_i],
353
- ) = [], [], [], [], [], []
354
 
355
  # add residual
356
  for bucket_i, bucket_frames in enumerate(batch_accum):
357
  if bucket_frames > 0:
358
- prompts_all.append(
359
- (
360
- utts[bucket_i],
361
- ref_rms_list[bucket_i],
362
- padded_mel_batch(ref_mels[bucket_i]),
363
- ref_mel_lens[bucket_i],
364
- total_mel_lens[bucket_i],
365
- final_text_list[bucket_i],
366
- )
367
- )
368
  # not only leave easy work for last workers
369
  random.seed(666)
370
  random.shuffle(prompts_all)
@@ -375,7 +354,6 @@ def get_inference_prompt(
375
  # get wav_res_ref_text of seed-tts test metalst
376
  # https://github.com/BytedanceSpeech/seed-tts-eval
377
 
378
-
379
  def get_seed_tts_test(metalst, gen_wav_dir, gpus):
380
  f = open(metalst)
381
  lines = f.readlines()
@@ -383,14 +361,14 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
383
 
384
  test_set_ = []
385
  for line in tqdm(lines):
386
- if len(line.strip().split("|")) == 5:
387
- utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
388
- elif len(line.strip().split("|")) == 4:
389
- utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
390
 
391
- if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")):
392
  continue
393
- gen_wav = os.path.join(gen_wav_dir, utt + ".wav")
394
  if not os.path.isabs(prompt_wav):
395
  prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
396
 
@@ -399,69 +377,65 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
399
  num_jobs = len(gpus)
400
  if num_jobs == 1:
401
  return [(gpus[0], test_set_)]
402
-
403
  wav_per_job = len(test_set_) // num_jobs + 1
404
  test_set = []
405
  for i in range(num_jobs):
406
- test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
407
 
408
  return test_set
409
 
410
 
411
  # get librispeech test-clean cross sentence test
412
 
413
-
414
- def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
415
  f = open(metalst)
416
  lines = f.readlines()
417
  f.close()
418
 
419
  test_set_ = []
420
  for line in tqdm(lines):
421
- ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
422
 
423
  if eval_ground_truth:
424
- gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
425
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
426
  else:
427
- if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
428
  raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
429
- gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
430
 
431
- ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
432
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
433
 
434
  test_set_.append((gen_wav, ref_wav, gen_txt))
435
 
436
  num_jobs = len(gpus)
437
  if num_jobs == 1:
438
  return [(gpus[0], test_set_)]
439
-
440
  wav_per_job = len(test_set_) // num_jobs + 1
441
  test_set = []
442
  for i in range(num_jobs):
443
- test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
444
 
445
  return test_set
446
 
447
 
448
  # load asr model
449
 
450
-
451
- def load_asr_model(lang, ckpt_dir=""):
452
  if lang == "zh":
453
  from funasr import AutoModel
454
-
455
  model = AutoModel(
456
- model=os.path.join(ckpt_dir, "paraformer-zh"),
457
- # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
458
  # punc_model = os.path.join(ckpt_dir, "ct-punc"),
459
- # spk_model = os.path.join(ckpt_dir, "cam++"),
460
  disable_update=True,
461
- ) # following seed-tts setting
462
  elif lang == "en":
463
  from faster_whisper import WhisperModel
464
-
465
  model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
466
  model = WhisperModel(model_size, device="cuda", compute_type="float16")
467
  return model
@@ -469,50 +443,44 @@ def load_asr_model(lang, ckpt_dir=""):
469
 
470
  # WER Evaluation, the way Seed-TTS does
471
 
472
-
473
  def run_asr_wer(args):
474
  rank, lang, test_set, ckpt_dir = args
475
 
476
  if lang == "zh":
477
  import zhconv
478
-
479
  torch.cuda.set_device(rank)
480
  elif lang == "en":
481
  os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
482
  else:
483
- raise NotImplementedError(
484
- "lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now."
485
- )
486
-
487
- asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir)
488
 
 
 
489
  from zhon.hanzi import punctuation
490
-
491
  punctuation_all = punctuation + string.punctuation
492
  wers = []
493
 
494
  from jiwer import compute_measures
495
-
496
  for gen_wav, prompt_wav, truth in tqdm(test_set):
497
  if lang == "zh":
498
  res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
499
  hypo = res[0]["text"]
500
- hypo = zhconv.convert(hypo, "zh-cn")
501
  elif lang == "en":
502
  segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
503
- hypo = ""
504
  for segment in segments:
505
- hypo = hypo + " " + segment.text
506
 
507
  # raw_truth = truth
508
  # raw_hypo = hypo
509
 
510
  for x in punctuation_all:
511
- truth = truth.replace(x, "")
512
- hypo = hypo.replace(x, "")
513
 
514
- truth = truth.replace(" ", " ")
515
- hypo = hypo.replace(" ", " ")
516
 
517
  if lang == "zh":
518
  truth = " ".join([x for x in truth])
@@ -536,22 +504,22 @@ def run_asr_wer(args):
536
 
537
  # SIM Evaluation
538
 
539
-
540
  def run_sim(args):
541
  rank, test_set, ckpt_dir = args
542
  device = f"cuda:{rank}"
543
 
544
- model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
545
  state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
546
- model.load_state_dict(state_dict["model"], strict=False)
547
 
548
- use_gpu = True if torch.cuda.is_available() else False
549
  if use_gpu:
550
  model = model.cuda(device)
551
  model.eval()
552
 
553
  sim_list = []
554
  for wav1, wav2, truth in tqdm(test_set):
 
555
  wav1, sr1 = torchaudio.load(wav1)
556
  wav2, sr2 = torchaudio.load(wav2)
557
 
@@ -566,21 +534,20 @@ def run_sim(args):
566
  with torch.no_grad():
567
  emb1 = model(wav1)
568
  emb2 = model(wav2)
569
-
570
  sim = F.cosine_similarity(emb1, emb2)[0].item()
571
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
572
  sim_list.append(sim)
573
-
574
  return sim_list
575
 
576
 
577
  # filter func for dirty data with many repetitions
578
 
579
-
580
- def repetition_found(text, length=2, tolerance=10):
581
  pattern_count = defaultdict(int)
582
  for i in range(len(text) - length + 1):
583
- pattern = text[i : i + length]
584
  pattern_count[pattern] += 1
585
  for pattern, count in pattern_count.items():
586
  if count > tolerance:
@@ -590,31 +557,24 @@ def repetition_found(text, length=2, tolerance=10):
590
 
591
  # load model checkpoint for inference
592
 
593
-
594
- def load_checkpoint(model, ckpt_path, device, use_ema=True):
595
- if device == "cuda":
596
- model = model.half()
597
 
598
  ckpt_type = ckpt_path.split(".")[-1]
599
  if ckpt_type == "safetensors":
600
  from safetensors.torch import load_file
601
-
602
- checkpoint = load_file(ckpt_path)
603
  else:
604
- checkpoint = torch.load(ckpt_path, weights_only=True)
605
 
606
- if use_ema:
 
607
  if ckpt_type == "safetensors":
608
- checkpoint = {"ema_model_state_dict": checkpoint}
609
- checkpoint["model_state_dict"] = {
610
- k.replace("ema_model.", ""): v
611
- for k, v in checkpoint["ema_model_state_dict"].items()
612
- if k not in ["initted", "step"]
613
- }
614
- model.load_state_dict(checkpoint["model_state_dict"])
615
  else:
616
- if ckpt_type == "safetensors":
617
- checkpoint = {"model_state_dict": checkpoint}
618
- model.load_state_dict(checkpoint["model_state_dict"])
619
-
620
- return model.to(device)
 
1
  from __future__ import annotations
2
 
3
  import os
4
+ import re
5
  import math
6
  import random
7
  import string
 
9
  from collections import defaultdict
10
 
11
  import matplotlib
 
12
  matplotlib.use("Agg")
13
  import matplotlib.pylab as plt
14
 
 
17
  from torch.nn.utils.rnn import pad_sequence
18
  import torchaudio
19
 
20
+ import einx
21
+ from einops import rearrange, reduce
22
+
23
  import jieba
24
  from pypinyin import lazy_pinyin, Style
25
 
 
29
 
30
  # seed everything
31
 
32
+ def seed_everything(seed = 0):
 
33
  random.seed(seed)
34
+ os.environ['PYTHONHASHSEED'] = str(seed)
35
  torch.manual_seed(seed)
36
  torch.cuda.manual_seed(seed)
37
  torch.cuda.manual_seed_all(seed)
38
  torch.backends.cudnn.deterministic = True
39
  torch.backends.cudnn.benchmark = False
40
 
 
41
  # helpers
42
 
 
43
  def exists(v):
44
  return v is not None
45
 
 
46
  def default(v, d):
47
  return v if exists(v) else d
48
 
 
49
  # tensor helpers
50
 
51
+ def lens_to_mask(
52
+ t: int['b'],
53
+ length: int | None = None
54
+ ) -> bool['b n']:
55
 
 
56
  if not exists(length):
57
  length = t.amax()
58
 
59
+ seq = torch.arange(length, device = t.device)
60
+ return einx.less('n, b -> b n', seq, t)
 
 
 
 
 
 
 
 
61
 
62
+ def mask_from_start_end_indices(
63
+ seq_len: int['b'],
64
+ start: int['b'],
65
+ end: int['b']
66
+ ):
67
+ max_seq_len = seq_len.max().item()
68
+ seq = torch.arange(max_seq_len, device = start.device).long()
69
+ return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
70
 
71
+ def mask_from_frac_lengths(
72
+ seq_len: int['b'],
73
+ frac_lengths: float['b']
74
+ ):
75
  lengths = (frac_lengths * seq_len).long()
76
  max_start = seq_len - lengths
77
 
78
  rand = torch.rand_like(frac_lengths)
79
+ start = (max_start * rand).long().clamp(min = 0)
80
  end = start + lengths
81
 
82
  return mask_from_start_end_indices(seq_len, start, end)
83
 
84
+ def maybe_masked_mean(
85
+ t: float['b n d'],
86
+ mask: bool['b n'] = None
87
+ ) -> float['b d']:
88
 
 
89
  if not exists(mask):
90
+ return t.mean(dim = 1)
91
 
92
+ t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
93
+ num = reduce(t, 'b n d -> b d', 'sum')
94
+ den = reduce(mask.float(), 'b n -> b', 'sum')
95
 
96
+ return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
97
 
98
 
99
  # simple utf-8 tokenizer, since paper went character based
100
+ def list_str_to_tensor(
101
+ text: list[str],
102
+ padding_value = -1
103
+ ) -> int['b nt']:
104
+ list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
105
+ text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
106
  return text
107
 
 
108
  # char tokenizer, based on custom dataset's extracted .txt file
109
  def list_str_to_idx(
110
  text: list[str] | list[list[str]],
111
  vocab_char_map: dict[str, int], # {char: idx}
112
+ padding_value = -1
113
+ ) -> int['b nt']:
114
  list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
115
+ text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
116
  return text
117
 
118
 
119
  # Get tokenizer
120
 
 
121
  def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
122
+ '''
123
  tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
124
  - "char" for char-wise tokenizer, need .txt vocab_file
125
  - "byte" for utf-8 tokenizer
126
  - "custom" if you're directly passing in a path to the vocab.txt you want to use
127
  vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
128
  - if use "char", derived from unfiltered character & symbol counts of custom dataset
129
+ - if use "byte", set to 256 (unicode byte range)
130
+ '''
131
  if tokenizer in ["pinyin", "char"]:
132
+ with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
133
  vocab_char_map = {}
134
  for i, char in enumerate(f):
135
  vocab_char_map[char[:-1]] = i
 
140
  vocab_char_map = None
141
  vocab_size = 256
142
  elif tokenizer == "custom":
143
+ with open (dataset_name, "r", encoding="utf-8") as f:
144
  vocab_char_map = {}
145
  for i, char in enumerate(f):
146
  vocab_char_map[char[:-1]] = i
 
151
 
152
  # convert char to pinyin
153
 
154
+ def convert_char_to_pinyin(text_list, polyphone = True):
 
155
  final_text_list = []
156
+ god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
157
+ custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov
 
 
158
  for text in text_list:
159
  char_list = []
160
  text = text.translate(god_knows_why_en_testset_contains_zh_quote)
161
  text = text.translate(custom_trans)
162
  for seg in jieba.cut(text):
163
+ seg_byte_len = len(bytes(seg, 'UTF-8'))
164
  if seg_byte_len == len(seg): # if pure alphabets and symbols
165
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
166
  char_list.append(" ")
 
189
  # save spectrogram
190
  def save_spectrogram(spectrogram, path):
191
  plt.figure(figsize=(12, 4))
192
+ plt.imshow(spectrogram, origin='lower', aspect='auto')
193
  plt.colorbar()
194
  plt.savefig(path)
195
  plt.close()
 
197
 
198
  # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
199
  def get_seedtts_testset_metainfo(metalst):
200
+ f = open(metalst); lines = f.readlines(); f.close()
 
 
201
  metainfo = []
202
  for line in lines:
203
+ if len(line.strip().split('|')) == 5:
204
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
205
+ elif len(line.strip().split('|')) == 4:
206
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
207
  gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
208
  if not os.path.isabs(prompt_wav):
209
  prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
 
213
 
214
  # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
215
  def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
216
+ f = open(metalst); lines = f.readlines(); f.close()
 
 
217
  metainfo = []
218
  for line in lines:
219
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
220
 
221
  # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
222
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
223
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
224
 
225
  # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
226
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
227
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
228
 
229
  metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
230
 
 
236
  max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
237
  padded_ref_mels = []
238
  for mel in ref_mels:
239
+ padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
240
  padded_ref_mels.append(padded_ref_mel)
241
  padded_ref_mels = torch.stack(padded_ref_mels)
242
+ padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d')
243
  return padded_ref_mels
244
 
245
 
246
  # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
247
 
 
248
  def get_inference_prompt(
249
+ metainfo,
250
+ speed = 1., tokenizer = "pinyin", polyphone = True,
251
+ target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
252
+ use_truth_duration = False,
253
+ infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
 
 
 
 
 
 
 
 
254
  ):
255
  prompts_all = []
256
 
 
258
  max_tokens = max_secs * target_sample_rate // hop_length
259
 
260
  batch_accum = [0] * num_buckets
261
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
262
+ ([[] for _ in range(num_buckets)] for _ in range(6))
 
263
 
264
+ mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
 
 
265
 
266
  for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
267
+
268
  # Audio
269
  ref_audio, ref_sr = torchaudio.load(prompt_wav)
270
  ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
 
276
  ref_audio = resampler(ref_audio)
277
 
278
  # Text
279
+ if len(prompt_text[-1].encode('utf-8')) == 1:
280
  prompt_text = prompt_text + " "
281
  text = [prompt_text + gt_text]
282
  if tokenizer == "pinyin":
283
+ text_list = convert_char_to_pinyin(text, polyphone = polyphone)
284
  else:
285
  text_list = text
286
 
 
296
  # # test vocoder resynthesis
297
  # ref_audio = gt_audio
298
  else:
299
+ zh_pause_punc = r"。,、;:?!"
300
+ ref_text_len = len(prompt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, prompt_text))
301
+ gen_text_len = len(gt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gt_text))
302
  total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
303
 
304
  # to mel spectrogram
305
  ref_mel = mel_spectrogram(ref_audio)
306
+ ref_mel = rearrange(ref_mel, '1 d n -> d n')
307
 
308
  # deal with batch
309
  assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
310
+ assert min_tokens <= total_mel_len <= max_tokens, \
311
+ f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
 
312
  bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
313
 
314
  utts[bucket_i].append(utt)
 
322
 
323
  if batch_accum[bucket_i] >= infer_batch_size:
324
  # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
325
+ prompts_all.append((
326
+ utts[bucket_i],
327
+ ref_rms_list[bucket_i],
328
+ padded_mel_batch(ref_mels[bucket_i]),
329
+ ref_mel_lens[bucket_i],
330
+ total_mel_lens[bucket_i],
331
+ final_text_list[bucket_i]
332
+ ))
 
 
333
  batch_accum[bucket_i] = 0
334
+ utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
 
 
 
 
 
 
 
335
 
336
  # add residual
337
  for bucket_i, bucket_frames in enumerate(batch_accum):
338
  if bucket_frames > 0:
339
+ prompts_all.append((
340
+ utts[bucket_i],
341
+ ref_rms_list[bucket_i],
342
+ padded_mel_batch(ref_mels[bucket_i]),
343
+ ref_mel_lens[bucket_i],
344
+ total_mel_lens[bucket_i],
345
+ final_text_list[bucket_i]
346
+ ))
 
 
347
  # not only leave easy work for last workers
348
  random.seed(666)
349
  random.shuffle(prompts_all)
 
354
  # get wav_res_ref_text of seed-tts test metalst
355
  # https://github.com/BytedanceSpeech/seed-tts-eval
356
 
 
357
  def get_seed_tts_test(metalst, gen_wav_dir, gpus):
358
  f = open(metalst)
359
  lines = f.readlines()
 
361
 
362
  test_set_ = []
363
  for line in tqdm(lines):
364
+ if len(line.strip().split('|')) == 5:
365
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
366
+ elif len(line.strip().split('|')) == 4:
367
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
368
 
369
+ if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
370
  continue
371
+ gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
372
  if not os.path.isabs(prompt_wav):
373
  prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
374
 
 
377
  num_jobs = len(gpus)
378
  if num_jobs == 1:
379
  return [(gpus[0], test_set_)]
380
+
381
  wav_per_job = len(test_set_) // num_jobs + 1
382
  test_set = []
383
  for i in range(num_jobs):
384
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
385
 
386
  return test_set
387
 
388
 
389
  # get librispeech test-clean cross sentence test
390
 
391
+ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
 
392
  f = open(metalst)
393
  lines = f.readlines()
394
  f.close()
395
 
396
  test_set_ = []
397
  for line in tqdm(lines):
398
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
399
 
400
  if eval_ground_truth:
401
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
402
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
403
  else:
404
+ if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
405
  raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
406
+ gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
407
 
408
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
409
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
410
 
411
  test_set_.append((gen_wav, ref_wav, gen_txt))
412
 
413
  num_jobs = len(gpus)
414
  if num_jobs == 1:
415
  return [(gpus[0], test_set_)]
416
+
417
  wav_per_job = len(test_set_) // num_jobs + 1
418
  test_set = []
419
  for i in range(num_jobs):
420
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
421
 
422
  return test_set
423
 
424
 
425
  # load asr model
426
 
427
+ def load_asr_model(lang, ckpt_dir = ""):
 
428
  if lang == "zh":
429
  from funasr import AutoModel
 
430
  model = AutoModel(
431
+ model = os.path.join(ckpt_dir, "paraformer-zh"),
432
+ # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
433
  # punc_model = os.path.join(ckpt_dir, "ct-punc"),
434
+ # spk_model = os.path.join(ckpt_dir, "cam++"),
435
  disable_update=True,
436
+ ) # following seed-tts setting
437
  elif lang == "en":
438
  from faster_whisper import WhisperModel
 
439
  model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
440
  model = WhisperModel(model_size, device="cuda", compute_type="float16")
441
  return model
 
443
 
444
  # WER Evaluation, the way Seed-TTS does
445
 
 
446
  def run_asr_wer(args):
447
  rank, lang, test_set, ckpt_dir = args
448
 
449
  if lang == "zh":
450
  import zhconv
 
451
  torch.cuda.set_device(rank)
452
  elif lang == "en":
453
  os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
454
  else:
455
+ raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
 
 
 
 
456
 
457
+ asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
458
+
459
  from zhon.hanzi import punctuation
 
460
  punctuation_all = punctuation + string.punctuation
461
  wers = []
462
 
463
  from jiwer import compute_measures
 
464
  for gen_wav, prompt_wav, truth in tqdm(test_set):
465
  if lang == "zh":
466
  res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
467
  hypo = res[0]["text"]
468
+ hypo = zhconv.convert(hypo, 'zh-cn')
469
  elif lang == "en":
470
  segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
471
+ hypo = ''
472
  for segment in segments:
473
+ hypo = hypo + ' ' + segment.text
474
 
475
  # raw_truth = truth
476
  # raw_hypo = hypo
477
 
478
  for x in punctuation_all:
479
+ truth = truth.replace(x, '')
480
+ hypo = hypo.replace(x, '')
481
 
482
+ truth = truth.replace(' ', ' ')
483
+ hypo = hypo.replace(' ', ' ')
484
 
485
  if lang == "zh":
486
  truth = " ".join([x for x in truth])
 
504
 
505
  # SIM Evaluation
506
 
 
507
  def run_sim(args):
508
  rank, test_set, ckpt_dir = args
509
  device = f"cuda:{rank}"
510
 
511
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
512
  state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
513
+ model.load_state_dict(state_dict['model'], strict=False)
514
 
515
+ use_gpu=True if torch.cuda.is_available() else False
516
  if use_gpu:
517
  model = model.cuda(device)
518
  model.eval()
519
 
520
  sim_list = []
521
  for wav1, wav2, truth in tqdm(test_set):
522
+
523
  wav1, sr1 = torchaudio.load(wav1)
524
  wav2, sr2 = torchaudio.load(wav2)
525
 
 
534
  with torch.no_grad():
535
  emb1 = model(wav1)
536
  emb2 = model(wav2)
537
+
538
  sim = F.cosine_similarity(emb1, emb2)[0].item()
539
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
540
  sim_list.append(sim)
541
+
542
  return sim_list
543
 
544
 
545
  # filter func for dirty data with many repetitions
546
 
547
+ def repetition_found(text, length = 2, tolerance = 10):
 
548
  pattern_count = defaultdict(int)
549
  for i in range(len(text) - length + 1):
550
+ pattern = text[i:i + length]
551
  pattern_count[pattern] += 1
552
  for pattern, count in pattern_count.items():
553
  if count > tolerance:
 
557
 
558
  # load model checkpoint for inference
559
 
560
+ def load_checkpoint(model, ckpt_path, device, use_ema = True):
561
+ from ema_pytorch import EMA
 
 
562
 
563
  ckpt_type = ckpt_path.split(".")[-1]
564
  if ckpt_type == "safetensors":
565
  from safetensors.torch import load_file
566
+ checkpoint = load_file(ckpt_path, device=device)
 
567
  else:
568
+ checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
569
 
570
+ if use_ema == True:
571
+ ema_model = EMA(model, include_online_model = False).to(device)
572
  if ckpt_type == "safetensors":
573
+ ema_model.load_state_dict(checkpoint)
574
+ else:
575
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
576
+ ema_model.copy_params_from_ema_to_model()
 
 
 
577
  else:
578
+ model.load_state_dict(checkpoint['model_state_dict'])
579
+
580
+ return model
 
 
model/utils_infer.py DELETED
@@ -1,357 +0,0 @@
1
- # A unified script for inference process
2
- # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
3
-
4
- import re
5
- import tempfile
6
-
7
- import numpy as np
8
- import torch
9
- import torchaudio
10
- import tqdm
11
- from pydub import AudioSegment, silence
12
- from transformers import pipeline
13
- from vocos import Vocos
14
-
15
- from model import CFM
16
- from model.utils import (
17
- load_checkpoint,
18
- get_tokenizer,
19
- convert_char_to_pinyin,
20
- )
21
-
22
-
23
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
24
-
25
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
26
-
27
-
28
- # -----------------------------------------
29
-
30
- target_sample_rate = 24000
31
- n_mel_channels = 100
32
- hop_length = 256
33
- target_rms = 0.1
34
- cross_fade_duration = 0.15
35
- ode_method = "euler"
36
- nfe_step = 32 # 16, 32
37
- cfg_strength = 2.0
38
- sway_sampling_coef = -1.0
39
- speed = 1.0
40
- fix_duration = None
41
-
42
- # -----------------------------------------
43
-
44
-
45
- # chunk text into smaller pieces
46
-
47
-
48
- def chunk_text(text, max_chars=135):
49
- """
50
- Splits the input text into chunks, each with a maximum number of characters.
51
-
52
- Args:
53
- text (str): The text to be split.
54
- max_chars (int): The maximum number of characters per chunk.
55
-
56
- Returns:
57
- List[str]: A list of text chunks.
58
- """
59
- chunks = []
60
- current_chunk = ""
61
- # Split the text into sentences based on punctuation followed by whitespace
62
- sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
63
-
64
- for sentence in sentences:
65
- if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
66
- current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
67
- else:
68
- if current_chunk:
69
- chunks.append(current_chunk.strip())
70
- current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
71
-
72
- if current_chunk:
73
- chunks.append(current_chunk.strip())
74
-
75
- return chunks
76
-
77
-
78
- # load vocoder
79
- def load_vocoder(is_local=False, local_path="", device=device):
80
- if is_local:
81
- print(f"Load vocos from local path {local_path}")
82
- vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
83
- state_dict = torch.load(f"{local_path}/pytorch_model.bin", map_location=device)
84
- vocos.load_state_dict(state_dict)
85
- vocos.eval()
86
- else:
87
- print("Download Vocos from huggingface charactr/vocos-mel-24khz")
88
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
89
- return vocos
90
-
91
-
92
- # load asr pipeline
93
-
94
- asr_pipe = None
95
-
96
-
97
- def initialize_asr_pipeline(device=device):
98
- global asr_pipe
99
- asr_pipe = pipeline(
100
- "automatic-speech-recognition",
101
- model="openai/whisper-large-v3-turbo",
102
- torch_dtype=torch.float16,
103
- device=device,
104
- )
105
-
106
-
107
- # load model for inference
108
-
109
-
110
- def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=device):
111
- if vocab_file == "":
112
- vocab_file = "Emilia_ZH_EN"
113
- tokenizer = "pinyin"
114
- else:
115
- tokenizer = "custom"
116
-
117
- print("\nvocab : ", vocab_file)
118
- print("tokenizer : ", tokenizer)
119
- print("model : ", ckpt_path, "\n")
120
-
121
- vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
122
- model = CFM(
123
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
124
- mel_spec_kwargs=dict(
125
- target_sample_rate=target_sample_rate,
126
- n_mel_channels=n_mel_channels,
127
- hop_length=hop_length,
128
- ),
129
- odeint_kwargs=dict(
130
- method=ode_method,
131
- ),
132
- vocab_char_map=vocab_char_map,
133
- ).to(device)
134
-
135
- model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
136
-
137
- return model
138
-
139
-
140
- # preprocess reference audio and text
141
-
142
-
143
- def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=device):
144
- show_info("Converting audio...")
145
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
146
- aseg = AudioSegment.from_file(ref_audio_orig)
147
-
148
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
149
- non_silent_wave = AudioSegment.silent(duration=0)
150
- for non_silent_seg in non_silent_segs:
151
- non_silent_wave += non_silent_seg
152
- aseg = non_silent_wave
153
-
154
- audio_duration = len(aseg)
155
- if audio_duration > 15000:
156
- show_info("Audio is over 15s, clipping to only first 15s.")
157
- aseg = aseg[:15000]
158
- aseg.export(f.name, format="wav")
159
- ref_audio = f.name
160
-
161
- if not ref_text.strip():
162
- global asr_pipe
163
- if asr_pipe is None:
164
- initialize_asr_pipeline(device=device)
165
- show_info("No reference text provided, transcribing reference audio...")
166
- ref_text = asr_pipe(
167
- ref_audio,
168
- chunk_length_s=30,
169
- batch_size=128,
170
- generate_kwargs={"task": "transcribe"},
171
- return_timestamps=False,
172
- )["text"].strip()
173
- show_info("Finished transcription")
174
- else:
175
- show_info("Using custom reference text...")
176
-
177
- # Add the functionality to ensure it ends with ". "
178
- if not ref_text.endswith(". ") and not ref_text.endswith("。"):
179
- if ref_text.endswith("."):
180
- ref_text += " "
181
- else:
182
- ref_text += ". "
183
-
184
- return ref_audio, ref_text
185
-
186
-
187
- # infer process: chunk text -> infer batches [i.e. infer_batch_process()]
188
-
189
-
190
- def infer_process(
191
- ref_audio,
192
- ref_text,
193
- gen_text,
194
- model_obj,
195
- show_info=print,
196
- progress=tqdm,
197
- target_rms=target_rms,
198
- cross_fade_duration=cross_fade_duration,
199
- nfe_step=nfe_step,
200
- cfg_strength=cfg_strength,
201
- sway_sampling_coef=sway_sampling_coef,
202
- speed=speed,
203
- fix_duration=fix_duration,
204
- device=device,
205
- ):
206
- # Split the input text into batches
207
- audio, sr = torchaudio.load(ref_audio)
208
- max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
209
- gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
210
- for i, gen_text in enumerate(gen_text_batches):
211
- print(f"gen_text {i}", gen_text)
212
-
213
- show_info(f"Generating audio in {len(gen_text_batches)} batches...")
214
- return infer_batch_process(
215
- (audio, sr),
216
- ref_text,
217
- gen_text_batches,
218
- model_obj,
219
- progress=progress,
220
- target_rms=target_rms,
221
- cross_fade_duration=cross_fade_duration,
222
- nfe_step=nfe_step,
223
- cfg_strength=cfg_strength,
224
- sway_sampling_coef=sway_sampling_coef,
225
- speed=speed,
226
- fix_duration=fix_duration,
227
- device=device,
228
- )
229
-
230
-
231
- # infer batches
232
-
233
-
234
- def infer_batch_process(
235
- ref_audio,
236
- ref_text,
237
- gen_text_batches,
238
- model_obj,
239
- progress=tqdm,
240
- target_rms=0.1,
241
- cross_fade_duration=0.15,
242
- nfe_step=32,
243
- cfg_strength=2.0,
244
- sway_sampling_coef=-1,
245
- speed=1,
246
- fix_duration=None,
247
- device=None,
248
- ):
249
- audio, sr = ref_audio
250
- if audio.shape[0] > 1:
251
- audio = torch.mean(audio, dim=0, keepdim=True)
252
-
253
- rms = torch.sqrt(torch.mean(torch.square(audio)))
254
- if rms < target_rms:
255
- audio = audio * target_rms / rms
256
- if sr != target_sample_rate:
257
- resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
258
- audio = resampler(audio)
259
- audio = audio.to(device)
260
-
261
- generated_waves = []
262
- spectrograms = []
263
-
264
- if len(ref_text[-1].encode("utf-8")) == 1:
265
- ref_text = ref_text + " "
266
- for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
267
- # Prepare the text
268
- text_list = [ref_text + gen_text]
269
- final_text_list = convert_char_to_pinyin(text_list)
270
-
271
- ref_audio_len = audio.shape[-1] // hop_length
272
- if fix_duration is not None:
273
- duration = int(fix_duration * target_sample_rate / hop_length)
274
- else:
275
- # Calculate duration
276
- ref_text_len = len(ref_text.encode("utf-8"))
277
- gen_text_len = len(gen_text.encode("utf-8"))
278
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
279
-
280
- # inference
281
- with torch.inference_mode():
282
- generated, _ = model_obj.sample(
283
- cond=audio,
284
- text=final_text_list,
285
- duration=duration,
286
- steps=nfe_step,
287
- cfg_strength=cfg_strength,
288
- sway_sampling_coef=sway_sampling_coef,
289
- )
290
-
291
- generated = generated.to(torch.float32)
292
- generated = generated[:, ref_audio_len:, :]
293
- generated_mel_spec = generated.permute(0, 2, 1)
294
- generated_wave = vocos.decode(generated_mel_spec.cpu())
295
- if rms < target_rms:
296
- generated_wave = generated_wave * rms / target_rms
297
-
298
- # wav -> numpy
299
- generated_wave = generated_wave.squeeze().cpu().numpy()
300
-
301
- generated_waves.append(generated_wave)
302
- spectrograms.append(generated_mel_spec[0].cpu().numpy())
303
-
304
- # Combine all generated waves with cross-fading
305
- if cross_fade_duration <= 0:
306
- # Simply concatenate
307
- final_wave = np.concatenate(generated_waves)
308
- else:
309
- final_wave = generated_waves[0]
310
- for i in range(1, len(generated_waves)):
311
- prev_wave = final_wave
312
- next_wave = generated_waves[i]
313
-
314
- # Calculate cross-fade samples, ensuring it does not exceed wave lengths
315
- cross_fade_samples = int(cross_fade_duration * target_sample_rate)
316
- cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
317
-
318
- if cross_fade_samples <= 0:
319
- # No overlap possible, concatenate
320
- final_wave = np.concatenate([prev_wave, next_wave])
321
- continue
322
-
323
- # Overlapping parts
324
- prev_overlap = prev_wave[-cross_fade_samples:]
325
- next_overlap = next_wave[:cross_fade_samples]
326
-
327
- # Fade out and fade in
328
- fade_out = np.linspace(1, 0, cross_fade_samples)
329
- fade_in = np.linspace(0, 1, cross_fade_samples)
330
-
331
- # Cross-faded overlap
332
- cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
333
-
334
- # Combine
335
- new_wave = np.concatenate(
336
- [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
337
- )
338
-
339
- final_wave = new_wave
340
-
341
- # Create a combined spectrogram
342
- combined_spectrogram = np.concatenate(spectrograms, axis=1)
343
-
344
- return final_wave, target_sample_rate, combined_spectrogram
345
-
346
-
347
- # remove silence from generated wav
348
-
349
-
350
- def remove_silence_for_generated_wav(filename):
351
- aseg = AudioSegment.from_file(filename)
352
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
353
- non_silent_wave = AudioSegment.silent(duration=0)
354
- for non_silent_seg in non_silent_segs:
355
- non_silent_wave += non_silent_seg
356
- aseg = non_silent_wave
357
- aseg.export(filename, format="wav")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
pyproject.toml DELETED
@@ -1,62 +0,0 @@
1
- [build-system]
2
- requires = ["setuptools >= 61.0", "setuptools-scm>=8.0"]
3
- build-backend = "setuptools.build_meta"
4
-
5
- [project]
6
- name = "f5-tts"
7
- version = "0.2.1"
8
- description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
- readme = "README.md"
10
- license = {text = "MIT License"}
11
- classifiers = [
12
- "License :: OSI Approved :: MIT License",
13
- "Operating System :: OS Independent",
14
- "Programming Language :: Python :: 3",
15
- ]
16
- dependencies = [
17
- "accelerate>=0.33.0",
18
- "bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'",
19
- "cached_path",
20
- "click",
21
- "datasets",
22
- "ema_pytorch>=0.5.2",
23
- "gradio>=3.45.2",
24
- "hydra-core>=1.3.0",
25
- "jieba",
26
- "librosa",
27
- "matplotlib",
28
- "numpy<=1.26.4",
29
- "pydub",
30
- "pypinyin",
31
- "safetensors",
32
- "soundfile",
33
- "tomli",
34
- "torch>=2.0.0",
35
- "torchaudio>=2.0.0",
36
- "torchdiffeq",
37
- "tqdm>=4.65.0",
38
- "transformers",
39
- "transformers_stream_generator",
40
- "vocos",
41
- "wandb",
42
- "x_transformers>=1.31.14",
43
- ]
44
-
45
- [project.optional-dependencies]
46
- eval = [
47
- "faster_whisper==0.10.1",
48
- "funasr",
49
- "jiwer",
50
- "modelscope",
51
- "zhconv",
52
- "zhon",
53
- ]
54
-
55
- [project.urls]
56
- Homepage = "https://github.com/SWivid/F5-TTS"
57
-
58
- [project.scripts]
59
- "f5-tts_infer-cli" = "f5_tts.infer.infer_cli:main"
60
- "f5-tts_infer-gradio" = "f5_tts.infer.infer_gradio:main"
61
- "f5-tts_finetune-cli" = "f5_tts.train.finetune_cli:main"
62
- "f5-tts_finetune-gradio" = "f5_tts.train.finetune_gradio:main"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,10 +1,9 @@
1
- torch
2
- torchaudio
3
  accelerate>=0.33.0
4
- bitsandbytes>0.37.0
5
  cached_path
6
  click
7
  datasets
 
 
8
  ema_pytorch>=0.5.2
9
  gradio
10
  jieba
@@ -22,5 +21,3 @@ transformers
22
  vocos
23
  wandb
24
  x_transformers>=1.31.14
25
- f5_tts @ git+https://huggingface.co/spaces/mrfakename/E2-F5-TTS
26
- detoxify @ git+https://github.com/unitaryai/detoxify
 
 
 
1
  accelerate>=0.33.0
 
2
  cached_path
3
  click
4
  datasets
5
+ einops>=0.8.0
6
+ einx>=0.3.0
7
  ema_pytorch>=0.5.2
8
  gradio
9
  jieba
 
21
  vocos
22
  wandb
23
  x_transformers>=1.31.14
 
 
ruff.toml DELETED
@@ -1,10 +0,0 @@
1
- line-length = 120
2
- target-version = "py310"
3
-
4
- [lint]
5
- # Only ignore variables with names starting with "_".
6
- dummy-variable-rgx = "^_.*$"
7
-
8
- [lint.isort]
9
- force-single-line = true
10
- lines-after-imports = 2
 
 
 
 
 
 
 
 
 
 
 
scripts/count_max_epoch.py CHANGED
@@ -1,7 +1,6 @@
1
- """ADAPTIVE BATCH SIZE"""
2
-
3
- print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in")
4
- print(" -> least padding, gather wavs with accumulated frames in a batch\n")
5
 
6
  # data
7
  total_hours = 95282
 
1
+ '''ADAPTIVE BATCH SIZE'''
2
+ print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in')
3
+ print(' -> least padding, gather wavs with accumulated frames in a batch\n')
 
4
 
5
  # data
6
  total_hours = 95282
scripts/count_params_gflops.py CHANGED
@@ -1,15 +1,13 @@
1
- import sys
2
- import os
3
-
4
  sys.path.append(os.getcwd())
5
 
6
- from model import M2_TTS, DiT
7
 
8
  import torch
9
  import thop
10
 
11
 
12
- """ ~155M """
13
  # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
14
  # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
15
  # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
@@ -17,11 +15,11 @@ import thop
17
  # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
18
  # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
19
 
20
- """ ~335M """
21
  # FLOPs: 622.1 G, Params: 333.2 M
22
  # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
23
  # FLOPs: 363.4 G, Params: 335.8 M
24
- transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
25
 
26
 
27
  model = M2_TTS(transformer=transformer)
@@ -32,8 +30,6 @@ duration = 20
32
  frame_length = int(duration * target_sample_rate / hop_length)
33
  text_length = 150
34
 
35
- flops, params = thop.profile(
36
- model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
37
- )
38
  print(f"FLOPs: {flops / 1e9} G")
39
  print(f"Params: {params / 1e6} M")
 
1
+ import sys, os
 
 
2
  sys.path.append(os.getcwd())
3
 
4
+ from model import M2_TTS, UNetT, DiT, MMDiT
5
 
6
  import torch
7
  import thop
8
 
9
 
10
+ ''' ~155M '''
11
  # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
12
  # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
13
  # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
 
15
  # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
16
  # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
17
 
18
+ ''' ~335M '''
19
  # FLOPs: 622.1 G, Params: 333.2 M
20
  # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
21
  # FLOPs: 363.4 G, Params: 335.8 M
22
+ transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
23
 
24
 
25
  model = M2_TTS(transformer=transformer)
 
30
  frame_length = int(duration * target_sample_rate / hop_length)
31
  text_length = 150
32
 
33
+ flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)))
 
 
34
  print(f"FLOPs: {flops / 1e9} G")
35
  print(f"Params: {params / 1e6} M")
scripts/eval_infer_batch.py CHANGED
@@ -1,6 +1,4 @@
1
- import sys
2
- import os
3
-
4
  sys.path.append(os.getcwd())
5
 
6
  import time
@@ -11,14 +9,15 @@ import argparse
11
  import torch
12
  import torchaudio
13
  from accelerate import Accelerator
 
14
  from vocos import Vocos
15
 
16
  from model import CFM, UNetT, DiT
17
  from model.utils import (
18
  load_checkpoint,
19
- get_tokenizer,
20
- get_seedtts_testset_metainfo,
21
- get_librispeech_test_clean_metainfo,
22
  get_inference_prompt,
23
  )
24
 
@@ -40,16 +39,16 @@ tokenizer = "pinyin"
40
 
41
  parser = argparse.ArgumentParser(description="batch inference")
42
 
43
- parser.add_argument("-s", "--seed", default=None, type=int)
44
- parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
45
- parser.add_argument("-n", "--expname", required=True)
46
- parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
47
 
48
- parser.add_argument("-nfe", "--nfestep", default=32, type=int)
49
- parser.add_argument("-o", "--odemethod", default="euler")
50
- parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
51
 
52
- parser.add_argument("-t", "--testset", required=True)
53
 
54
  args = parser.parse_args()
55
 
@@ -68,26 +67,26 @@ testset = args.testset
68
 
69
 
70
  infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
71
- cfg_strength = 2.0
72
- speed = 1.0
73
  use_truth_duration = False
74
  no_ref_audio = False
75
 
76
 
77
  if exp_name == "F5TTS_Base":
78
  model_cls = DiT
79
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
80
 
81
  elif exp_name == "E2TTS_Base":
82
  model_cls = UNetT
83
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
84
 
85
 
86
  if testset == "ls_pc_test_clean":
87
  metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
88
  librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
89
  metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
90
-
91
  elif testset == "seedtts_test_zh":
92
  metalst = "data/seedtts_testset/zh/meta.lst"
93
  metainfo = get_seedtts_testset_metainfo(metalst)
@@ -98,16 +97,13 @@ elif testset == "seedtts_test_en":
98
 
99
 
100
  # path to save genereted wavs
101
- if seed is None:
102
- seed = random.randint(-10000, 10000)
103
- output_dir = (
104
- f"results/{exp_name}_{ckpt_step}/{testset}/"
105
- f"seed{seed}_{ode_method}_nfe{nfe_step}"
106
- f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
107
- f"_cfg{cfg_strength}_speed{speed}"
108
- f"{'_gt-dur' if use_truth_duration else ''}"
109
  f"{'_no-ref-audio' if no_ref_audio else ''}"
110
- )
111
 
112
 
113
  # -------------------------------------------------#
@@ -115,15 +111,15 @@ output_dir = (
115
  use_ema = True
116
 
117
  prompts_all = get_inference_prompt(
118
- metainfo,
119
- speed=speed,
120
- tokenizer=tokenizer,
121
- target_sample_rate=target_sample_rate,
122
- n_mel_channels=n_mel_channels,
123
- hop_length=hop_length,
124
- target_rms=target_rms,
125
- use_truth_duration=use_truth_duration,
126
- infer_batch_size=infer_batch_size,
127
  )
128
 
129
  # Vocoder model
@@ -142,19 +138,23 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
142
 
143
  # Model
144
  model = CFM(
145
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
146
- mel_spec_kwargs=dict(
147
- target_sample_rate=target_sample_rate,
148
- n_mel_channels=n_mel_channels,
149
- hop_length=hop_length,
150
  ),
151
- odeint_kwargs=dict(
152
- method=ode_method,
 
 
153
  ),
154
- vocab_char_map=vocab_char_map,
 
 
 
155
  ).to(device)
156
 
157
- model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
158
 
159
  if not os.path.exists(output_dir) and accelerator.is_main_process:
160
  os.makedirs(output_dir)
@@ -164,29 +164,30 @@ accelerator.wait_for_everyone()
164
  start = time.time()
165
 
166
  with accelerator.split_between_processes(prompts_all) as prompts:
 
167
  for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
168
  utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
169
  ref_mels = ref_mels.to(device)
170
- ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
171
- total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
172
-
173
  # Inference
174
  with torch.inference_mode():
175
  generated, _ = model.sample(
176
- cond=ref_mels,
177
- text=final_text_list,
178
- duration=total_mel_lens,
179
- lens=ref_mel_lens,
180
- steps=nfe_step,
181
- cfg_strength=cfg_strength,
182
- sway_sampling_coef=sway_sampling_coef,
183
- no_ref_audio=no_ref_audio,
184
- seed=seed,
185
  )
186
  # Final result
187
  for i, gen in enumerate(generated):
188
- gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
189
- gen_mel_spec = gen.permute(0, 2, 1)
190
  generated_wave = vocos.decode(gen_mel_spec.cpu())
191
  if ref_rms_list[i] < target_rms:
192
  generated_wave = generated_wave * ref_rms_list[i] / target_rms
 
1
+ import sys, os
 
 
2
  sys.path.append(os.getcwd())
3
 
4
  import time
 
9
  import torch
10
  import torchaudio
11
  from accelerate import Accelerator
12
+ from einops import rearrange
13
  from vocos import Vocos
14
 
15
  from model import CFM, UNetT, DiT
16
  from model.utils import (
17
  load_checkpoint,
18
+ get_tokenizer,
19
+ get_seedtts_testset_metainfo,
20
+ get_librispeech_test_clean_metainfo,
21
  get_inference_prompt,
22
  )
23
 
 
39
 
40
  parser = argparse.ArgumentParser(description="batch inference")
41
 
42
+ parser.add_argument('-s', '--seed', default=None, type=int)
43
+ parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
44
+ parser.add_argument('-n', '--expname', required=True)
45
+ parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
46
 
47
+ parser.add_argument('-nfe', '--nfestep', default=32, type=int)
48
+ parser.add_argument('-o', '--odemethod', default="euler")
49
+ parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
50
 
51
+ parser.add_argument('-t', '--testset', required=True)
52
 
53
  args = parser.parse_args()
54
 
 
67
 
68
 
69
  infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
70
+ cfg_strength = 2.
71
+ speed = 1.
72
  use_truth_duration = False
73
  no_ref_audio = False
74
 
75
 
76
  if exp_name == "F5TTS_Base":
77
  model_cls = DiT
78
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
79
 
80
  elif exp_name == "E2TTS_Base":
81
  model_cls = UNetT
82
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
83
 
84
 
85
  if testset == "ls_pc_test_clean":
86
  metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
87
  librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
88
  metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
89
+
90
  elif testset == "seedtts_test_zh":
91
  metalst = "data/seedtts_testset/zh/meta.lst"
92
  metainfo = get_seedtts_testset_metainfo(metalst)
 
97
 
98
 
99
  # path to save genereted wavs
100
+ if seed is None: seed = random.randint(-10000, 10000)
101
+ output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
102
+ f"seed{seed}_{ode_method}_nfe{nfe_step}" \
103
+ f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
104
+ f"_cfg{cfg_strength}_speed{speed}" \
105
+ f"{'_gt-dur' if use_truth_duration else ''}" \
 
 
106
  f"{'_no-ref-audio' if no_ref_audio else ''}"
 
107
 
108
 
109
  # -------------------------------------------------#
 
111
  use_ema = True
112
 
113
  prompts_all = get_inference_prompt(
114
+ metainfo,
115
+ speed = speed,
116
+ tokenizer = tokenizer,
117
+ target_sample_rate = target_sample_rate,
118
+ n_mel_channels = n_mel_channels,
119
+ hop_length = hop_length,
120
+ target_rms = target_rms,
121
+ use_truth_duration = use_truth_duration,
122
+ infer_batch_size = infer_batch_size,
123
  )
124
 
125
  # Vocoder model
 
138
 
139
  # Model
140
  model = CFM(
141
+ transformer = model_cls(
142
+ **model_cfg,
143
+ text_num_embeds = vocab_size,
144
+ mel_dim = n_mel_channels
 
145
  ),
146
+ mel_spec_kwargs = dict(
147
+ target_sample_rate = target_sample_rate,
148
+ n_mel_channels = n_mel_channels,
149
+ hop_length = hop_length,
150
  ),
151
+ odeint_kwargs = dict(
152
+ method = ode_method,
153
+ ),
154
+ vocab_char_map = vocab_char_map,
155
  ).to(device)
156
 
157
+ model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
158
 
159
  if not os.path.exists(output_dir) and accelerator.is_main_process:
160
  os.makedirs(output_dir)
 
164
  start = time.time()
165
 
166
  with accelerator.split_between_processes(prompts_all) as prompts:
167
+
168
  for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
169
  utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
170
  ref_mels = ref_mels.to(device)
171
+ ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
172
+ total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
173
+
174
  # Inference
175
  with torch.inference_mode():
176
  generated, _ = model.sample(
177
+ cond = ref_mels,
178
+ text = final_text_list,
179
+ duration = total_mel_lens,
180
+ lens = ref_mel_lens,
181
+ steps = nfe_step,
182
+ cfg_strength = cfg_strength,
183
+ sway_sampling_coef = sway_sampling_coef,
184
+ no_ref_audio = no_ref_audio,
185
+ seed = seed,
186
  )
187
  # Final result
188
  for i, gen in enumerate(generated):
189
+ gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
190
+ gen_mel_spec = rearrange(gen, '1 n d -> 1 d n')
191
  generated_wave = vocos.decode(gen_mel_spec.cpu())
192
  if ref_rms_list[i] < target_rms:
193
  generated_wave = generated_wave * ref_rms_list[i] / target_rms
scripts/eval_librispeech_test_clean.py CHANGED
@@ -1,8 +1,6 @@
1
  # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
 
3
- import sys
4
- import os
5
-
6
  sys.path.append(os.getcwd())
7
 
8
  import multiprocessing as mp
@@ -21,7 +19,7 @@ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
21
  librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
22
  gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
23
 
24
- gpus = [0, 1, 2, 3, 4, 5, 6, 7]
25
  test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
26
 
27
  ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
@@ -48,7 +46,7 @@ if eval_task == "wer":
48
  for wers_ in results:
49
  wers.extend(wers_)
50
 
51
- wer = round(np.mean(wers) * 100, 3)
52
  print(f"\nTotal {len(wers)} samples")
53
  print(f"WER : {wer}%")
54
 
@@ -64,6 +62,6 @@ if eval_task == "sim":
64
  for sim_ in results:
65
  sim_list.extend(sim_)
66
 
67
- sim = round(sum(sim_list) / len(sim_list), 3)
68
  print(f"\nTotal {len(sim_list)} samples")
69
  print(f"SIM : {sim}")
 
1
  # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
 
3
+ import sys, os
 
 
4
  sys.path.append(os.getcwd())
5
 
6
  import multiprocessing as mp
 
19
  librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
20
  gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
21
 
22
+ gpus = [0,1,2,3,4,5,6,7]
23
  test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
24
 
25
  ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
 
46
  for wers_ in results:
47
  wers.extend(wers_)
48
 
49
+ wer = round(np.mean(wers)*100, 3)
50
  print(f"\nTotal {len(wers)} samples")
51
  print(f"WER : {wer}%")
52
 
 
62
  for sim_ in results:
63
  sim_list.extend(sim_)
64
 
65
+ sim = round(sum(sim_list)/len(sim_list), 3)
66
  print(f"\nTotal {len(sim_list)} samples")
67
  print(f"SIM : {sim}")
scripts/eval_seedtts_testset.py CHANGED
@@ -1,8 +1,6 @@
1
  # Evaluate with Seed-TTS testset
2
 
3
- import sys
4
- import os
5
-
6
  sys.path.append(os.getcwd())
7
 
8
  import multiprocessing as mp
@@ -16,21 +14,21 @@ from model.utils import (
16
 
17
 
18
  eval_task = "wer" # sim | wer
19
- lang = "zh" # zh | en
20
  metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
21
  # gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
22
- gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
23
 
24
 
25
  # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
26
- # zh 1.254 seems a result of 4 workers wer_seed_tts
27
- gpus = [0, 1, 2, 3, 4, 5, 6, 7]
28
  test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
29
 
30
  local = False
31
  if local: # use local custom checkpoint dir
32
  if lang == "zh":
33
- asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
34
  elif lang == "en":
35
  asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
36
  else:
@@ -50,7 +48,7 @@ if eval_task == "wer":
50
  for wers_ in results:
51
  wers.extend(wers_)
52
 
53
- wer = round(np.mean(wers) * 100, 3)
54
  print(f"\nTotal {len(wers)} samples")
55
  print(f"WER : {wer}%")
56
 
@@ -66,6 +64,6 @@ if eval_task == "sim":
66
  for sim_ in results:
67
  sim_list.extend(sim_)
68
 
69
- sim = round(sum(sim_list) / len(sim_list), 3)
70
  print(f"\nTotal {len(sim_list)} samples")
71
  print(f"SIM : {sim}")
 
1
  # Evaluate with Seed-TTS testset
2
 
3
+ import sys, os
 
 
4
  sys.path.append(os.getcwd())
5
 
6
  import multiprocessing as mp
 
14
 
15
 
16
  eval_task = "wer" # sim | wer
17
+ lang = "zh" # zh | en
18
  metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
19
  # gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
20
+ gen_wav_dir = f"PATH_TO_GENERATED" # generated wavs
21
 
22
 
23
  # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
24
+ # zh 1.254 seems a result of 4 workers wer_seed_tts
25
+ gpus = [0,1,2,3,4,5,6,7]
26
  test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
27
 
28
  local = False
29
  if local: # use local custom checkpoint dir
30
  if lang == "zh":
31
+ asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
32
  elif lang == "en":
33
  asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
34
  else:
 
48
  for wers_ in results:
49
  wers.extend(wers_)
50
 
51
+ wer = round(np.mean(wers)*100, 3)
52
  print(f"\nTotal {len(wers)} samples")
53
  print(f"WER : {wer}%")
54
 
 
64
  for sim_ in results:
65
  sim_list.extend(sim_)
66
 
67
+ sim = round(sum(sim_list)/len(sim_list), 3)
68
  print(f"\nTotal {len(sim_list)} samples")
69
  print(f"SIM : {sim}")
scripts/prepare_csv_wavs.py CHANGED
@@ -1,6 +1,4 @@
1
- import sys
2
- import os
3
-
4
  sys.path.append(os.getcwd())
5
 
6
  from pathlib import Path
@@ -19,11 +17,10 @@ from model.utils import (
19
 
20
  PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
21
 
22
-
23
  def is_csv_wavs_format(input_dataset_dir):
24
  fpath = Path(input_dataset_dir)
25
  metadata = fpath / "metadata.csv"
26
- wavs = fpath / "wavs"
27
  return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
28
 
29
 
@@ -49,24 +46,22 @@ def prepare_csv_wavs_dir(input_dir):
49
 
50
  return sub_result, durations, vocab_set
51
 
52
-
53
  def get_audio_duration(audio_path):
54
  audio, sample_rate = torchaudio.load(audio_path)
55
  num_channels = audio.shape[0]
56
  return audio.shape[1] / (sample_rate * num_channels)
57
 
58
-
59
  def read_audio_text_pairs(csv_file_path):
60
  audio_text_pairs = []
61
 
62
  parent = Path(csv_file_path).parent
63
- with open(csv_file_path, mode="r", newline="", encoding="utf-8") as csvfile:
64
- reader = csv.reader(csvfile, delimiter="|")
65
  next(reader) # Skip the header row
66
  for row in reader:
67
  if len(row) >= 2:
68
  audio_file = row[0].strip() # First column: audio file path
69
- text = row[1].strip() # Second column: text
70
  audio_file_path = parent / audio_file
71
  audio_text_pairs.append((audio_file_path.as_posix(), text))
72
 
@@ -83,12 +78,12 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
83
  # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
84
  raw_arrow_path = out_dir / "raw.arrow"
85
  with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
86
- for line in tqdm(result, desc="Writing to raw.arrow ..."):
87
  writer.write(line)
88
 
89
  # dup a json separately saving duration in case for DynamicBatchSampler ease
90
  dur_json_path = out_dir / "duration.json"
91
- with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
92
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
93
 
94
  # vocab map, i.e. tokenizer
@@ -125,14 +120,13 @@ def cli():
125
  # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
126
  # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
127
  parser = argparse.ArgumentParser(description="Prepare and save dataset.")
128
- parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
129
- parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
130
- parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
131
 
132
  args = parser.parse_args()
133
 
134
  prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
135
 
136
-
137
  if __name__ == "__main__":
138
  cli()
 
1
+ import sys, os
 
 
2
  sys.path.append(os.getcwd())
3
 
4
  from pathlib import Path
 
17
 
18
  PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
19
 
 
20
  def is_csv_wavs_format(input_dataset_dir):
21
  fpath = Path(input_dataset_dir)
22
  metadata = fpath / "metadata.csv"
23
+ wavs = fpath / 'wavs'
24
  return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
25
 
26
 
 
46
 
47
  return sub_result, durations, vocab_set
48
 
 
49
  def get_audio_duration(audio_path):
50
  audio, sample_rate = torchaudio.load(audio_path)
51
  num_channels = audio.shape[0]
52
  return audio.shape[1] / (sample_rate * num_channels)
53
 
 
54
  def read_audio_text_pairs(csv_file_path):
55
  audio_text_pairs = []
56
 
57
  parent = Path(csv_file_path).parent
58
+ with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csvfile:
59
+ reader = csv.reader(csvfile, delimiter='|')
60
  next(reader) # Skip the header row
61
  for row in reader:
62
  if len(row) >= 2:
63
  audio_file = row[0].strip() # First column: audio file path
64
+ text = row[1].strip() # Second column: text
65
  audio_file_path = parent / audio_file
66
  audio_text_pairs.append((audio_file_path.as_posix(), text))
67
 
 
78
  # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
79
  raw_arrow_path = out_dir / "raw.arrow"
80
  with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
81
+ for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
82
  writer.write(line)
83
 
84
  # dup a json separately saving duration in case for DynamicBatchSampler ease
85
  dur_json_path = out_dir / "duration.json"
86
+ with open(dur_json_path.as_posix(), 'w', encoding='utf-8') as f:
87
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
88
 
89
  # vocab map, i.e. tokenizer
 
120
  # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
121
  # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
122
  parser = argparse.ArgumentParser(description="Prepare and save dataset.")
123
+ parser.add_argument('inp_dir', type=str, help="Input directory containing the data.")
124
+ parser.add_argument('out_dir', type=str, help="Output directory to save the prepared data.")
125
+ parser.add_argument('--pretrain', action='store_true', help="Enable for new pretrain, otherwise is a fine-tune")
126
 
127
  args = parser.parse_args()
128
 
129
  prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
130
 
 
131
  if __name__ == "__main__":
132
  cli()
scripts/prepare_emilia.py CHANGED
@@ -4,9 +4,7 @@
4
  # generate audio text map for Emilia ZH & EN
5
  # evaluate for vocab size
6
 
7
- import sys
8
- import os
9
-
10
  sys.path.append(os.getcwd())
11
 
12
  from pathlib import Path
@@ -14,6 +12,7 @@ import json
14
  from tqdm import tqdm
15
  from concurrent.futures import ProcessPoolExecutor
16
 
 
17
  from datasets.arrow_writer import ArrowWriter
18
 
19
  from model.utils import (
@@ -22,89 +21,13 @@ from model.utils import (
22
  )
23
 
24
 
25
- out_zh = {
26
- "ZH_B00041_S06226",
27
- "ZH_B00042_S09204",
28
- "ZH_B00065_S09430",
29
- "ZH_B00065_S09431",
30
- "ZH_B00066_S09327",
31
- "ZH_B00066_S09328",
32
- }
33
  zh_filters = ["い", "て"]
34
  # seems synthesized audios, or heavily code-switched
35
  out_en = {
36
- "EN_B00013_S00913",
37
- "EN_B00042_S00120",
38
- "EN_B00055_S04111",
39
- "EN_B00061_S00693",
40
- "EN_B00061_S01494",
41
- "EN_B00061_S03375",
42
- "EN_B00059_S00092",
43
- "EN_B00111_S04300",
44
- "EN_B00100_S03759",
45
- "EN_B00087_S03811",
46
- "EN_B00059_S00950",
47
- "EN_B00089_S00946",
48
- "EN_B00078_S05127",
49
- "EN_B00070_S04089",
50
- "EN_B00074_S09659",
51
- "EN_B00061_S06983",
52
- "EN_B00061_S07060",
53
- "EN_B00059_S08397",
54
- "EN_B00082_S06192",
55
- "EN_B00091_S01238",
56
- "EN_B00089_S07349",
57
- "EN_B00070_S04343",
58
- "EN_B00061_S02400",
59
- "EN_B00076_S01262",
60
- "EN_B00068_S06467",
61
- "EN_B00076_S02943",
62
- "EN_B00064_S05954",
63
- "EN_B00061_S05386",
64
- "EN_B00066_S06544",
65
- "EN_B00076_S06944",
66
- "EN_B00072_S08620",
67
- "EN_B00076_S07135",
68
- "EN_B00076_S09127",
69
- "EN_B00065_S00497",
70
- "EN_B00059_S06227",
71
- "EN_B00063_S02859",
72
- "EN_B00075_S01547",
73
- "EN_B00061_S08286",
74
- "EN_B00079_S02901",
75
- "EN_B00092_S03643",
76
- "EN_B00096_S08653",
77
- "EN_B00063_S04297",
78
- "EN_B00063_S04614",
79
- "EN_B00079_S04698",
80
- "EN_B00104_S01666",
81
- "EN_B00061_S09504",
82
- "EN_B00061_S09694",
83
- "EN_B00065_S05444",
84
- "EN_B00063_S06860",
85
- "EN_B00065_S05725",
86
- "EN_B00069_S07628",
87
- "EN_B00083_S03875",
88
- "EN_B00071_S07665",
89
- "EN_B00071_S07665",
90
- "EN_B00062_S04187",
91
- "EN_B00065_S09873",
92
- "EN_B00065_S09922",
93
- "EN_B00084_S02463",
94
- "EN_B00067_S05066",
95
- "EN_B00106_S08060",
96
- "EN_B00073_S06399",
97
- "EN_B00073_S09236",
98
- "EN_B00087_S00432",
99
- "EN_B00085_S05618",
100
- "EN_B00064_S01262",
101
- "EN_B00072_S01739",
102
- "EN_B00059_S03913",
103
- "EN_B00069_S04036",
104
- "EN_B00067_S05623",
105
- "EN_B00060_S05389",
106
- "EN_B00060_S07290",
107
- "EN_B00062_S08995",
108
  }
109
  en_filters = ["ا", "い", "て"]
110
 
@@ -120,24 +43,18 @@ def deal_with_audio_dir(audio_dir):
120
  for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
121
  obj = json.loads(line)
122
  text = obj["text"]
123
- if obj["language"] == "zh":
124
  if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
125
  bad_case_zh += 1
126
  continue
127
  else:
128
- text = text.translate(
129
- str.maketrans({",": ",", "!": "", "?": "?"})
130
- ) # not "" cuz much code-switched
131
- if obj["language"] == "en":
132
- if (
133
- obj["wav"].split("/")[1] in out_en
134
- or any(f in text for f in en_filters)
135
- or repetition_found(text, length=4)
136
- ):
137
  bad_case_en += 1
138
  continue
139
  if tokenizer == "pinyin":
140
- text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
141
  duration = obj["duration"]
142
  sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
143
  durations.append(duration)
@@ -179,11 +96,11 @@ def main():
179
  # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
180
  # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
181
  with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
182
- for line in tqdm(result, desc="Writing to raw.arrow ..."):
183
  writer.write(line)
184
 
185
  # dup a json separately saving duration in case for DynamicBatchSampler ease
186
- with open(f"data/{dataset_name}/duration.json", "w", encoding="utf-8") as f:
187
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
188
 
189
  # vocab map, i.e. tokenizer
@@ -197,13 +114,12 @@ def main():
197
  print(f"\nFor {dataset_name}, sample count: {len(result)}")
198
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
199
  print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
200
- if "ZH" in langs:
201
- print(f"Bad zh transcription case: {total_bad_case_zh}")
202
- if "EN" in langs:
203
- print(f"Bad en transcription case: {total_bad_case_en}\n")
204
 
205
 
206
  if __name__ == "__main__":
 
207
  max_workers = 32
208
 
209
  tokenizer = "pinyin" # "pinyin" | "char"
 
4
  # generate audio text map for Emilia ZH & EN
5
  # evaluate for vocab size
6
 
7
+ import sys, os
 
 
8
  sys.path.append(os.getcwd())
9
 
10
  from pathlib import Path
 
12
  from tqdm import tqdm
13
  from concurrent.futures import ProcessPoolExecutor
14
 
15
+ from datasets import Dataset
16
  from datasets.arrow_writer import ArrowWriter
17
 
18
  from model.utils import (
 
21
  )
22
 
23
 
24
+ out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"}
 
 
 
 
 
 
 
25
  zh_filters = ["い", "て"]
26
  # seems synthesized audios, or heavily code-switched
27
  out_en = {
28
+ "EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375",
29
+
30
+ "EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  }
32
  en_filters = ["ا", "い", "て"]
33
 
 
43
  for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
44
  obj = json.loads(line)
45
  text = obj["text"]
46
+ if obj['language'] == "zh":
47
  if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
48
  bad_case_zh += 1
49
  continue
50
  else:
51
+ text = text.translate(str.maketrans({',': ',', '!': '!', '?': '?'})) # not "。" cuz much code-switched
52
+ if obj['language'] == "en":
53
+ if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4):
 
 
 
 
 
 
54
  bad_case_en += 1
55
  continue
56
  if tokenizer == "pinyin":
57
+ text = convert_char_to_pinyin([text], polyphone = polyphone)[0]
58
  duration = obj["duration"]
59
  sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
60
  durations.append(duration)
 
96
  # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
97
  # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
98
  with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
99
+ for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
100
  writer.write(line)
101
 
102
  # dup a json separately saving duration in case for DynamicBatchSampler ease
103
+ with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f:
104
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
105
 
106
  # vocab map, i.e. tokenizer
 
114
  print(f"\nFor {dataset_name}, sample count: {len(result)}")
115
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
116
  print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
117
+ if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}")
118
+ if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n")
 
 
119
 
120
 
121
  if __name__ == "__main__":
122
+
123
  max_workers = 32
124
 
125
  tokenizer = "pinyin" # "pinyin" | "char"
scripts/prepare_wenetspeech4tts.py CHANGED
@@ -1,9 +1,7 @@
1
  # generate audio text map for WenetSpeech4TTS
2
  # evaluate for vocab size
3
 
4
- import sys
5
- import os
6
-
7
  sys.path.append(os.getcwd())
8
 
9
  import json
@@ -25,7 +23,7 @@ def deal_with_sub_path_files(dataset_path, sub_path):
25
 
26
  audio_paths, texts, durations = [], [], []
27
  for text_file in tqdm(text_files):
28
- with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file:
29
  first_line = file.readline().split("\t")
30
  audio_nm = first_line[0]
31
  audio_path = os.path.join(audio_dir, audio_nm + ".wav")
@@ -34,7 +32,7 @@ def deal_with_sub_path_files(dataset_path, sub_path):
34
  audio_paths.append(audio_path)
35
 
36
  if tokenizer == "pinyin":
37
- texts.extend(convert_char_to_pinyin([text], polyphone=polyphone))
38
  elif tokenizer == "char":
39
  texts.append(text)
40
 
@@ -48,7 +46,7 @@ def main():
48
  assert tokenizer in ["pinyin", "char"]
49
 
50
  audio_path_list, text_list, duration_list = [], [], []
51
-
52
  executor = ProcessPoolExecutor(max_workers=max_workers)
53
  futures = []
54
  for dataset_path in dataset_paths:
@@ -70,10 +68,8 @@ def main():
70
  dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
71
  dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
72
 
73
- with open(f"data/{dataset_name}_{tokenizer}/duration.json", "w", encoding="utf-8") as f:
74
- json.dump(
75
- {"duration": duration_list}, f, ensure_ascii=False
76
- ) # dup a json separately saving duration in case for DynamicBatchSampler ease
77
 
78
  print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
79
  text_vocab_set = set()
@@ -89,21 +85,22 @@ def main():
89
  f.write(vocab + "\n")
90
  print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
91
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
92
-
93
 
94
  if __name__ == "__main__":
 
95
  max_workers = 32
96
 
97
  tokenizer = "pinyin" # "pinyin" | "char"
98
  polyphone = True
99
  dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
100
 
101
- dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1]
102
  dataset_paths = [
103
  "<SOME_PATH>/WenetSpeech4TTS/Basic",
104
  "<SOME_PATH>/WenetSpeech4TTS/Standard",
105
  "<SOME_PATH>/WenetSpeech4TTS/Premium",
106
- ][-dataset_choice:]
107
  print(f"\nChoose Dataset: {dataset_name}\n")
108
 
109
  main()
@@ -112,8 +109,8 @@ if __name__ == "__main__":
112
  # WenetSpeech4TTS Basic Standard Premium
113
  # samples count 3932473 1941220 407494
114
  # pinyin vocab size 1349 1348 1344 (no polyphone)
115
- # - - 1459 (polyphone)
116
  # char vocab size 5264 5219 5042
117
-
118
  # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
119
  # please be careful if using pretrained model, make sure the vocab.txt is same
 
1
  # generate audio text map for WenetSpeech4TTS
2
  # evaluate for vocab size
3
 
4
+ import sys, os
 
 
5
  sys.path.append(os.getcwd())
6
 
7
  import json
 
23
 
24
  audio_paths, texts, durations = [], [], []
25
  for text_file in tqdm(text_files):
26
+ with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file:
27
  first_line = file.readline().split("\t")
28
  audio_nm = first_line[0]
29
  audio_path = os.path.join(audio_dir, audio_nm + ".wav")
 
32
  audio_paths.append(audio_path)
33
 
34
  if tokenizer == "pinyin":
35
+ texts.extend(convert_char_to_pinyin([text], polyphone = polyphone))
36
  elif tokenizer == "char":
37
  texts.append(text)
38
 
 
46
  assert tokenizer in ["pinyin", "char"]
47
 
48
  audio_path_list, text_list, duration_list = [], [], []
49
+
50
  executor = ProcessPoolExecutor(max_workers=max_workers)
51
  futures = []
52
  for dataset_path in dataset_paths:
 
68
  dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
69
  dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
70
 
71
+ with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f:
72
+ json.dump({"duration": duration_list}, f, ensure_ascii=False) # dup a json separately saving duration in case for DynamicBatchSampler ease
 
 
73
 
74
  print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
75
  text_vocab_set = set()
 
85
  f.write(vocab + "\n")
86
  print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
87
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
88
+
89
 
90
  if __name__ == "__main__":
91
+
92
  max_workers = 32
93
 
94
  tokenizer = "pinyin" # "pinyin" | "char"
95
  polyphone = True
96
  dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
97
 
98
+ dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1]
99
  dataset_paths = [
100
  "<SOME_PATH>/WenetSpeech4TTS/Basic",
101
  "<SOME_PATH>/WenetSpeech4TTS/Standard",
102
  "<SOME_PATH>/WenetSpeech4TTS/Premium",
103
+ ][-dataset_choice:]
104
  print(f"\nChoose Dataset: {dataset_name}\n")
105
 
106
  main()
 
109
  # WenetSpeech4TTS Basic Standard Premium
110
  # samples count 3932473 1941220 407494
111
  # pinyin vocab size 1349 1348 1344 (no polyphone)
112
+ # - - 1459 (polyphone)
113
  # char vocab size 5264 5219 5042
114
+
115
  # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
116
  # please be careful if using pretrained model, make sure the vocab.txt is same
speech_edit.py CHANGED
@@ -3,13 +3,14 @@ import os
3
  import torch
4
  import torch.nn.functional as F
5
  import torchaudio
 
6
  from vocos import Vocos
7
 
8
- from model import CFM, UNetT, DiT
9
  from model.utils import (
10
  load_checkpoint,
11
- get_tokenizer,
12
- convert_char_to_pinyin,
13
  save_spectrogram,
14
  )
15
 
@@ -35,20 +36,20 @@ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
35
  ckpt_step = 1200000
36
 
37
  nfe_step = 32 # 16, 32
38
- cfg_strength = 2.0
39
- ode_method = "euler" # euler | midpoint
40
- sway_sampling_coef = -1.0
41
- speed = 1.0
42
 
43
  if exp_name == "F5TTS_Base":
44
  model_cls = DiT
45
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
46
 
47
  elif exp_name == "E2TTS_Base":
48
  model_cls = UNetT
49
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
50
 
51
- ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
52
  output_dir = "tests"
53
 
54
  # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
@@ -62,14 +63,8 @@ output_dir = "tests"
62
  audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav"
63
  origin_text = "Some call me nature, others call me mother nature."
64
  target_text = "Some call me optimist, others call me realist."
65
- parts_to_edit = [
66
- [1.42, 2.44],
67
- [4.04, 4.9],
68
- ] # stard_ends of "nature" & "mother nature", in seconds
69
- fix_duration = [
70
- 1.2,
71
- 1,
72
- ] # fix duration for "optimist" & "realist", in seconds
73
 
74
  # audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav"
75
  # origin_text = "对,这就是我,万人敬仰的太乙真人。"
@@ -92,7 +87,7 @@ if local:
92
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
93
  state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
94
  vocos.load_state_dict(state_dict)
95
-
96
  vocos.eval()
97
  else:
98
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
@@ -102,19 +97,23 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
102
 
103
  # Model
104
  model = CFM(
105
- transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
106
- mel_spec_kwargs=dict(
107
- target_sample_rate=target_sample_rate,
108
- n_mel_channels=n_mel_channels,
109
- hop_length=hop_length,
 
 
 
 
110
  ),
111
- odeint_kwargs=dict(
112
- method=ode_method,
113
  ),
114
- vocab_char_map=vocab_char_map,
115
  ).to(device)
116
 
117
- model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
118
 
119
  # Audio
120
  audio, sr = torchaudio.load(audio_to_edit)
@@ -134,18 +133,14 @@ for part in parts_to_edit:
134
  part_dur = end - start if fix_duration is None else fix_duration.pop(0)
135
  part_dur = part_dur * target_sample_rate
136
  start = start * target_sample_rate
137
- audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1)
138
- edit_mask = torch.cat(
139
- (
140
- edit_mask,
141
- torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool),
142
- torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool),
143
- ),
144
- dim=-1,
145
- )
146
  offset = end * target_sample_rate
147
  # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
148
- edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
149
  audio = audio.to(device)
150
  edit_mask = edit_mask.to(device)
151
 
@@ -165,25 +160,24 @@ duration = audio.shape[-1] // hop_length
165
  # Inference
166
  with torch.inference_mode():
167
  generated, trajectory = model.sample(
168
- cond=audio,
169
- text=final_text_list,
170
- duration=duration,
171
- steps=nfe_step,
172
- cfg_strength=cfg_strength,
173
- sway_sampling_coef=sway_sampling_coef,
174
- seed=seed,
175
- edit_mask=edit_mask,
176
  )
177
  print(f"Generated mel: {generated.shape}")
178
 
179
  # Final result
180
- generated = generated.to(torch.float32)
181
  generated = generated[:, ref_audio_len:, :]
182
- generated_mel_spec = generated.permute(0, 2, 1)
183
  generated_wave = vocos.decode(generated_mel_spec.cpu())
184
  if rms < target_rms:
185
  generated_wave = generated_wave * rms / target_rms
186
 
187
- save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
188
- torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
189
  print(f"Generated wav: {generated_wave.shape}")
 
3
  import torch
4
  import torch.nn.functional as F
5
  import torchaudio
6
+ from einops import rearrange
7
  from vocos import Vocos
8
 
9
+ from model import CFM, UNetT, DiT, MMDiT
10
  from model.utils import (
11
  load_checkpoint,
12
+ get_tokenizer,
13
+ convert_char_to_pinyin,
14
  save_spectrogram,
15
  )
16
 
 
36
  ckpt_step = 1200000
37
 
38
  nfe_step = 32 # 16, 32
39
+ cfg_strength = 2.
40
+ ode_method = 'euler' # euler | midpoint
41
+ sway_sampling_coef = -1.
42
+ speed = 1.
43
 
44
  if exp_name == "F5TTS_Base":
45
  model_cls = DiT
46
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
47
 
48
  elif exp_name == "E2TTS_Base":
49
  model_cls = UNetT
50
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
51
 
52
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
53
  output_dir = "tests"
54
 
55
  # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
 
63
  audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav"
64
  origin_text = "Some call me nature, others call me mother nature."
65
  target_text = "Some call me optimist, others call me realist."
66
+ parts_to_edit = [[1.42, 2.44], [4.04, 4.9], ] # stard_ends of "nature" & "mother nature", in seconds
67
+ fix_duration = [1.2, 1, ] # fix duration for "optimist" & "realist", in seconds
 
 
 
 
 
 
68
 
69
  # audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav"
70
  # origin_text = "对,这就是我,万人敬仰的太乙真人。"
 
87
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
88
  state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
89
  vocos.load_state_dict(state_dict)
90
+
91
  vocos.eval()
92
  else:
93
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
 
97
 
98
  # Model
99
  model = CFM(
100
+ transformer = model_cls(
101
+ **model_cfg,
102
+ text_num_embeds = vocab_size,
103
+ mel_dim = n_mel_channels
104
+ ),
105
+ mel_spec_kwargs = dict(
106
+ target_sample_rate = target_sample_rate,
107
+ n_mel_channels = n_mel_channels,
108
+ hop_length = hop_length,
109
  ),
110
+ odeint_kwargs = dict(
111
+ method = ode_method,
112
  ),
113
+ vocab_char_map = vocab_char_map,
114
  ).to(device)
115
 
116
+ model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
117
 
118
  # Audio
119
  audio, sr = torchaudio.load(audio_to_edit)
 
133
  part_dur = end - start if fix_duration is None else fix_duration.pop(0)
134
  part_dur = part_dur * target_sample_rate
135
  start = start * target_sample_rate
136
+ audio_ = torch.cat((audio_, audio[:, round(offset):round(start)], torch.zeros(1, round(part_dur))), dim = -1)
137
+ edit_mask = torch.cat((edit_mask,
138
+ torch.ones(1, round((start - offset) / hop_length), dtype = torch.bool),
139
+ torch.zeros(1, round(part_dur / hop_length), dtype = torch.bool)
140
+ ), dim = -1)
 
 
 
 
141
  offset = end * target_sample_rate
142
  # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
143
+ edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value = True)
144
  audio = audio.to(device)
145
  edit_mask = edit_mask.to(device)
146
 
 
160
  # Inference
161
  with torch.inference_mode():
162
  generated, trajectory = model.sample(
163
+ cond = audio,
164
+ text = final_text_list,
165
+ duration = duration,
166
+ steps = nfe_step,
167
+ cfg_strength = cfg_strength,
168
+ sway_sampling_coef = sway_sampling_coef,
169
+ seed = seed,
170
+ edit_mask = edit_mask,
171
  )
172
  print(f"Generated mel: {generated.shape}")
173
 
174
  # Final result
 
175
  generated = generated[:, ref_audio_len:, :]
176
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
177
  generated_wave = vocos.decode(generated_mel_spec.cpu())
178
  if rms < target_rms:
179
  generated_wave = generated_wave * rms / target_rms
180
 
181
+ save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single_edit.png")
182
+ torchaudio.save(f"{output_dir}/test_single_edit.wav", generated_wave, target_sample_rate)
183
  print(f"Generated wav: {generated_wave.shape}")
src/f5_tts/api.py DELETED
@@ -1,166 +0,0 @@
1
- import random
2
- import sys
3
- from importlib.resources import files
4
-
5
- import soundfile as sf
6
- import tqdm
7
- from cached_path import cached_path
8
-
9
- from f5_tts.infer.utils_infer import (
10
- hop_length,
11
- infer_process,
12
- load_model,
13
- load_vocoder,
14
- preprocess_ref_audio_text,
15
- remove_silence_for_generated_wav,
16
- save_spectrogram,
17
- transcribe,
18
- target_sample_rate,
19
- )
20
- from f5_tts.model import DiT, UNetT
21
- from f5_tts.model.utils import seed_everything
22
-
23
-
24
- class F5TTS:
25
- def __init__(
26
- self,
27
- model_type="F5-TTS",
28
- ckpt_file="",
29
- vocab_file="",
30
- ode_method="euler",
31
- use_ema=True,
32
- vocoder_name="vocos",
33
- local_path=None,
34
- device=None,
35
- hf_cache_dir=None,
36
- ):
37
- # Initialize parameters
38
- self.final_wave = None
39
- self.target_sample_rate = target_sample_rate
40
- self.hop_length = hop_length
41
- self.seed = -1
42
- self.mel_spec_type = vocoder_name
43
-
44
- # Set device
45
- if device is not None:
46
- self.device = device
47
- else:
48
- import torch
49
-
50
- self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
51
-
52
- # Load models
53
- self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
54
- self.load_ema_model(
55
- model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
56
- )
57
-
58
- def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None):
59
- self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir)
60
-
61
- def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None):
62
- if model_type == "F5-TTS":
63
- if not ckpt_file:
64
- if mel_spec_type == "vocos":
65
- ckpt_file = str(
66
- cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
67
- )
68
- elif mel_spec_type == "bigvgan":
69
- ckpt_file = str(
70
- cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
71
- )
72
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
73
- model_cls = DiT
74
- elif model_type == "E2-TTS":
75
- if not ckpt_file:
76
- ckpt_file = str(
77
- cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
78
- )
79
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
80
- model_cls = UNetT
81
- else:
82
- raise ValueError(f"Unknown model type: {model_type}")
83
-
84
- self.ema_model = load_model(
85
- model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
86
- )
87
-
88
- def transcribe(self, ref_audio, language=None):
89
- return transcribe(ref_audio, language)
90
-
91
- def export_wav(self, wav, file_wave, remove_silence=False):
92
- sf.write(file_wave, wav, self.target_sample_rate)
93
-
94
- if remove_silence:
95
- remove_silence_for_generated_wav(file_wave)
96
-
97
- def export_spectrogram(self, spect, file_spect):
98
- save_spectrogram(spect, file_spect)
99
-
100
- def infer(
101
- self,
102
- ref_file,
103
- ref_text,
104
- gen_text,
105
- show_info=print,
106
- progress=tqdm,
107
- target_rms=0.1,
108
- cross_fade_duration=0.15,
109
- sway_sampling_coef=-1,
110
- cfg_strength=2,
111
- nfe_step=32,
112
- speed=1.0,
113
- fix_duration=None,
114
- remove_silence=False,
115
- file_wave=None,
116
- file_spect=None,
117
- seed=-1,
118
- ):
119
- if seed == -1:
120
- seed = random.randint(0, sys.maxsize)
121
- seed_everything(seed)
122
- self.seed = seed
123
-
124
- ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
125
-
126
- wav, sr, spect = infer_process(
127
- ref_file,
128
- ref_text,
129
- gen_text,
130
- self.ema_model,
131
- self.vocoder,
132
- self.mel_spec_type,
133
- show_info=show_info,
134
- progress=progress,
135
- target_rms=target_rms,
136
- cross_fade_duration=cross_fade_duration,
137
- nfe_step=nfe_step,
138
- cfg_strength=cfg_strength,
139
- sway_sampling_coef=sway_sampling_coef,
140
- speed=speed,
141
- fix_duration=fix_duration,
142
- device=self.device,
143
- )
144
-
145
- if file_wave is not None:
146
- self.export_wav(wav, file_wave, remove_silence)
147
-
148
- if file_spect is not None:
149
- self.export_spectrogram(spect, file_spect)
150
-
151
- return wav, sr, spect
152
-
153
-
154
- if __name__ == "__main__":
155
- f5tts = F5TTS()
156
-
157
- wav, sr, spect = f5tts.infer(
158
- ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
159
- ref_text="some call me nature, others call me mother nature.",
160
- gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
161
- file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
162
- file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
163
- seed=-1, # random seed = -1
164
- )
165
-
166
- print("seed :", f5tts.seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/f5_tts/configs/E2TTS_Base_train.yaml DELETED
@@ -1,44 +0,0 @@
1
- hydra:
2
- run:
3
- dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
-
5
- datasets:
6
- name: Emilia_ZH_EN # dataset name
7
- batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
8
- batch_size_type: frame # "frame" or "sample"
9
- max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
10
- num_workers: 16
11
-
12
- optim:
13
- epochs: 15
14
- learning_rate: 7.5e-5
15
- num_warmup_updates: 20000 # warmup steps
16
- grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
- max_grad_norm: 1.0 # gradient clipping
18
- bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
19
-
20
- model:
21
- name: E2TTS_Base
22
- tokenizer: pinyin
23
- tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
24
- arch:
25
- dim: 1024
26
- depth: 24
27
- heads: 16
28
- ff_mult: 4
29
- mel_spec:
30
- target_sample_rate: 24000
31
- n_mel_channels: 100
32
- hop_length: 256
33
- win_length: 1024
34
- n_fft: 1024
35
- mel_spec_type: vocos # 'vocos' or 'bigvgan'
36
- vocoder:
37
- is_local: False # use local offline ckpt or not
38
- local_path: None # local vocoder path
39
-
40
- ckpts:
41
- logger: wandb # wandb | tensorboard | None
42
- save_per_updates: 50000 # save checkpoint per steps
43
- last_per_steps: 5000 # save last checkpoint per steps
44
- save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}