ddd commited on
Commit
b93970c
1 Parent(s): aee7e5a

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. README.md +83 -12
  3. checkpoints/.gitkeep +0 -0
  4. configs/config_base.yaml +42 -0
  5. configs/singing/base.yaml +42 -0
  6. configs/singing/fs2.yaml +3 -0
  7. configs/tts/base.yaml +95 -0
  8. configs/tts/base_zh.yaml +3 -0
  9. configs/tts/fs2.yaml +80 -0
  10. configs/tts/hifigan.yaml +21 -0
  11. configs/tts/lj/base_mel2wav.yaml +3 -0
  12. configs/tts/lj/base_text2mel.yaml +13 -0
  13. configs/tts/lj/fs2.yaml +3 -0
  14. configs/tts/lj/hifigan.yaml +3 -0
  15. configs/tts/lj/pwg.yaml +3 -0
  16. configs/tts/pwg.yaml +110 -0
  17. data/processed/ljspeech/dict.txt +77 -0
  18. data/processed/ljspeech/metadata_phone.csv +0 -0
  19. data/processed/ljspeech/mfa_dict.txt +0 -0
  20. data/processed/ljspeech/phone_set.json +1 -0
  21. data_gen/singing/binarize.py +398 -0
  22. data_gen/tts/base_binarizer.py +224 -0
  23. data_gen/tts/bin/binarize.py +20 -0
  24. data_gen/tts/binarizer_zh.py +59 -0
  25. data_gen/tts/data_gen_utils.py +347 -0
  26. data_gen/tts/txt_processors/base_text_processor.py +8 -0
  27. data_gen/tts/txt_processors/en.py +78 -0
  28. data_gen/tts/txt_processors/zh.py +41 -0
  29. data_gen/tts/txt_processors/zh_g2pM.py +72 -0
  30. docs/README-SVS-opencpop-cascade.md +111 -0
  31. docs/README-SVS-opencpop-e2e.md +106 -0
  32. docs/README-SVS-popcs.md +63 -0
  33. docs/README-SVS.md +44 -0
  34. docs/README-TTS.md +63 -0
  35. docs/README-zh.md +212 -0
  36. inference/svs/base_svs_infer.py +265 -0
  37. inference/svs/ds_cascade.py +54 -0
  38. inference/svs/ds_e2e.py +67 -0
  39. inference/svs/gradio/gradio_settings.yaml +19 -0
  40. inference/svs/gradio/infer.py +91 -0
  41. inference/svs/opencpop/cpop_pinyin2ph.txt +418 -0
  42. inference/svs/opencpop/map.py +8 -0
  43. modules/__init__.py +0 -0
  44. modules/commons/common_layers.py +668 -0
  45. modules/commons/espnet_positional_embedding.py +113 -0
  46. modules/commons/ssim.py +391 -0
  47. modules/diffsinger_midi/fs2.py +118 -0
  48. modules/fastspeech/fs2.py +255 -0
  49. modules/fastspeech/pe.py +149 -0
  50. modules/fastspeech/tts_modules.py +357 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Jinglin Liu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,83 @@
1
- ---
2
- title: DiffSinger
3
- emoji: 👁
4
- colorFrom: green
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 3.1.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism
2
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446)
3
+ [![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
4
+ [![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
5
+ | [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/NATSpeech/DiffSpeech)
6
+
7
+ This repository is the official PyTorch implementation of our AAAI-2022 [paper](https://arxiv.org/abs/2105.02446), in which we propose DiffSinger (for Singing-Voice-Synthesis) and DiffSpeech (for Text-to-Speech).
8
+
9
+ <table style="width:100%">
10
+ <tr>
11
+ <th>DiffSinger/DiffSpeech at training</th>
12
+ <th>DiffSinger/DiffSpeech at inference</th>
13
+ </tr>
14
+ <tr>
15
+ <td><img src="resources/model_a.png" alt="Training" height="300"></td>
16
+ <td><img src="resources/model_b.png" alt="Inference" height="300"></td>
17
+ </tr>
18
+ </table>
19
+
20
+ :tada: :tada: :tada: **Updates**:
21
+ - Mar.2, 2022: [MIDI-new-version](docs/README-SVS-opencpop-e2e.md): A substantial improvement :sparkles:
22
+ - Mar.1, 2022: [NeuralSVB](https://github.com/MoonInTheRiver/NeuralSVB), for singing voice beautifying, has been released :sparkles: :sparkles: :sparkles: .
23
+ - Feb.13, 2022: [NATSpeech](https://github.com/NATSpeech/NATSpeech), the improved code framework, which contains the implementations of DiffSpeech and our NeurIPS-2021 work [PortaSpeech](https://openreview.net/forum?id=xmJsuh8xlq) has been released :sparkles: :sparkles: :sparkles:.
24
+ - Jan.29, 2022: support [MIDI-old-version](docs/README-SVS-opencpop-cascade.md) SVS. :construction: :pick: :hammer_and_wrench:
25
+ - Jan.13, 2022: support SVS, release PopCS dataset.
26
+ - Dec.19, 2021: support TTS. [HuggingFace🤗 Demo](https://huggingface.co/spaces/NATSpeech/DiffSpeech)
27
+
28
+ :rocket: **News**:
29
+ - Feb.24, 2022: Our new work, NeuralSVB was accepted by ACL-2022 [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2202.13277). [Demo Page](https://neuralsvb.github.io).
30
+ - Dec.01, 2021: DiffSinger was accepted by AAAI-2022.
31
+ - Sep.29, 2021: Our recent work `PortaSpeech: Portable and High-Quality Generative Text-to-Speech` was accepted by NeurIPS-2021 [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2109.15166) .
32
+ - May.06, 2021: We submitted DiffSinger to Arxiv [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446).
33
+
34
+ ## Environments
35
+ ```sh
36
+ conda create -n your_env_name python=3.8
37
+ source activate your_env_name
38
+ pip install -r requirements_2080.txt (GPU 2080Ti, CUDA 10.2)
39
+ or pip install -r requirements_3090.txt (GPU 3090, CUDA 11.4)
40
+ ```
41
+
42
+ ## Documents
43
+ - [Run DiffSpeech (TTS version)](docs/README-TTS.md).
44
+ - [Run DiffSinger (SVS version)](docs/README-SVS.md).
45
+
46
+ ## Tensorboard
47
+ ```sh
48
+ tensorboard --logdir_spec exp_name
49
+ ```
50
+ <table style="width:100%">
51
+ <tr>
52
+ <td><img src="resources/tfb.png" alt="Tensorboard" height="250"></td>
53
+ </tr>
54
+ </table>
55
+
56
+ ## Audio Demos
57
+ Old audio samples can be found in our [demo page](https://diffsinger.github.io/). Audio samples generated by this repository are listed here:
58
+
59
+ ### TTS audio samples
60
+ Speech samples (test set of LJSpeech) can be found in [resources/demos_1213](https://github.com/MoonInTheRiver/DiffSinger/blob/master/resources/demos_1213).
61
+
62
+ ### SVS audio samples
63
+ Singing samples (test set of PopCS) can be found in [resources/demos_0112](https://github.com/MoonInTheRiver/DiffSinger/blob/master/resources/demos_0112).
64
+
65
+ ## Citation
66
+ @article{liu2021diffsinger,
67
+ title={Diffsinger: Singing voice synthesis via shallow diffusion mechanism},
68
+ author={Liu, Jinglin and Li, Chengxi and Ren, Yi and Chen, Feiyang and Liu, Peng and Zhao, Zhou},
69
+ journal={arXiv preprint arXiv:2105.02446},
70
+ volume={2},
71
+ year={2021}}
72
+
73
+
74
+ ## Acknowledgements
75
+ Our codes are based on the following repos:
76
+ * [denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch)
77
+ * [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning)
78
+ * [ParallelWaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN)
79
+ * [HifiGAN](https://github.com/jik876/hifi-gan)
80
+ * [espnet](https://github.com/espnet/espnet)
81
+ * [DiffWave](https://github.com/lmnt-com/diffwave)
82
+
83
+ Also thanks [Keon Lee](https://github.com/keonlee9420/DiffSinger) for fast implementation of our work.
checkpoints/.gitkeep ADDED
File without changes
configs/config_base.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # task
2
+ binary_data_dir: ''
3
+ work_dir: '' # experiment directory.
4
+ infer: false # infer
5
+ seed: 1234
6
+ debug: false
7
+ save_codes:
8
+ - configs
9
+ - modules
10
+ - tasks
11
+ - utils
12
+ - usr
13
+
14
+ #############
15
+ # dataset
16
+ #############
17
+ ds_workers: 1
18
+ test_num: 100
19
+ valid_num: 100
20
+ endless_ds: false
21
+ sort_by_len: true
22
+
23
+ #########
24
+ # train and eval
25
+ #########
26
+ load_ckpt: ''
27
+ save_ckpt: true
28
+ save_best: false
29
+ num_ckpt_keep: 3
30
+ clip_grad_norm: 0
31
+ accumulate_grad_batches: 1
32
+ log_interval: 100
33
+ num_sanity_val_steps: 5 # steps of validation at the beginning
34
+ check_val_every_n_epoch: 10
35
+ val_check_interval: 2000
36
+ max_epochs: 1000
37
+ max_updates: 160000
38
+ max_tokens: 31250
39
+ max_sentences: 100000
40
+ max_eval_tokens: -1
41
+ max_eval_sentences: -1
42
+ test_input_dir: ''
configs/singing/base.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - configs/tts/base.yaml
3
+ - configs/tts/base_zh.yaml
4
+
5
+
6
+ datasets: []
7
+ test_prefixes: []
8
+ test_num: 0
9
+ valid_num: 0
10
+
11
+ pre_align_cls: data_gen.singing.pre_align.SingingPreAlign
12
+ binarizer_cls: data_gen.singing.binarize.SingingBinarizer
13
+ pre_align_args:
14
+ use_tone: false # for ZH
15
+ forced_align: mfa
16
+ use_sox: true
17
+ hop_size: 128 # Hop size.
18
+ fft_size: 512 # FFT size.
19
+ win_size: 512 # FFT size.
20
+ max_frames: 8000
21
+ fmin: 50 # Minimum freq in mel basis calculation.
22
+ fmax: 11025 # Maximum frequency in mel basis calculation.
23
+ pitch_type: frame
24
+
25
+ hidden_size: 256
26
+ mel_loss: "ssim:0.5|l1:0.5"
27
+ lambda_f0: 0.0
28
+ lambda_uv: 0.0
29
+ lambda_energy: 0.0
30
+ lambda_ph_dur: 0.0
31
+ lambda_sent_dur: 0.0
32
+ lambda_word_dur: 0.0
33
+ predictor_grad: 0.0
34
+ use_spk_embed: true
35
+ use_spk_id: false
36
+
37
+ max_tokens: 20000
38
+ max_updates: 400000
39
+ num_spk: 100
40
+ save_f0: true
41
+ use_gt_dur: true
42
+ use_gt_f0: true
configs/singing/fs2.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ base_config:
2
+ - configs/tts/fs2.yaml
3
+ - configs/singing/base.yaml
configs/tts/base.yaml ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # task
2
+ base_config: configs/config_base.yaml
3
+ task_cls: ''
4
+ #############
5
+ # dataset
6
+ #############
7
+ raw_data_dir: ''
8
+ processed_data_dir: ''
9
+ binary_data_dir: ''
10
+ dict_dir: ''
11
+ pre_align_cls: ''
12
+ binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer
13
+ pre_align_args:
14
+ use_tone: true # for ZH
15
+ forced_align: mfa
16
+ use_sox: false
17
+ txt_processor: en
18
+ allow_no_txt: false
19
+ denoise: false
20
+ binarization_args:
21
+ shuffle: false
22
+ with_txt: true
23
+ with_wav: false
24
+ with_align: true
25
+ with_spk_embed: true
26
+ with_f0: true
27
+ with_f0cwt: true
28
+
29
+ loud_norm: false
30
+ endless_ds: true
31
+ reset_phone_dict: true
32
+
33
+ test_num: 100
34
+ valid_num: 100
35
+ max_frames: 1550
36
+ max_input_tokens: 1550
37
+ audio_num_mel_bins: 80
38
+ audio_sample_rate: 22050
39
+ hop_size: 256 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
40
+ win_size: 1024 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
41
+ fmin: 80 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
42
+ fmax: 7600 # To be increased/reduced depending on data.
43
+ fft_size: 1024 # Extra window size is filled with 0 paddings to match this parameter
44
+ min_level_db: -100
45
+ num_spk: 1
46
+ mel_vmin: -6
47
+ mel_vmax: 1.5
48
+ ds_workers: 4
49
+
50
+ #########
51
+ # model
52
+ #########
53
+ dropout: 0.1
54
+ enc_layers: 4
55
+ dec_layers: 4
56
+ hidden_size: 384
57
+ num_heads: 2
58
+ prenet_dropout: 0.5
59
+ prenet_hidden_size: 256
60
+ stop_token_weight: 5.0
61
+ enc_ffn_kernel_size: 9
62
+ dec_ffn_kernel_size: 9
63
+ ffn_act: gelu
64
+ ffn_padding: 'SAME'
65
+
66
+
67
+ ###########
68
+ # optimization
69
+ ###########
70
+ lr: 2.0
71
+ warmup_updates: 8000
72
+ optimizer_adam_beta1: 0.9
73
+ optimizer_adam_beta2: 0.98
74
+ weight_decay: 0
75
+ clip_grad_norm: 1
76
+
77
+
78
+ ###########
79
+ # train and eval
80
+ ###########
81
+ max_tokens: 30000
82
+ max_sentences: 100000
83
+ max_eval_sentences: 1
84
+ max_eval_tokens: 60000
85
+ train_set_name: 'train'
86
+ valid_set_name: 'valid'
87
+ test_set_name: 'test'
88
+ vocoder: pwg
89
+ vocoder_ckpt: ''
90
+ profile_infer: false
91
+ out_wav_norm: false
92
+ save_gt: false
93
+ save_f0: false
94
+ gen_dir_name: ''
95
+ use_denoise: false
configs/tts/base_zh.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pre_align_args:
2
+ txt_processor: zh_g2pM
3
+ binarizer_cls: data_gen.tts.binarizer_zh.ZhBinarizer
configs/tts/fs2.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config: configs/tts/base.yaml
2
+ task_cls: tasks.tts.fs2.FastSpeech2Task
3
+
4
+ # model
5
+ hidden_size: 256
6
+ dropout: 0.1
7
+ encoder_type: fft # fft|tacotron|tacotron2|conformer
8
+ encoder_K: 8 # for tacotron encoder
9
+ decoder_type: fft # fft|rnn|conv|conformer
10
+ use_pos_embed: true
11
+
12
+ # duration
13
+ predictor_hidden: -1
14
+ predictor_kernel: 5
15
+ predictor_layers: 2
16
+ dur_predictor_kernel: 3
17
+ dur_predictor_layers: 2
18
+ predictor_dropout: 0.5
19
+
20
+ # pitch and energy
21
+ use_pitch_embed: true
22
+ pitch_type: ph # frame|ph|cwt
23
+ use_uv: true
24
+ cwt_hidden_size: 128
25
+ cwt_layers: 2
26
+ cwt_loss: l1
27
+ cwt_add_f0_loss: false
28
+ cwt_std_scale: 0.8
29
+
30
+ pitch_ar: false
31
+ #pitch_embed_type: 0q
32
+ pitch_loss: 'l1' # l1|l2|ssim
33
+ pitch_norm: log
34
+ use_energy_embed: false
35
+
36
+ # reference encoder and speaker embedding
37
+ use_spk_id: false
38
+ use_split_spk_id: false
39
+ use_spk_embed: false
40
+ use_var_enc: false
41
+ lambda_commit: 0.25
42
+ ref_norm_layer: bn
43
+ pitch_enc_hidden_stride_kernel:
44
+ - 0,2,5 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size
45
+ - 0,2,5
46
+ - 0,2,5
47
+ dur_enc_hidden_stride_kernel:
48
+ - 0,2,3 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size
49
+ - 0,2,3
50
+ - 0,1,3
51
+
52
+
53
+ # mel
54
+ mel_loss: l1:0.5|ssim:0.5 # l1|l2|gdl|ssim or l1:0.5|ssim:0.5
55
+
56
+ # loss lambda
57
+ lambda_f0: 1.0
58
+ lambda_uv: 1.0
59
+ lambda_energy: 0.1
60
+ lambda_ph_dur: 1.0
61
+ lambda_sent_dur: 1.0
62
+ lambda_word_dur: 1.0
63
+ predictor_grad: 0.1
64
+
65
+ # train and eval
66
+ pretrain_fs_ckpt: ''
67
+ warmup_updates: 2000
68
+ max_tokens: 32000
69
+ max_sentences: 100000
70
+ max_eval_sentences: 1
71
+ max_updates: 120000
72
+ num_valid_plots: 5
73
+ num_test_samples: 0
74
+ test_ids: []
75
+ use_gt_dur: false
76
+ use_gt_f0: false
77
+
78
+ # exp
79
+ dur_loss: mse # huber|mol
80
+ norm_type: gn
configs/tts/hifigan.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config: configs/tts/pwg.yaml
2
+ task_cls: tasks.vocoder.hifigan.HifiGanTask
3
+ resblock: "1"
4
+ adam_b1: 0.8
5
+ adam_b2: 0.99
6
+ upsample_rates: [ 8,8,2,2 ]
7
+ upsample_kernel_sizes: [ 16,16,4,4 ]
8
+ upsample_initial_channel: 128
9
+ resblock_kernel_sizes: [ 3,7,11 ]
10
+ resblock_dilation_sizes: [ [ 1,3,5 ], [ 1,3,5 ], [ 1,3,5 ] ]
11
+
12
+ lambda_mel: 45.0
13
+
14
+ max_samples: 8192
15
+ max_sentences: 16
16
+
17
+ generator_params:
18
+ lr: 0.0002 # Generator's learning rate.
19
+ aux_context_window: 0 # Context window size for auxiliary feature.
20
+ discriminator_optimizer_params:
21
+ lr: 0.0002 # Discriminator's learning rate.
configs/tts/lj/base_mel2wav.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ raw_data_dir: 'data/raw/LJSpeech-1.1'
2
+ processed_data_dir: 'data/processed/ljspeech'
3
+ binary_data_dir: 'data/binary/ljspeech_wav'
configs/tts/lj/base_text2mel.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ raw_data_dir: 'data/raw/LJSpeech-1.1'
2
+ processed_data_dir: 'data/processed/ljspeech'
3
+ binary_data_dir: 'data/binary/ljspeech'
4
+ pre_align_cls: data_gen.tts.lj.pre_align.LJPreAlign
5
+
6
+ pitch_type: cwt
7
+ mel_loss: l1
8
+ num_test_samples: 20
9
+ test_ids: [ 68, 70, 74, 87, 110, 172, 190, 215, 231, 294,
10
+ 316, 324, 402, 422, 485, 500, 505, 508, 509, 519 ]
11
+ use_energy_embed: false
12
+ test_num: 523
13
+ valid_num: 348
configs/tts/lj/fs2.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ base_config:
2
+ - configs/tts/fs2.yaml
3
+ - configs/tts/lj/base_text2mel.yaml
configs/tts/lj/hifigan.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ base_config:
2
+ - configs/tts/hifigan.yaml
3
+ - configs/tts/lj/base_mel2wav.yaml
configs/tts/lj/pwg.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ base_config:
2
+ - configs/tts/pwg.yaml
3
+ - configs/tts/lj/base_mel2wav.yaml
configs/tts/pwg.yaml ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config: configs/tts/base.yaml
2
+ task_cls: tasks.vocoder.pwg.PwgTask
3
+
4
+ binarization_args:
5
+ with_wav: true
6
+ with_spk_embed: false
7
+ with_align: false
8
+ test_input_dir: ''
9
+
10
+ ###########
11
+ # train and eval
12
+ ###########
13
+ max_samples: 25600
14
+ max_sentences: 5
15
+ max_eval_sentences: 1
16
+ max_updates: 1000000
17
+ val_check_interval: 2000
18
+
19
+
20
+ ###########################################################
21
+ # FEATURE EXTRACTION SETTING #
22
+ ###########################################################
23
+ sampling_rate: 22050 # Sampling rate.
24
+ fft_size: 1024 # FFT size.
25
+ hop_size: 256 # Hop size.
26
+ win_length: null # Window length.
27
+ # If set to null, it will be the same as fft_size.
28
+ window: "hann" # Window function.
29
+ num_mels: 80 # Number of mel basis.
30
+ fmin: 80 # Minimum freq in mel basis calculation.
31
+ fmax: 7600 # Maximum frequency in mel basis calculation.
32
+ format: "hdf5" # Feature file format. "npy" or "hdf5" is supported.
33
+
34
+ ###########################################################
35
+ # GENERATOR NETWORK ARCHITECTURE SETTING #
36
+ ###########################################################
37
+ generator_params:
38
+ in_channels: 1 # Number of input channels.
39
+ out_channels: 1 # Number of output channels.
40
+ kernel_size: 3 # Kernel size of dilated convolution.
41
+ layers: 30 # Number of residual block layers.
42
+ stacks: 3 # Number of stacks i.e., dilation cycles.
43
+ residual_channels: 64 # Number of channels in residual conv.
44
+ gate_channels: 128 # Number of channels in gated conv.
45
+ skip_channels: 64 # Number of channels in skip conv.
46
+ aux_channels: 80 # Number of channels for auxiliary feature conv.
47
+ # Must be the same as num_mels.
48
+ aux_context_window: 2 # Context window size for auxiliary feature.
49
+ # If set to 2, previous 2 and future 2 frames will be considered.
50
+ dropout: 0.0 # Dropout rate. 0.0 means no dropout applied.
51
+ use_weight_norm: true # Whether to use weight norm.
52
+ # If set to true, it will be applied to all of the conv layers.
53
+ upsample_net: "ConvInUpsampleNetwork" # Upsampling network architecture.
54
+ upsample_params: # Upsampling network parameters.
55
+ upsample_scales: [4, 4, 4, 4] # Upsampling scales. Prodcut of these must be the same as hop size.
56
+ use_pitch_embed: false
57
+
58
+ ###########################################################
59
+ # DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
60
+ ###########################################################
61
+ discriminator_params:
62
+ in_channels: 1 # Number of input channels.
63
+ out_channels: 1 # Number of output channels.
64
+ kernel_size: 3 # Number of output channels.
65
+ layers: 10 # Number of conv layers.
66
+ conv_channels: 64 # Number of chnn layers.
67
+ bias: true # Whether to use bias parameter in conv.
68
+ use_weight_norm: true # Whether to use weight norm.
69
+ # If set to true, it will be applied to all of the conv layers.
70
+ nonlinear_activation: "LeakyReLU" # Nonlinear function after each conv.
71
+ nonlinear_activation_params: # Nonlinear function parameters
72
+ negative_slope: 0.2 # Alpha in LeakyReLU.
73
+
74
+ ###########################################################
75
+ # STFT LOSS SETTING #
76
+ ###########################################################
77
+ stft_loss_params:
78
+ fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
79
+ hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
80
+ win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
81
+ window: "hann_window" # Window function for STFT-based loss
82
+ use_mel_loss: false
83
+
84
+ ###########################################################
85
+ # ADVERSARIAL LOSS SETTING #
86
+ ###########################################################
87
+ lambda_adv: 4.0 # Loss balancing coefficient.
88
+
89
+ ###########################################################
90
+ # OPTIMIZER & SCHEDULER SETTING #
91
+ ###########################################################
92
+ generator_optimizer_params:
93
+ lr: 0.0001 # Generator's learning rate.
94
+ eps: 1.0e-6 # Generator's epsilon.
95
+ weight_decay: 0.0 # Generator's weight decay coefficient.
96
+ generator_scheduler_params:
97
+ step_size: 200000 # Generator's scheduler step size.
98
+ gamma: 0.5 # Generator's scheduler gamma.
99
+ # At each step size, lr will be multiplied by this parameter.
100
+ generator_grad_norm: 10 # Generator's gradient norm.
101
+ discriminator_optimizer_params:
102
+ lr: 0.00005 # Discriminator's learning rate.
103
+ eps: 1.0e-6 # Discriminator's epsilon.
104
+ weight_decay: 0.0 # Discriminator's weight decay coefficient.
105
+ discriminator_scheduler_params:
106
+ step_size: 200000 # Discriminator's scheduler step size.
107
+ gamma: 0.5 # Discriminator's scheduler gamma.
108
+ # At each step size, lr will be multiplied by this parameter.
109
+ discriminator_grad_norm: 1 # Discriminator's gradient norm.
110
+ disc_start_steps: 40000 # Number of steps to start to train discriminator.
data/processed/ljspeech/dict.txt ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ! !
2
+ , ,
3
+ . .
4
+ ; ;
5
+ <BOS> <BOS>
6
+ <EOS> <EOS>
7
+ ? ?
8
+ AA0 AA0
9
+ AA1 AA1
10
+ AA2 AA2
11
+ AE0 AE0
12
+ AE1 AE1
13
+ AE2 AE2
14
+ AH0 AH0
15
+ AH1 AH1
16
+ AH2 AH2
17
+ AO0 AO0
18
+ AO1 AO1
19
+ AO2 AO2
20
+ AW0 AW0
21
+ AW1 AW1
22
+ AW2 AW2
23
+ AY0 AY0
24
+ AY1 AY1
25
+ AY2 AY2
26
+ B B
27
+ CH CH
28
+ D D
29
+ DH DH
30
+ EH0 EH0
31
+ EH1 EH1
32
+ EH2 EH2
33
+ ER0 ER0
34
+ ER1 ER1
35
+ ER2 ER2
36
+ EY0 EY0
37
+ EY1 EY1
38
+ EY2 EY2
39
+ F F
40
+ G G
41
+ HH HH
42
+ IH0 IH0
43
+ IH1 IH1
44
+ IH2 IH2
45
+ IY0 IY0
46
+ IY1 IY1
47
+ IY2 IY2
48
+ JH JH
49
+ K K
50
+ L L
51
+ M M
52
+ N N
53
+ NG NG
54
+ OW0 OW0
55
+ OW1 OW1
56
+ OW2 OW2
57
+ OY0 OY0
58
+ OY1 OY1
59
+ OY2 OY2
60
+ P P
61
+ R R
62
+ S S
63
+ SH SH
64
+ T T
65
+ TH TH
66
+ UH0 UH0
67
+ UH1 UH1
68
+ UH2 UH2
69
+ UW0 UW0
70
+ UW1 UW1
71
+ UW2 UW2
72
+ V V
73
+ W W
74
+ Y Y
75
+ Z Z
76
+ ZH ZH
77
+ | |
data/processed/ljspeech/metadata_phone.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/processed/ljspeech/mfa_dict.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/processed/ljspeech/phone_set.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ["!", ",", ".", ";", "<BOS>", "<EOS>", "?", "AA0", "AA1", "AA2", "AE0", "AE1", "AE2", "AH0", "AH1", "AH2", "AO0", "AO1", "AO2", "AW0", "AW1", "AW2", "AY0", "AY1", "AY2", "B", "CH", "D", "DH", "EH0", "EH1", "EH2", "ER0", "ER1", "ER2", "EY0", "EY1", "EY2", "F", "G", "HH", "IH0", "IH1", "IH2", "IY0", "IY1", "IY2", "JH", "K", "L", "M", "N", "NG", "OW0", "OW1", "OW2", "OY0", "OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH0", "UH1", "UH2", "UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH", "|"]
data_gen/singing/binarize.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from copy import deepcopy
4
+ import pandas as pd
5
+ import logging
6
+ from tqdm import tqdm
7
+ import json
8
+ import glob
9
+ import re
10
+ from resemblyzer import VoiceEncoder
11
+ import traceback
12
+ import numpy as np
13
+ import pretty_midi
14
+ import librosa
15
+ from scipy.interpolate import interp1d
16
+ import torch
17
+ from textgrid import TextGrid
18
+
19
+ from utils.hparams import hparams
20
+ from data_gen.tts.data_gen_utils import build_phone_encoder, get_pitch
21
+ from utils.pitch_utils import f0_to_coarse
22
+ from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError
23
+ from data_gen.tts.binarizer_zh import ZhBinarizer
24
+ from data_gen.tts.txt_processors.zh_g2pM import ALL_YUNMU
25
+ from vocoders.base_vocoder import VOCODERS
26
+
27
+
28
+ class SingingBinarizer(BaseBinarizer):
29
+ def __init__(self, processed_data_dir=None):
30
+ if processed_data_dir is None:
31
+ processed_data_dir = hparams['processed_data_dir']
32
+ self.processed_data_dirs = processed_data_dir.split(",")
33
+ self.binarization_args = hparams['binarization_args']
34
+ self.pre_align_args = hparams['pre_align_args']
35
+ self.item2txt = {}
36
+ self.item2ph = {}
37
+ self.item2wavfn = {}
38
+ self.item2f0fn = {}
39
+ self.item2tgfn = {}
40
+ self.item2spk = {}
41
+
42
+ def split_train_test_set(self, item_names):
43
+ item_names = deepcopy(item_names)
44
+ test_item_names = [x for x in item_names if any([ts in x for ts in hparams['test_prefixes']])]
45
+ train_item_names = [x for x in item_names if x not in set(test_item_names)]
46
+ logging.info("train {}".format(len(train_item_names)))
47
+ logging.info("test {}".format(len(test_item_names)))
48
+ return train_item_names, test_item_names
49
+
50
+ def load_meta_data(self):
51
+ for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
52
+ wav_suffix = '_wf0.wav'
53
+ txt_suffix = '.txt'
54
+ ph_suffix = '_ph.txt'
55
+ tg_suffix = '.TextGrid'
56
+ all_wav_pieces = glob.glob(f'{processed_data_dir}/*/*{wav_suffix}')
57
+
58
+ for piece_path in all_wav_pieces:
59
+ item_name = raw_item_name = piece_path[len(processed_data_dir)+1:].replace('/', '-')[:-len(wav_suffix)]
60
+ if len(self.processed_data_dirs) > 1:
61
+ item_name = f'ds{ds_id}_{item_name}'
62
+ self.item2txt[item_name] = open(f'{piece_path.replace(wav_suffix, txt_suffix)}').readline()
63
+ self.item2ph[item_name] = open(f'{piece_path.replace(wav_suffix, ph_suffix)}').readline()
64
+ self.item2wavfn[item_name] = piece_path
65
+
66
+ self.item2spk[item_name] = re.split('-|#', piece_path.split('/')[-2])[0]
67
+ if len(self.processed_data_dirs) > 1:
68
+ self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
69
+ self.item2tgfn[item_name] = piece_path.replace(wav_suffix, tg_suffix)
70
+ print('spkers: ', set(self.item2spk.values()))
71
+ self.item_names = sorted(list(self.item2txt.keys()))
72
+ if self.binarization_args['shuffle']:
73
+ random.seed(1234)
74
+ random.shuffle(self.item_names)
75
+ self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
76
+
77
+ @property
78
+ def train_item_names(self):
79
+ return self._train_item_names
80
+
81
+ @property
82
+ def valid_item_names(self):
83
+ return self._test_item_names
84
+
85
+ @property
86
+ def test_item_names(self):
87
+ return self._test_item_names
88
+
89
+ def process(self):
90
+ self.load_meta_data()
91
+ os.makedirs(hparams['binary_data_dir'], exist_ok=True)
92
+ self.spk_map = self.build_spk_map()
93
+ print("| spk_map: ", self.spk_map)
94
+ spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json"
95
+ json.dump(self.spk_map, open(spk_map_fn, 'w'))
96
+
97
+ self.phone_encoder = self._phone_encoder()
98
+ self.process_data('valid')
99
+ self.process_data('test')
100
+ self.process_data('train')
101
+
102
+ def _phone_encoder(self):
103
+ ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
104
+ ph_set = []
105
+ if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn):
106
+ for ph_sent in self.item2ph.values():
107
+ ph_set += ph_sent.split(' ')
108
+ ph_set = sorted(set(ph_set))
109
+ json.dump(ph_set, open(ph_set_fn, 'w'))
110
+ print("| Build phone set: ", ph_set)
111
+ else:
112
+ ph_set = json.load(open(ph_set_fn, 'r'))
113
+ print("| Load phone set: ", ph_set)
114
+ return build_phone_encoder(hparams['binary_data_dir'])
115
+
116
+ # @staticmethod
117
+ # def get_pitch(wav_fn, spec, res):
118
+ # wav_suffix = '_wf0.wav'
119
+ # f0_suffix = '_f0.npy'
120
+ # f0fn = wav_fn.replace(wav_suffix, f0_suffix)
121
+ # pitch_info = np.load(f0fn)
122
+ # f0 = [x[1] for x in pitch_info]
123
+ # spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)]
124
+ # f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)]
125
+ # f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)]
126
+ # # f0_x_coor = np.arange(0, 1, 1 / len(f0))
127
+ # # f0_x_coor[-1] = 1
128
+ # # f0 = interp1d(f0_x_coor, f0, 'nearest')(spec_x_coor)[:len(spec)]
129
+ # if sum(f0) == 0:
130
+ # raise BinarizationError("Empty f0")
131
+ # assert len(f0) == len(spec), (len(f0), len(spec))
132
+ # pitch_coarse = f0_to_coarse(f0)
133
+ #
134
+ # # vis f0
135
+ # # import matplotlib.pyplot as plt
136
+ # # from textgrid import TextGrid
137
+ # # tg_fn = wav_fn.replace(wav_suffix, '.TextGrid')
138
+ # # fig = plt.figure(figsize=(12, 6))
139
+ # # plt.pcolor(spec.T, vmin=-5, vmax=0)
140
+ # # ax = plt.gca()
141
+ # # ax2 = ax.twinx()
142
+ # # ax2.plot(f0, color='red')
143
+ # # ax2.set_ylim(0, 800)
144
+ # # itvs = TextGrid.fromFile(tg_fn)[0]
145
+ # # for itv in itvs:
146
+ # # x = itv.maxTime * hparams['audio_sample_rate'] / hparams['hop_size']
147
+ # # plt.vlines(x=x, ymin=0, ymax=80, color='black')
148
+ # # plt.text(x=x, y=20, s=itv.mark, color='black')
149
+ # # plt.savefig('tmp/20211229_singing_plots_test.png')
150
+ #
151
+ # res['f0'] = f0
152
+ # res['pitch'] = pitch_coarse
153
+
154
+ @classmethod
155
+ def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
156
+ if hparams['vocoder'] in VOCODERS:
157
+ wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
158
+ else:
159
+ wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
160
+ res = {
161
+ 'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
162
+ 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
163
+ }
164
+ try:
165
+ if binarization_args['with_f0']:
166
+ # cls.get_pitch(wav_fn, mel, res)
167
+ cls.get_pitch(wav, mel, res)
168
+ if binarization_args['with_txt']:
169
+ try:
170
+ # print(ph)
171
+ phone_encoded = res['phone'] = encoder.encode(ph)
172
+ except:
173
+ traceback.print_exc()
174
+ raise BinarizationError(f"Empty phoneme")
175
+ if binarization_args['with_align']:
176
+ cls.get_align(tg_fn, ph, mel, phone_encoded, res)
177
+ except BinarizationError as e:
178
+ print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
179
+ return None
180
+ return res
181
+
182
+
183
+ class MidiSingingBinarizer(SingingBinarizer):
184
+ item2midi = {}
185
+ item2midi_dur = {}
186
+ item2is_slur = {}
187
+ item2ph_durs = {}
188
+ item2wdb = {}
189
+
190
+ def load_meta_data(self):
191
+ for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
192
+ meta_midi = json.load(open(os.path.join(processed_data_dir, 'meta.json'))) # [list of dict]
193
+
194
+ for song_item in meta_midi:
195
+ item_name = raw_item_name = song_item['item_name']
196
+ if len(self.processed_data_dirs) > 1:
197
+ item_name = f'ds{ds_id}_{item_name}'
198
+ self.item2wavfn[item_name] = song_item['wav_fn']
199
+ self.item2txt[item_name] = song_item['txt']
200
+
201
+ self.item2ph[item_name] = ' '.join(song_item['phs'])
202
+ self.item2wdb[item_name] = [1 if x in ALL_YUNMU + ['AP', 'SP', '<SIL>'] else 0 for x in song_item['phs']]
203
+ self.item2ph_durs[item_name] = song_item['ph_dur']
204
+
205
+ self.item2midi[item_name] = song_item['notes']
206
+ self.item2midi_dur[item_name] = song_item['notes_dur']
207
+ self.item2is_slur[item_name] = song_item['is_slur']
208
+ self.item2spk[item_name] = 'pop-cs'
209
+ if len(self.processed_data_dirs) > 1:
210
+ self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
211
+
212
+ print('spkers: ', set(self.item2spk.values()))
213
+ self.item_names = sorted(list(self.item2txt.keys()))
214
+ if self.binarization_args['shuffle']:
215
+ random.seed(1234)
216
+ random.shuffle(self.item_names)
217
+ self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
218
+
219
+ @staticmethod
220
+ def get_pitch(wav_fn, wav, spec, ph, res):
221
+ wav_suffix = '.wav'
222
+ # midi_suffix = '.mid'
223
+ wav_dir = 'wavs'
224
+ f0_dir = 'f0'
225
+
226
+ item_name = '/'.join(os.path.splitext(wav_fn)[0].split('/')[-2:]).replace('_wf0', '')
227
+ res['pitch_midi'] = np.asarray(MidiSingingBinarizer.item2midi[item_name])
228
+ res['midi_dur'] = np.asarray(MidiSingingBinarizer.item2midi_dur[item_name])
229
+ res['is_slur'] = np.asarray(MidiSingingBinarizer.item2is_slur[item_name])
230
+ res['word_boundary'] = np.asarray(MidiSingingBinarizer.item2wdb[item_name])
231
+ assert res['pitch_midi'].shape == res['midi_dur'].shape == res['is_slur'].shape, (
232
+ res['pitch_midi'].shape, res['midi_dur'].shape, res['is_slur'].shape)
233
+
234
+ # gt f0.
235
+ gt_f0, gt_pitch_coarse = get_pitch(wav, spec, hparams)
236
+ if sum(gt_f0) == 0:
237
+ raise BinarizationError("Empty **gt** f0")
238
+ res['f0'] = gt_f0
239
+ res['pitch'] = gt_pitch_coarse
240
+
241
+ @staticmethod
242
+ def get_align(ph_durs, mel, phone_encoded, res, hop_size=hparams['hop_size'], audio_sample_rate=hparams['audio_sample_rate']):
243
+ mel2ph = np.zeros([mel.shape[0]], int)
244
+ startTime = 0
245
+
246
+ for i_ph in range(len(ph_durs)):
247
+ start_frame = int(startTime * audio_sample_rate / hop_size + 0.5)
248
+ end_frame = int((startTime + ph_durs[i_ph]) * audio_sample_rate / hop_size + 0.5)
249
+ mel2ph[start_frame:end_frame] = i_ph + 1
250
+ startTime = startTime + ph_durs[i_ph]
251
+
252
+ # print('ph durs: ', ph_durs)
253
+ # print('mel2ph: ', mel2ph, len(mel2ph))
254
+ res['mel2ph'] = mel2ph
255
+ # res['dur'] = None
256
+
257
+ @classmethod
258
+ def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
259
+ if hparams['vocoder'] in VOCODERS:
260
+ wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
261
+ else:
262
+ wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
263
+ res = {
264
+ 'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
265
+ 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
266
+ }
267
+ try:
268
+ if binarization_args['with_f0']:
269
+ cls.get_pitch(wav_fn, wav, mel, ph, res)
270
+ if binarization_args['with_txt']:
271
+ try:
272
+ phone_encoded = res['phone'] = encoder.encode(ph)
273
+ except:
274
+ traceback.print_exc()
275
+ raise BinarizationError(f"Empty phoneme")
276
+ if binarization_args['with_align']:
277
+ cls.get_align(MidiSingingBinarizer.item2ph_durs[item_name], mel, phone_encoded, res)
278
+ except BinarizationError as e:
279
+ print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
280
+ return None
281
+ return res
282
+
283
+
284
+ class ZhSingingBinarizer(ZhBinarizer, SingingBinarizer):
285
+ pass
286
+
287
+
288
+ class OpencpopBinarizer(MidiSingingBinarizer):
289
+ item2midi = {}
290
+ item2midi_dur = {}
291
+ item2is_slur = {}
292
+ item2ph_durs = {}
293
+ item2wdb = {}
294
+
295
+ def split_train_test_set(self, item_names):
296
+ item_names = deepcopy(item_names)
297
+ test_item_names = [x for x in item_names if any([x.startswith(ts) for ts in hparams['test_prefixes']])]
298
+ train_item_names = [x for x in item_names if x not in set(test_item_names)]
299
+ logging.info("train {}".format(len(train_item_names)))
300
+ logging.info("test {}".format(len(test_item_names)))
301
+ return train_item_names, test_item_names
302
+
303
+ def load_meta_data(self):
304
+ raw_data_dir = hparams['raw_data_dir']
305
+ # meta_midi = json.load(open(os.path.join(raw_data_dir, 'meta.json'))) # [list of dict]
306
+ utterance_labels = open(os.path.join(raw_data_dir, 'transcriptions.txt')).readlines()
307
+
308
+ for utterance_label in utterance_labels:
309
+ song_info = utterance_label.split('|')
310
+ item_name = raw_item_name = song_info[0]
311
+ self.item2wavfn[item_name] = f'{raw_data_dir}/wavs/{item_name}.wav'
312
+ self.item2txt[item_name] = song_info[1]
313
+
314
+ self.item2ph[item_name] = song_info[2]
315
+ # self.item2wdb[item_name] = list(np.nonzero([1 if x in ALL_YUNMU + ['AP', 'SP'] else 0 for x in song_info[2].split()])[0])
316
+ self.item2wdb[item_name] = [1 if x in ALL_YUNMU + ['AP', 'SP'] else 0 for x in song_info[2].split()]
317
+ self.item2ph_durs[item_name] = [float(x) for x in song_info[5].split(" ")]
318
+
319
+ self.item2midi[item_name] = [librosa.note_to_midi(x.split("/")[0]) if x != 'rest' else 0
320
+ for x in song_info[3].split(" ")]
321
+ self.item2midi_dur[item_name] = [float(x) for x in song_info[4].split(" ")]
322
+ self.item2is_slur[item_name] = [int(x) for x in song_info[6].split(" ")]
323
+ self.item2spk[item_name] = 'opencpop'
324
+
325
+ print('spkers: ', set(self.item2spk.values()))
326
+ self.item_names = sorted(list(self.item2txt.keys()))
327
+ if self.binarization_args['shuffle']:
328
+ random.seed(1234)
329
+ random.shuffle(self.item_names)
330
+ self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
331
+
332
+ @staticmethod
333
+ def get_pitch(wav_fn, wav, spec, ph, res):
334
+ wav_suffix = '.wav'
335
+ # midi_suffix = '.mid'
336
+ wav_dir = 'wavs'
337
+ f0_dir = 'text_f0_align'
338
+
339
+ item_name = os.path.splitext(os.path.basename(wav_fn))[0]
340
+ res['pitch_midi'] = np.asarray(OpencpopBinarizer.item2midi[item_name])
341
+ res['midi_dur'] = np.asarray(OpencpopBinarizer.item2midi_dur[item_name])
342
+ res['is_slur'] = np.asarray(OpencpopBinarizer.item2is_slur[item_name])
343
+ res['word_boundary'] = np.asarray(OpencpopBinarizer.item2wdb[item_name])
344
+ assert res['pitch_midi'].shape == res['midi_dur'].shape == res['is_slur'].shape, (res['pitch_midi'].shape, res['midi_dur'].shape, res['is_slur'].shape)
345
+
346
+ # gt f0.
347
+ # f0 = None
348
+ # f0_suffix = '_f0.npy'
349
+ # f0fn = wav_fn.replace(wav_suffix, f0_suffix).replace(wav_dir, f0_dir)
350
+ # pitch_info = np.load(f0fn)
351
+ # f0 = [x[1] for x in pitch_info]
352
+ # spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)]
353
+ #
354
+ # f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)]
355
+ # f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)]
356
+ # if sum(f0) == 0:
357
+ # raise BinarizationError("Empty **gt** f0")
358
+ #
359
+ # pitch_coarse = f0_to_coarse(f0)
360
+ # res['f0'] = f0
361
+ # res['pitch'] = pitch_coarse
362
+
363
+ # gt f0.
364
+ gt_f0, gt_pitch_coarse = get_pitch(wav, spec, hparams)
365
+ if sum(gt_f0) == 0:
366
+ raise BinarizationError("Empty **gt** f0")
367
+ res['f0'] = gt_f0
368
+ res['pitch'] = gt_pitch_coarse
369
+
370
+ @classmethod
371
+ def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
372
+ if hparams['vocoder'] in VOCODERS:
373
+ wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
374
+ else:
375
+ wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
376
+ res = {
377
+ 'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
378
+ 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
379
+ }
380
+ try:
381
+ if binarization_args['with_f0']:
382
+ cls.get_pitch(wav_fn, wav, mel, ph, res)
383
+ if binarization_args['with_txt']:
384
+ try:
385
+ phone_encoded = res['phone'] = encoder.encode(ph)
386
+ except:
387
+ traceback.print_exc()
388
+ raise BinarizationError(f"Empty phoneme")
389
+ if binarization_args['with_align']:
390
+ cls.get_align(OpencpopBinarizer.item2ph_durs[item_name], mel, phone_encoded, res)
391
+ except BinarizationError as e:
392
+ print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
393
+ return None
394
+ return res
395
+
396
+
397
+ if __name__ == "__main__":
398
+ SingingBinarizer().process()
data_gen/tts/base_binarizer.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["OMP_NUM_THREADS"] = "1"
3
+
4
+ from utils.multiprocess_utils import chunked_multiprocess_run
5
+ import random
6
+ import traceback
7
+ import json
8
+ from resemblyzer import VoiceEncoder
9
+ from tqdm import tqdm
10
+ from data_gen.tts.data_gen_utils import get_mel2ph, get_pitch, build_phone_encoder
11
+ from utils.hparams import set_hparams, hparams
12
+ import numpy as np
13
+ from utils.indexed_datasets import IndexedDatasetBuilder
14
+ from vocoders.base_vocoder import VOCODERS
15
+ import pandas as pd
16
+
17
+
18
+ class BinarizationError(Exception):
19
+ pass
20
+
21
+
22
+ class BaseBinarizer:
23
+ def __init__(self, processed_data_dir=None):
24
+ if processed_data_dir is None:
25
+ processed_data_dir = hparams['processed_data_dir']
26
+ self.processed_data_dirs = processed_data_dir.split(",")
27
+ self.binarization_args = hparams['binarization_args']
28
+ self.pre_align_args = hparams['pre_align_args']
29
+ self.forced_align = self.pre_align_args['forced_align']
30
+ tg_dir = None
31
+ if self.forced_align == 'mfa':
32
+ tg_dir = 'mfa_outputs'
33
+ if self.forced_align == 'kaldi':
34
+ tg_dir = 'kaldi_outputs'
35
+ self.item2txt = {}
36
+ self.item2ph = {}
37
+ self.item2wavfn = {}
38
+ self.item2tgfn = {}
39
+ self.item2spk = {}
40
+ for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
41
+ self.meta_df = pd.read_csv(f"{processed_data_dir}/metadata_phone.csv", dtype=str)
42
+ for r_idx, r in self.meta_df.iterrows():
43
+ item_name = raw_item_name = r['item_name']
44
+ if len(self.processed_data_dirs) > 1:
45
+ item_name = f'ds{ds_id}_{item_name}'
46
+ self.item2txt[item_name] = r['txt']
47
+ self.item2ph[item_name] = r['ph']
48
+ self.item2wavfn[item_name] = os.path.join(hparams['raw_data_dir'], 'wavs', os.path.basename(r['wav_fn']).split('_')[1])
49
+ self.item2spk[item_name] = r.get('spk', 'SPK1')
50
+ if len(self.processed_data_dirs) > 1:
51
+ self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
52
+ if tg_dir is not None:
53
+ self.item2tgfn[item_name] = f"{processed_data_dir}/{tg_dir}/{raw_item_name}.TextGrid"
54
+ self.item_names = sorted(list(self.item2txt.keys()))
55
+ if self.binarization_args['shuffle']:
56
+ random.seed(1234)
57
+ random.shuffle(self.item_names)
58
+
59
+ @property
60
+ def train_item_names(self):
61
+ return self.item_names[hparams['test_num']+hparams['valid_num']:]
62
+
63
+ @property
64
+ def valid_item_names(self):
65
+ return self.item_names[0: hparams['test_num']+hparams['valid_num']] #
66
+
67
+ @property
68
+ def test_item_names(self):
69
+ return self.item_names[0: hparams['test_num']] # Audios for MOS testing are in 'test_ids'
70
+
71
+ def build_spk_map(self):
72
+ spk_map = set()
73
+ for item_name in self.item_names:
74
+ spk_name = self.item2spk[item_name]
75
+ spk_map.add(spk_name)
76
+ spk_map = {x: i for i, x in enumerate(sorted(list(spk_map)))}
77
+ assert len(spk_map) == 0 or len(spk_map) <= hparams['num_spk'], len(spk_map)
78
+ return spk_map
79
+
80
+ def item_name2spk_id(self, item_name):
81
+ return self.spk_map[self.item2spk[item_name]]
82
+
83
+ def _phone_encoder(self):
84
+ ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
85
+ ph_set = []
86
+ if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn):
87
+ for processed_data_dir in self.processed_data_dirs:
88
+ ph_set += [x.split(' ')[0] for x in open(f'{processed_data_dir}/dict.txt').readlines()]
89
+ ph_set = sorted(set(ph_set))
90
+ json.dump(ph_set, open(ph_set_fn, 'w'))
91
+ else:
92
+ ph_set = json.load(open(ph_set_fn, 'r'))
93
+ print("| phone set: ", ph_set)
94
+ return build_phone_encoder(hparams['binary_data_dir'])
95
+
96
+ def meta_data(self, prefix):
97
+ if prefix == 'valid':
98
+ item_names = self.valid_item_names
99
+ elif prefix == 'test':
100
+ item_names = self.test_item_names
101
+ else:
102
+ item_names = self.train_item_names
103
+ for item_name in item_names:
104
+ ph = self.item2ph[item_name]
105
+ txt = self.item2txt[item_name]
106
+ tg_fn = self.item2tgfn.get(item_name)
107
+ wav_fn = self.item2wavfn[item_name]
108
+ spk_id = self.item_name2spk_id(item_name)
109
+ yield item_name, ph, txt, tg_fn, wav_fn, spk_id
110
+
111
+ def process(self):
112
+ os.makedirs(hparams['binary_data_dir'], exist_ok=True)
113
+ self.spk_map = self.build_spk_map()
114
+ print("| spk_map: ", self.spk_map)
115
+ spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json"
116
+ json.dump(self.spk_map, open(spk_map_fn, 'w'))
117
+
118
+ self.phone_encoder = self._phone_encoder()
119
+ self.process_data('valid')
120
+ self.process_data('test')
121
+ self.process_data('train')
122
+
123
+ def process_data(self, prefix):
124
+ data_dir = hparams['binary_data_dir']
125
+ args = []
126
+ builder = IndexedDatasetBuilder(f'{data_dir}/{prefix}')
127
+ lengths = []
128
+ f0s = []
129
+ total_sec = 0
130
+ if self.binarization_args['with_spk_embed']:
131
+ voice_encoder = VoiceEncoder().cuda()
132
+
133
+ meta_data = list(self.meta_data(prefix))
134
+ for m in meta_data:
135
+ args.append(list(m) + [self.phone_encoder, self.binarization_args])
136
+ num_workers = int(os.getenv('N_PROC', os.cpu_count() // 3))
137
+ for f_id, (_, item) in enumerate(
138
+ zip(tqdm(meta_data), chunked_multiprocess_run(self.process_item, args, num_workers=num_workers))):
139
+ if item is None:
140
+ continue
141
+ item['spk_embed'] = voice_encoder.embed_utterance(item['wav']) \
142
+ if self.binarization_args['with_spk_embed'] else None
143
+ if not self.binarization_args['with_wav'] and 'wav' in item:
144
+ print("del wav")
145
+ del item['wav']
146
+ builder.add_item(item)
147
+ lengths.append(item['len'])
148
+ total_sec += item['sec']
149
+ if item.get('f0') is not None:
150
+ f0s.append(item['f0'])
151
+ builder.finalize()
152
+ np.save(f'{data_dir}/{prefix}_lengths.npy', lengths)
153
+ if len(f0s) > 0:
154
+ f0s = np.concatenate(f0s, 0)
155
+ f0s = f0s[f0s != 0]
156
+ np.save(f'{data_dir}/{prefix}_f0s_mean_std.npy', [np.mean(f0s).item(), np.std(f0s).item()])
157
+ print(f"| {prefix} total duration: {total_sec:.3f}s")
158
+
159
+ @classmethod
160
+ def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
161
+ if hparams['vocoder'] in VOCODERS:
162
+ wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
163
+ else:
164
+ wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
165
+ res = {
166
+ 'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
167
+ 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
168
+ }
169
+ try:
170
+ if binarization_args['with_f0']:
171
+ cls.get_pitch(wav, mel, res)
172
+ if binarization_args['with_f0cwt']:
173
+ cls.get_f0cwt(res['f0'], res)
174
+ if binarization_args['with_txt']:
175
+ try:
176
+ phone_encoded = res['phone'] = encoder.encode(ph)
177
+ except:
178
+ traceback.print_exc()
179
+ raise BinarizationError(f"Empty phoneme")
180
+ if binarization_args['with_align']:
181
+ cls.get_align(tg_fn, ph, mel, phone_encoded, res)
182
+ except BinarizationError as e:
183
+ print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
184
+ return None
185
+ return res
186
+
187
+ @staticmethod
188
+ def get_align(tg_fn, ph, mel, phone_encoded, res):
189
+ if tg_fn is not None and os.path.exists(tg_fn):
190
+ mel2ph, dur = get_mel2ph(tg_fn, ph, mel, hparams)
191
+ else:
192
+ raise BinarizationError(f"Align not found")
193
+ if mel2ph.max() - 1 >= len(phone_encoded):
194
+ raise BinarizationError(
195
+ f"Align does not match: mel2ph.max() - 1: {mel2ph.max() - 1}, len(phone_encoded): {len(phone_encoded)}")
196
+ res['mel2ph'] = mel2ph
197
+ res['dur'] = dur
198
+
199
+ @staticmethod
200
+ def get_pitch(wav, mel, res):
201
+ f0, pitch_coarse = get_pitch(wav, mel, hparams)
202
+ if sum(f0) == 0:
203
+ raise BinarizationError("Empty f0")
204
+ res['f0'] = f0
205
+ res['pitch'] = pitch_coarse
206
+
207
+ @staticmethod
208
+ def get_f0cwt(f0, res):
209
+ from utils.cwt import get_cont_lf0, get_lf0_cwt
210
+ uv, cont_lf0_lpf = get_cont_lf0(f0)
211
+ logf0s_mean_org, logf0s_std_org = np.mean(cont_lf0_lpf), np.std(cont_lf0_lpf)
212
+ cont_lf0_lpf_norm = (cont_lf0_lpf - logf0s_mean_org) / logf0s_std_org
213
+ Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm)
214
+ if np.any(np.isnan(Wavelet_lf0)):
215
+ raise BinarizationError("NaN CWT")
216
+ res['cwt_spec'] = Wavelet_lf0
217
+ res['cwt_scales'] = scales
218
+ res['f0_mean'] = logf0s_mean_org
219
+ res['f0_std'] = logf0s_std_org
220
+
221
+
222
+ if __name__ == "__main__":
223
+ set_hparams()
224
+ BaseBinarizer().process()
data_gen/tts/bin/binarize.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["OMP_NUM_THREADS"] = "1"
4
+
5
+ import importlib
6
+ from utils.hparams import set_hparams, hparams
7
+
8
+
9
+ def binarize():
10
+ binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer')
11
+ pkg = ".".join(binarizer_cls.split(".")[:-1])
12
+ cls_name = binarizer_cls.split(".")[-1]
13
+ binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
14
+ print("| Binarizer: ", binarizer_cls)
15
+ binarizer_cls().process()
16
+
17
+
18
+ if __name__ == '__main__':
19
+ set_hparams()
20
+ binarize()
data_gen/tts/binarizer_zh.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["OMP_NUM_THREADS"] = "1"
4
+
5
+ from data_gen.tts.txt_processors.zh_g2pM import ALL_SHENMU
6
+ from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError
7
+ from data_gen.tts.data_gen_utils import get_mel2ph
8
+ from utils.hparams import set_hparams, hparams
9
+ import numpy as np
10
+
11
+
12
+ class ZhBinarizer(BaseBinarizer):
13
+ @staticmethod
14
+ def get_align(tg_fn, ph, mel, phone_encoded, res):
15
+ if tg_fn is not None and os.path.exists(tg_fn):
16
+ _, dur = get_mel2ph(tg_fn, ph, mel, hparams)
17
+ else:
18
+ raise BinarizationError(f"Align not found")
19
+ ph_list = ph.split(" ")
20
+ assert len(dur) == len(ph_list)
21
+ mel2ph = []
22
+ # 分隔符的时长分配给韵母
23
+ dur_cumsum = np.pad(np.cumsum(dur), [1, 0], mode='constant', constant_values=0)
24
+ for i in range(len(dur)):
25
+ p = ph_list[i]
26
+ if p[0] != '<' and not p[0].isalpha():
27
+ uv_ = res['f0'][dur_cumsum[i]:dur_cumsum[i + 1]] == 0
28
+ j = 0
29
+ while j < len(uv_) and not uv_[j]:
30
+ j += 1
31
+ dur[i - 1] += j
32
+ dur[i] -= j
33
+ if dur[i] < 100:
34
+ dur[i - 1] += dur[i]
35
+ dur[i] = 0
36
+ # 声母和韵母等长
37
+ for i in range(len(dur)):
38
+ p = ph_list[i]
39
+ if p in ALL_SHENMU:
40
+ p_next = ph_list[i + 1]
41
+ if not (dur[i] > 0 and p_next[0].isalpha() and p_next not in ALL_SHENMU):
42
+ print(f"assert dur[i] > 0 and p_next[0].isalpha() and p_next not in ALL_SHENMU, "
43
+ f"dur[i]: {dur[i]}, p: {p}, p_next: {p_next}.")
44
+ continue
45
+ total = dur[i + 1] + dur[i]
46
+ dur[i] = total // 2
47
+ dur[i + 1] = total - dur[i]
48
+ for i in range(len(dur)):
49
+ mel2ph += [i + 1] * dur[i]
50
+ mel2ph = np.array(mel2ph)
51
+ if mel2ph.max() - 1 >= len(phone_encoded):
52
+ raise BinarizationError(f"| Align does not match: {(mel2ph.max() - 1, len(phone_encoded))}")
53
+ res['mel2ph'] = mel2ph
54
+ res['dur'] = dur
55
+
56
+
57
+ if __name__ == "__main__":
58
+ set_hparams()
59
+ ZhBinarizer().process()
data_gen/tts/data_gen_utils.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ warnings.filterwarnings("ignore")
4
+
5
+ import parselmouth
6
+ import os
7
+ import torch
8
+ from skimage.transform import resize
9
+ from utils.text_encoder import TokenTextEncoder
10
+ from utils.pitch_utils import f0_to_coarse
11
+ import struct
12
+ import webrtcvad
13
+ from scipy.ndimage.morphology import binary_dilation
14
+ import librosa
15
+ import numpy as np
16
+ from utils import audio
17
+ import pyloudnorm as pyln
18
+ import re
19
+ import json
20
+ from collections import OrderedDict
21
+
22
+ PUNCS = '!,.?;:'
23
+
24
+ int16_max = (2 ** 15) - 1
25
+
26
+
27
+ def trim_long_silences(path, sr=None, return_raw_wav=False, norm=True, vad_max_silence_length=12):
28
+ """
29
+ Ensures that segments without voice in the waveform remain no longer than a
30
+ threshold determined by the VAD parameters in params.py.
31
+ :param wav: the raw waveform as a numpy array of floats
32
+ :param vad_max_silence_length: Maximum number of consecutive silent frames a segment can have.
33
+ :return: the same waveform with silences trimmed away (length <= original wav length)
34
+ """
35
+
36
+ ## Voice Activation Detection
37
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
38
+ # This sets the granularity of the VAD. Should not need to be changed.
39
+ sampling_rate = 16000
40
+ wav_raw, sr = librosa.core.load(path, sr=sr)
41
+
42
+ if norm:
43
+ meter = pyln.Meter(sr) # create BS.1770 meter
44
+ loudness = meter.integrated_loudness(wav_raw)
45
+ wav_raw = pyln.normalize.loudness(wav_raw, loudness, -20.0)
46
+ if np.abs(wav_raw).max() > 1.0:
47
+ wav_raw = wav_raw / np.abs(wav_raw).max()
48
+
49
+ wav = librosa.resample(wav_raw, sr, sampling_rate, res_type='kaiser_best')
50
+
51
+ vad_window_length = 30 # In milliseconds
52
+ # Number of frames to average together when performing the moving average smoothing.
53
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
54
+ vad_moving_average_width = 8
55
+
56
+ # Compute the voice detection window size
57
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
58
+
59
+ # Trim the end of the audio to have a multiple of the window size
60
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
61
+
62
+ # Convert the float waveform to 16-bit mono PCM
63
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
64
+
65
+ # Perform voice activation detection
66
+ voice_flags = []
67
+ vad = webrtcvad.Vad(mode=3)
68
+ for window_start in range(0, len(wav), samples_per_window):
69
+ window_end = window_start + samples_per_window
70
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
71
+ sample_rate=sampling_rate))
72
+ voice_flags = np.array(voice_flags)
73
+
74
+ # Smooth the voice detection with a moving average
75
+ def moving_average(array, width):
76
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
77
+ ret = np.cumsum(array_padded, dtype=float)
78
+ ret[width:] = ret[width:] - ret[:-width]
79
+ return ret[width - 1:] / width
80
+
81
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
82
+ audio_mask = np.round(audio_mask).astype(np.bool)
83
+
84
+ # Dilate the voiced regions
85
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
86
+ audio_mask = np.repeat(audio_mask, samples_per_window)
87
+ audio_mask = resize(audio_mask, (len(wav_raw),)) > 0
88
+ if return_raw_wav:
89
+ return wav_raw, audio_mask, sr
90
+ return wav_raw[audio_mask], audio_mask, sr
91
+
92
+
93
+ def process_utterance(wav_path,
94
+ fft_size=1024,
95
+ hop_size=256,
96
+ win_length=1024,
97
+ window="hann",
98
+ num_mels=80,
99
+ fmin=80,
100
+ fmax=7600,
101
+ eps=1e-6,
102
+ sample_rate=22050,
103
+ loud_norm=False,
104
+ min_level_db=-100,
105
+ return_linear=False,
106
+ trim_long_sil=False, vocoder='pwg'):
107
+ if isinstance(wav_path, str):
108
+ if trim_long_sil:
109
+ wav, _, _ = trim_long_silences(wav_path, sample_rate)
110
+ else:
111
+ wav, _ = librosa.core.load(wav_path, sr=sample_rate)
112
+ else:
113
+ wav = wav_path
114
+
115
+ if loud_norm:
116
+ meter = pyln.Meter(sample_rate) # create BS.1770 meter
117
+ loudness = meter.integrated_loudness(wav)
118
+ wav = pyln.normalize.loudness(wav, loudness, -22.0)
119
+ if np.abs(wav).max() > 1:
120
+ wav = wav / np.abs(wav).max()
121
+
122
+ # get amplitude spectrogram
123
+ x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,
124
+ win_length=win_length, window=window, pad_mode="constant")
125
+ spc = np.abs(x_stft) # (n_bins, T)
126
+
127
+ # get mel basis
128
+ fmin = 0 if fmin == -1 else fmin
129
+ fmax = sample_rate / 2 if fmax == -1 else fmax
130
+ mel_basis = librosa.filters.mel(sample_rate, fft_size, num_mels, fmin, fmax)
131
+ mel = mel_basis @ spc
132
+
133
+ if vocoder == 'pwg':
134
+ mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T)
135
+ else:
136
+ assert False, f'"{vocoder}" is not in ["pwg"].'
137
+
138
+ l_pad, r_pad = audio.librosa_pad_lr(wav, fft_size, hop_size, 1)
139
+ wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
140
+ wav = wav[:mel.shape[1] * hop_size]
141
+
142
+ if not return_linear:
143
+ return wav, mel
144
+ else:
145
+ spc = audio.amp_to_db(spc)
146
+ spc = audio.normalize(spc, {'min_level_db': min_level_db})
147
+ return wav, mel, spc
148
+
149
+
150
+ def get_pitch(wav_data, mel, hparams):
151
+ """
152
+
153
+ :param wav_data: [T]
154
+ :param mel: [T, 80]
155
+ :param hparams:
156
+ :return:
157
+ """
158
+ time_step = hparams['hop_size'] / hparams['audio_sample_rate'] * 1000
159
+ f0_min = 80
160
+ f0_max = 750
161
+
162
+ if hparams['hop_size'] == 128:
163
+ pad_size = 4
164
+ elif hparams['hop_size'] == 256:
165
+ pad_size = 2
166
+ else:
167
+ assert False
168
+
169
+ f0 = parselmouth.Sound(wav_data, hparams['audio_sample_rate']).to_pitch_ac(
170
+ time_step=time_step / 1000, voicing_threshold=0.6,
171
+ pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
172
+ lpad = pad_size * 2
173
+ rpad = len(mel) - len(f0) - lpad
174
+ f0 = np.pad(f0, [[lpad, rpad]], mode='constant')
175
+ # mel and f0 are extracted by 2 different libraries. we should force them to have the same length.
176
+ # Attention: we find that new version of some libraries could cause ``rpad'' to be a negetive value...
177
+ # Just to be sure, we recommend users to set up the same environments as them in requirements_auto.txt (by Anaconda)
178
+ delta_l = len(mel) - len(f0)
179
+ assert np.abs(delta_l) <= 8
180
+ if delta_l > 0:
181
+ f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)
182
+ f0 = f0[:len(mel)]
183
+ pitch_coarse = f0_to_coarse(f0)
184
+ return f0, pitch_coarse
185
+
186
+
187
+ def remove_empty_lines(text):
188
+ """remove empty lines"""
189
+ assert (len(text) > 0)
190
+ assert (isinstance(text, list))
191
+ text = [t.strip() for t in text]
192
+ if "" in text:
193
+ text.remove("")
194
+ return text
195
+
196
+
197
+ class TextGrid(object):
198
+ def __init__(self, text):
199
+ text = remove_empty_lines(text)
200
+ self.text = text
201
+ self.line_count = 0
202
+ self._get_type()
203
+ self._get_time_intval()
204
+ self._get_size()
205
+ self.tier_list = []
206
+ self._get_item_list()
207
+
208
+ def _extract_pattern(self, pattern, inc):
209
+ """
210
+ Parameters
211
+ ----------
212
+ pattern : regex to extract pattern
213
+ inc : increment of line count after extraction
214
+ Returns
215
+ -------
216
+ group : extracted info
217
+ """
218
+ try:
219
+ group = re.match(pattern, self.text[self.line_count]).group(1)
220
+ self.line_count += inc
221
+ except AttributeError:
222
+ raise ValueError("File format error at line %d:%s" % (self.line_count, self.text[self.line_count]))
223
+ return group
224
+
225
+ def _get_type(self):
226
+ self.file_type = self._extract_pattern(r"File type = \"(.*)\"", 2)
227
+
228
+ def _get_time_intval(self):
229
+ self.xmin = self._extract_pattern(r"xmin = (.*)", 1)
230
+ self.xmax = self._extract_pattern(r"xmax = (.*)", 2)
231
+
232
+ def _get_size(self):
233
+ self.size = int(self._extract_pattern(r"size = (.*)", 2))
234
+
235
+ def _get_item_list(self):
236
+ """Only supports IntervalTier currently"""
237
+ for itemIdx in range(1, self.size + 1):
238
+ tier = OrderedDict()
239
+ item_list = []
240
+ tier_idx = self._extract_pattern(r"item \[(.*)\]:", 1)
241
+ tier_class = self._extract_pattern(r"class = \"(.*)\"", 1)
242
+ if tier_class != "IntervalTier":
243
+ raise NotImplementedError("Only IntervalTier class is supported currently")
244
+ tier_name = self._extract_pattern(r"name = \"(.*)\"", 1)
245
+ tier_xmin = self._extract_pattern(r"xmin = (.*)", 1)
246
+ tier_xmax = self._extract_pattern(r"xmax = (.*)", 1)
247
+ tier_size = self._extract_pattern(r"intervals: size = (.*)", 1)
248
+ for i in range(int(tier_size)):
249
+ item = OrderedDict()
250
+ item["idx"] = self._extract_pattern(r"intervals \[(.*)\]", 1)
251
+ item["xmin"] = self._extract_pattern(r"xmin = (.*)", 1)
252
+ item["xmax"] = self._extract_pattern(r"xmax = (.*)", 1)
253
+ item["text"] = self._extract_pattern(r"text = \"(.*)\"", 1)
254
+ item_list.append(item)
255
+ tier["idx"] = tier_idx
256
+ tier["class"] = tier_class
257
+ tier["name"] = tier_name
258
+ tier["xmin"] = tier_xmin
259
+ tier["xmax"] = tier_xmax
260
+ tier["size"] = tier_size
261
+ tier["items"] = item_list
262
+ self.tier_list.append(tier)
263
+
264
+ def toJson(self):
265
+ _json = OrderedDict()
266
+ _json["file_type"] = self.file_type
267
+ _json["xmin"] = self.xmin
268
+ _json["xmax"] = self.xmax
269
+ _json["size"] = self.size
270
+ _json["tiers"] = self.tier_list
271
+ return json.dumps(_json, ensure_ascii=False, indent=2)
272
+
273
+
274
+ def get_mel2ph(tg_fn, ph, mel, hparams):
275
+ ph_list = ph.split(" ")
276
+ with open(tg_fn, "r") as f:
277
+ tg = f.readlines()
278
+ tg = remove_empty_lines(tg)
279
+ tg = TextGrid(tg)
280
+ tg = json.loads(tg.toJson())
281
+ split = np.ones(len(ph_list) + 1, np.float) * -1
282
+ tg_idx = 0
283
+ ph_idx = 0
284
+ tg_align = [x for x in tg['tiers'][-1]['items']]
285
+ tg_align_ = []
286
+ for x in tg_align:
287
+ x['xmin'] = float(x['xmin'])
288
+ x['xmax'] = float(x['xmax'])
289
+ if x['text'] in ['sil', 'sp', '', 'SIL', 'PUNC']:
290
+ x['text'] = ''
291
+ if len(tg_align_) > 0 and tg_align_[-1]['text'] == '':
292
+ tg_align_[-1]['xmax'] = x['xmax']
293
+ continue
294
+ tg_align_.append(x)
295
+ tg_align = tg_align_
296
+ tg_len = len([x for x in tg_align if x['text'] != ''])
297
+ ph_len = len([x for x in ph_list if not is_sil_phoneme(x)])
298
+ assert tg_len == ph_len, (tg_len, ph_len, tg_align, ph_list, tg_fn)
299
+ while tg_idx < len(tg_align) or ph_idx < len(ph_list):
300
+ if tg_idx == len(tg_align) and is_sil_phoneme(ph_list[ph_idx]):
301
+ split[ph_idx] = 1e8
302
+ ph_idx += 1
303
+ continue
304
+ x = tg_align[tg_idx]
305
+ if x['text'] == '' and ph_idx == len(ph_list):
306
+ tg_idx += 1
307
+ continue
308
+ assert ph_idx < len(ph_list), (tg_len, ph_len, tg_align, ph_list, tg_fn)
309
+ ph = ph_list[ph_idx]
310
+ if x['text'] == '' and not is_sil_phoneme(ph):
311
+ assert False, (ph_list, tg_align)
312
+ if x['text'] != '' and is_sil_phoneme(ph):
313
+ ph_idx += 1
314
+ else:
315
+ assert (x['text'] == '' and is_sil_phoneme(ph)) \
316
+ or x['text'].lower() == ph.lower() \
317
+ or x['text'].lower() == 'sil', (x['text'], ph)
318
+ split[ph_idx] = x['xmin']
319
+ if ph_idx > 0 and split[ph_idx - 1] == -1 and is_sil_phoneme(ph_list[ph_idx - 1]):
320
+ split[ph_idx - 1] = split[ph_idx]
321
+ ph_idx += 1
322
+ tg_idx += 1
323
+ assert tg_idx == len(tg_align), (tg_idx, [x['text'] for x in tg_align])
324
+ assert ph_idx >= len(ph_list) - 1, (ph_idx, ph_list, len(ph_list), [x['text'] for x in tg_align], tg_fn)
325
+ mel2ph = np.zeros([mel.shape[0]], np.int)
326
+ split[0] = 0
327
+ split[-1] = 1e8
328
+ for i in range(len(split) - 1):
329
+ assert split[i] != -1 and split[i] <= split[i + 1], (split[:-1],)
330
+ split = [int(s * hparams['audio_sample_rate'] / hparams['hop_size'] + 0.5) for s in split]
331
+ for ph_idx in range(len(ph_list)):
332
+ mel2ph[split[ph_idx]:split[ph_idx + 1]] = ph_idx + 1
333
+ mel2ph_torch = torch.from_numpy(mel2ph)
334
+ T_t = len(ph_list)
335
+ dur = mel2ph_torch.new_zeros([T_t + 1]).scatter_add(0, mel2ph_torch, torch.ones_like(mel2ph_torch))
336
+ dur = dur[1:].numpy()
337
+ return mel2ph, dur
338
+
339
+
340
+ def build_phone_encoder(data_dir):
341
+ phone_list_file = os.path.join(data_dir, 'phone_set.json')
342
+ phone_list = json.load(open(phone_list_file))
343
+ return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
344
+
345
+
346
+ def is_sil_phoneme(p):
347
+ return not p[0].isalpha()
data_gen/tts/txt_processors/base_text_processor.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ class BaseTxtProcessor:
2
+ @staticmethod
3
+ def sp_phonemes():
4
+ return ['|']
5
+
6
+ @classmethod
7
+ def process(cls, txt, pre_align_args):
8
+ raise NotImplementedError
data_gen/tts/txt_processors/en.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from data_gen.tts.data_gen_utils import PUNCS
3
+ from g2p_en import G2p
4
+ import unicodedata
5
+ from g2p_en.expand import normalize_numbers
6
+ from nltk import pos_tag
7
+ from nltk.tokenize import TweetTokenizer
8
+
9
+ from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor
10
+
11
+
12
+ class EnG2p(G2p):
13
+ word_tokenize = TweetTokenizer().tokenize
14
+
15
+ def __call__(self, text):
16
+ # preprocessing
17
+ words = EnG2p.word_tokenize(text)
18
+ tokens = pos_tag(words) # tuples of (word, tag)
19
+
20
+ # steps
21
+ prons = []
22
+ for word, pos in tokens:
23
+ if re.search("[a-z]", word) is None:
24
+ pron = [word]
25
+
26
+ elif word in self.homograph2features: # Check homograph
27
+ pron1, pron2, pos1 = self.homograph2features[word]
28
+ if pos.startswith(pos1):
29
+ pron = pron1
30
+ else:
31
+ pron = pron2
32
+ elif word in self.cmu: # lookup CMU dict
33
+ pron = self.cmu[word][0]
34
+ else: # predict for oov
35
+ pron = self.predict(word)
36
+
37
+ prons.extend(pron)
38
+ prons.extend([" "])
39
+
40
+ return prons[:-1]
41
+
42
+
43
+ class TxtProcessor(BaseTxtProcessor):
44
+ g2p = EnG2p()
45
+
46
+ @staticmethod
47
+ def preprocess_text(text):
48
+ text = normalize_numbers(text)
49
+ text = ''.join(char for char in unicodedata.normalize('NFD', text)
50
+ if unicodedata.category(char) != 'Mn') # Strip accents
51
+ text = text.lower()
52
+ text = re.sub("[\'\"()]+", "", text)
53
+ text = re.sub("[-]+", " ", text)
54
+ text = re.sub(f"[^ a-z{PUNCS}]", "", text)
55
+ text = re.sub(f" ?([{PUNCS}]) ?", r"\1", text) # !! -> !
56
+ text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> !
57
+ text = text.replace("i.e.", "that is")
58
+ text = text.replace("i.e.", "that is")
59
+ text = text.replace("etc.", "etc")
60
+ text = re.sub(f"([{PUNCS}])", r" \1 ", text)
61
+ text = re.sub(rf"\s+", r" ", text)
62
+ return text
63
+
64
+ @classmethod
65
+ def process(cls, txt, pre_align_args):
66
+ txt = cls.preprocess_text(txt).strip()
67
+ phs = cls.g2p(txt)
68
+ phs_ = []
69
+ n_word_sep = 0
70
+ for p in phs:
71
+ if p.strip() == '':
72
+ phs_ += ['|']
73
+ n_word_sep += 1
74
+ else:
75
+ phs_ += p.split(" ")
76
+ phs = phs_
77
+ assert n_word_sep + 1 == len(txt.split(" ")), (phs, f"\"{txt}\"")
78
+ return phs, txt
data_gen/tts/txt_processors/zh.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from pypinyin import pinyin, Style
3
+ from data_gen.tts.data_gen_utils import PUNCS
4
+ from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor
5
+ from utils.text_norm import NSWNormalizer
6
+
7
+
8
+ class TxtProcessor(BaseTxtProcessor):
9
+ table = {ord(f): ord(t) for f, t in zip(
10
+ u':,。!?【】()%#@&1234567890',
11
+ u':,.!?[]()%#@&1234567890')}
12
+
13
+ @staticmethod
14
+ def preprocess_text(text):
15
+ text = text.translate(TxtProcessor.table)
16
+ text = NSWNormalizer(text).normalize(remove_punc=False)
17
+ text = re.sub("[\'\"()]+", "", text)
18
+ text = re.sub("[-]+", " ", text)
19
+ text = re.sub(f"[^ A-Za-z\u4e00-\u9fff{PUNCS}]", "", text)
20
+ text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> !
21
+ text = re.sub(f"([{PUNCS}])", r" \1 ", text)
22
+ text = re.sub(rf"\s+", r"", text)
23
+ return text
24
+
25
+ @classmethod
26
+ def process(cls, txt, pre_align_args):
27
+ txt = cls.preprocess_text(txt)
28
+ shengmu = pinyin(txt, style=Style.INITIALS) # https://blog.csdn.net/zhoulei124/article/details/89055403
29
+ yunmu_finals = pinyin(txt, style=Style.FINALS)
30
+ yunmu_tone3 = pinyin(txt, style=Style.FINALS_TONE3)
31
+ yunmu = [[t[0] + '5'] if t[0] == f[0] else t for f, t in zip(yunmu_finals, yunmu_tone3)] \
32
+ if pre_align_args['use_tone'] else yunmu_finals
33
+
34
+ assert len(shengmu) == len(yunmu)
35
+ phs = ["|"]
36
+ for a, b, c in zip(shengmu, yunmu, yunmu_finals):
37
+ if a[0] == c[0]:
38
+ phs += [a[0], "|"]
39
+ else:
40
+ phs += [a[0], b[0], "|"]
41
+ return phs, txt
data_gen/tts/txt_processors/zh_g2pM.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import jieba
3
+ from pypinyin import pinyin, Style
4
+ from data_gen.tts.data_gen_utils import PUNCS
5
+ from data_gen.tts.txt_processors import zh
6
+ from g2pM import G2pM
7
+
8
+ ALL_SHENMU = ['zh', 'ch', 'sh', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'j',
9
+ 'q', 'x', 'r', 'z', 'c', 's', 'y', 'w']
10
+ ALL_YUNMU = ['a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'ia', 'ian',
11
+ 'iang', 'iao', 'ie', 'in', 'ing', 'iong', 'iu', 'ng', 'o', 'ong', 'ou',
12
+ 'u', 'ua', 'uai', 'uan', 'uang', 'ui', 'un', 'uo', 'v', 'van', 've', 'vn']
13
+
14
+
15
+ class TxtProcessor(zh.TxtProcessor):
16
+ model = G2pM()
17
+
18
+ @staticmethod
19
+ def sp_phonemes():
20
+ return ['|', '#']
21
+
22
+ @classmethod
23
+ def process(cls, txt, pre_align_args):
24
+ txt = cls.preprocess_text(txt)
25
+ ph_list = cls.model(txt, tone=pre_align_args['use_tone'], char_split=True)
26
+ seg_list = '#'.join(jieba.cut(txt))
27
+ assert len(ph_list) == len([s for s in seg_list if s != '#']), (ph_list, seg_list)
28
+
29
+ # 加入词边界'#'
30
+ ph_list_ = []
31
+ seg_idx = 0
32
+ for p in ph_list:
33
+ p = p.replace("u:", "v")
34
+ if seg_list[seg_idx] == '#':
35
+ ph_list_.append('#')
36
+ seg_idx += 1
37
+ else:
38
+ ph_list_.append("|")
39
+ seg_idx += 1
40
+ if re.findall('[\u4e00-\u9fff]', p):
41
+ if pre_align_args['use_tone']:
42
+ p = pinyin(p, style=Style.TONE3, strict=True)[0][0]
43
+ if p[-1] not in ['1', '2', '3', '4', '5']:
44
+ p = p + '5'
45
+ else:
46
+ p = pinyin(p, style=Style.NORMAL, strict=True)[0][0]
47
+
48
+ finished = False
49
+ if len([c.isalpha() for c in p]) > 1:
50
+ for shenmu in ALL_SHENMU:
51
+ if p.startswith(shenmu) and not p.lstrip(shenmu).isnumeric():
52
+ ph_list_ += [shenmu, p.lstrip(shenmu)]
53
+ finished = True
54
+ break
55
+ if not finished:
56
+ ph_list_.append(p)
57
+
58
+ ph_list = ph_list_
59
+
60
+ # 去除静音符号周围的词边界标记 [..., '#', ',', '#', ...]
61
+ sil_phonemes = list(PUNCS) + TxtProcessor.sp_phonemes()
62
+ ph_list_ = []
63
+ for i in range(0, len(ph_list), 1):
64
+ if ph_list[i] != '#' or (ph_list[i - 1] not in sil_phonemes and ph_list[i + 1] not in sil_phonemes):
65
+ ph_list_.append(ph_list[i])
66
+ ph_list = ph_list_
67
+ return ph_list, txt
68
+
69
+
70
+ if __name__ == '__main__':
71
+ phs, txt = TxtProcessor.process('他来到了,网易杭研大厦', {'use_tone': True})
72
+ print(phs)
docs/README-SVS-opencpop-cascade.md ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism
2
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446)
3
+ [![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
4
+ [![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
5
+
6
+ ## DiffSinger (MIDI version SVS)
7
+ ### 0. Data Acquirement
8
+ For Opencpop dataset: Please strictly follow the instructions of [Opencpop](https://wenet.org.cn/opencpop/). We have no right to give you the access to Opencpop.
9
+
10
+ The pipeline below is designed for Opencpop dataset:
11
+
12
+ ### 1. Preparation
13
+
14
+ #### Data Preparation
15
+ a) Download and extract Opencpop, then create a link to the dataset folder: `ln -s /xxx/opencpop data/raw/`
16
+
17
+ b) Run the following scripts to pack the dataset for training/inference.
18
+
19
+ ```sh
20
+ export PYTHONPATH=.
21
+ CUDA_VISIBLE_DEVICES=0 python data_gen/tts/bin/binarize.py --config usr/configs/midi/cascade/opencs/aux_rel.yaml
22
+
23
+ # `data/binary/opencpop-midi-dp` will be generated.
24
+ ```
25
+
26
+ #### Vocoder Preparation
27
+ We provide the pre-trained model of [HifiGAN-Singing](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0109_hifigan_bigpopcs_hop128.zip) which is specially designed for SVS with NSF mechanism.
28
+ Please unzip this file into `checkpoints` before training your acoustic model.
29
+
30
+ (Update: You can also move [a ckpt with more training steps](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/model_ckpt_steps_1512000.ckpt) into this vocoder directory)
31
+
32
+ This singing vocoder is trained on ~70 hours singing data, which can be viewed as a universal vocoder.
33
+
34
+ #### Exp Name Preparation
35
+ ```bash
36
+ export MY_FS_EXP_NAME=0302_opencpop_fs_midi
37
+ export MY_DS_EXP_NAME=0303_opencpop_ds58_midi
38
+ ```
39
+
40
+ ```
41
+ .
42
+ |--data
43
+ |--raw
44
+ |--opencpop
45
+ |--segments
46
+ |--transcriptions.txt
47
+ |--wavs
48
+ |--checkpoints
49
+ |--MY_FS_EXP_NAME (optional)
50
+ |--MY_DS_EXP_NAME (optional)
51
+ |--0109_hifigan_bigpopcs_hop128
52
+ |--model_ckpt_steps_1512000.ckpt
53
+ |--config.yaml
54
+ ```
55
+
56
+ ### 2. Training Example
57
+ First, you need a pre-trained FFT-Singer checkpoint. You can use the pre-trained model, or train FFT-Singer from scratch, run:
58
+ ```sh
59
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/cascade/opencs/aux_rel.yaml --exp_name $MY_FS_EXP_NAME --reset
60
+ ```
61
+
62
+ Then, to train DiffSinger, run:
63
+
64
+ ```sh
65
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name $MY_DS_EXP_NAME --reset
66
+ ```
67
+
68
+ Remember to adjust the "fs2_ckpt" parameter in `usr/configs/midi/cascade/opencs/ds60_rel.yaml` to fit your path.
69
+
70
+ ### 3. Inference Example
71
+ ```sh
72
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name $MY_DS_EXP_NAME --reset --infer
73
+ ```
74
+
75
+ We also provide:
76
+ - the pre-trained model of DiffSinger;
77
+ - the pre-trained model of FFT-Singer;
78
+
79
+ They can be found in [here](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/adjust-receptive-field.zip).
80
+
81
+ Remember to put the pre-trained models in `checkpoints` directory.
82
+
83
+ ### 4. Inference from raw inputs
84
+ ```sh
85
+ python inference/svs/ds_e2e.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name $MY_DS_EXP_NAME
86
+ ```
87
+ Raw inputs:
88
+ ```
89
+ inp = {
90
+ 'text': '小酒窝长睫毛AP是你最美的记号',
91
+ 'notes': 'C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4',
92
+ 'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340',
93
+ 'input_type': 'word'
94
+ } # user input: Chinese characters
95
+ or,
96
+ inp = {
97
+ 'text': '小酒窝长睫毛AP是你最美的记号',
98
+ 'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
99
+ 'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
100
+ 'note_dur_seq': '0.407140 0.407140 0.376190 0.376190 0.242180 0.242180 0.509550 0.509550 0.183420 0.315400 0.315400 0.235020 0.361660 0.361660 0.223070 0.377270 0.377270 0.340550 0.340550 0.299620 0.299620 0.344510 0.344510 0.283770 0.283770 0.323390 0.323390 0.360340 0.360340',
101
+ 'is_slur_seq': '0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0',
102
+ 'input_type': 'phoneme'
103
+ } # input like Opencpop dataset.
104
+ ```
105
+
106
+ ### 5. Some issues.
107
+ a) the HifiGAN-Singing is trained on our [vocoder dataset](https://dl.acm.org/doi/abs/10.1145/3474085.3475437) and the training set of [PopCS](https://arxiv.org/abs/2105.02446). Opencpop is the out-of-domain dataset (unseen speaker). This may cause the deterioration of audio quality, and we are considering fine-tuning this vocoder on the training set of Opencpop.
108
+
109
+ b) in this version of codes, we used the melody frontend ([lyric + MIDI]->[F0+ph_dur]) to predict F0 contour and phoneme duration.
110
+
111
+ c) generated audio demos can be found in [MY_DS_EXP_NAME](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/adjust-receptive-field.zip).
docs/README-SVS-opencpop-e2e.md ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism
2
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446)
3
+ [![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
4
+ [![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
5
+
6
+ Substantial update: We 1) **abandon** the explicit prediction of the F0 curve; 2) increase the receptive field of the denoiser; 3) make the linguistic encoder more robust.
7
+ **By doing so, 1) the synthesized recordings are more natural in terms of pitch; 2) the pipeline is simpler.**
8
+
9
+ 简而言之,把F0曲线的动态性交给生成式模型去捕捉,而不再是以前那样用MSE约束对数域F0。
10
+
11
+ ## DiffSinger (MIDI version SVS)
12
+ ### 0. Data Acquirement
13
+ For Opencpop dataset: Please strictly follow the instructions of [Opencpop](https://wenet.org.cn/opencpop/). We have no right to give you the access to Opencpop.
14
+
15
+ The pipeline below is designed for Opencpop dataset:
16
+
17
+ ### 1. Preparation
18
+
19
+ #### Data Preparation
20
+ a) Download and extract Opencpop, then create a link to the dataset folder: `ln -s /xxx/opencpop data/raw/`
21
+
22
+ b) Run the following scripts to pack the dataset for training/inference.
23
+
24
+ ```sh
25
+ export PYTHONPATH=.
26
+ CUDA_VISIBLE_DEVICES=0 python data_gen/tts/bin/binarize.py --config usr/configs/midi/cascade/opencs/aux_rel.yaml
27
+
28
+ # `data/binary/opencpop-midi-dp` will be generated.
29
+ ```
30
+
31
+ #### Vocoder Preparation
32
+ We provide the pre-trained model of [HifiGAN-Singing](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0109_hifigan_bigpopcs_hop128.zip) which is specially designed for SVS with NSF mechanism.
33
+
34
+ Also, please unzip pre-trained vocoder and [this pendant for vocoder](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0102_xiaoma_pe.zip) into `checkpoints` before training your acoustic model.
35
+
36
+ (Update: You can also move [a ckpt with more training steps](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/model_ckpt_steps_1512000.ckpt) into this vocoder directory)
37
+
38
+ This singing vocoder is trained on ~70 hours singing data, which can be viewed as a universal vocoder.
39
+
40
+ #### Exp Name Preparation
41
+ ```bash
42
+ export MY_DS_EXP_NAME=0228_opencpop_ds100_rel
43
+ ```
44
+
45
+ ```
46
+ .
47
+ |--data
48
+ |--raw
49
+ |--opencpop
50
+ |--segments
51
+ |--transcriptions.txt
52
+ |--wavs
53
+ |--checkpoints
54
+ |--MY_DS_EXP_NAME (optional)
55
+ |--0109_hifigan_bigpopcs_hop128 (vocoder)
56
+ |--model_ckpt_steps_1512000.ckpt
57
+ |--config.yaml
58
+ ```
59
+
60
+ ### 2. Training Example
61
+ ```sh
62
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name $MY_DS_EXP_NAME --reset
63
+ ```
64
+
65
+ ### 3. Inference from packed test set
66
+ ```sh
67
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name $MY_DS_EXP_NAME --reset --infer
68
+ ```
69
+
70
+ We also provide:
71
+ - the pre-trained model of DiffSinger;
72
+
73
+ They can be found in [here](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0228_opencpop_ds100_rel.zip).
74
+
75
+ Remember to put the pre-trained models in `checkpoints` directory.
76
+
77
+ ### 4. Inference from raw inputs
78
+ ```sh
79
+ python inference/svs/ds_e2e.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name $MY_DS_EXP_NAME
80
+ ```
81
+ Raw inputs:
82
+ ```
83
+ inp = {
84
+ 'text': '小酒窝长睫毛AP是你最美的记号',
85
+ 'notes': 'C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4',
86
+ 'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340',
87
+ 'input_type': 'word'
88
+ } # user input: Chinese characters
89
+ or,
90
+ inp = {
91
+ 'text': '小酒窝长睫毛AP是你最美的记号',
92
+ 'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
93
+ 'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
94
+ 'note_dur_seq': '0.407140 0.407140 0.376190 0.376190 0.242180 0.242180 0.509550 0.509550 0.183420 0.315400 0.315400 0.235020 0.361660 0.361660 0.223070 0.377270 0.377270 0.340550 0.340550 0.299620 0.299620 0.344510 0.344510 0.283770 0.283770 0.323390 0.323390 0.360340 0.360340',
95
+ 'is_slur_seq': '0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0',
96
+ 'input_type': 'phoneme'
97
+ } # input like Opencpop dataset.
98
+ ```
99
+
100
+ ### 5. Some issues.
101
+ a) the HifiGAN-Singing is trained on our [vocoder dataset](https://dl.acm.org/doi/abs/10.1145/3474085.3475437) and the training set of [PopCS](https://arxiv.org/abs/2105.02446). Opencpop is the out-of-domain dataset (unseen speaker). This may cause the deterioration of audio quality, and we are considering fine-tuning this vocoder on the training set of Opencpop.
102
+
103
+ b) in this version of codes, we used the melody frontend ([lyric + MIDI]->[ph_dur]) to predict phoneme duration. F0 curve is implicitly predicted together with mel-spectrogram.
104
+
105
+ c) example [generated audio](https://github.com/MoonInTheRiver/DiffSinger/blob/master/resources/demos_0221/DS/).
106
+ More generated audio demos can be found in [DiffSinger](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0228_opencpop_ds100_rel.zip).
docs/README-SVS-popcs.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## DiffSinger (SVS version)
2
+
3
+ ### 0. Data Acquirement
4
+ - See in [apply_form](https://github.com/MoonInTheRiver/DiffSinger/blob/master/resources/apply_form.md).
5
+ - Dataset [preview](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/popcs_preview.zip).
6
+
7
+ ### 1. Preparation
8
+ #### Data Preparation
9
+ a) Download and extract PopCS, then create a link to the dataset folder: `ln -s /xxx/popcs/ data/processed/popcs`
10
+
11
+ b) Run the following scripts to pack the dataset for training/inference.
12
+ ```sh
13
+ export PYTHONPATH=.
14
+ CUDA_VISIBLE_DEVICES=0 python data_gen/tts/bin/binarize.py --config usr/configs/popcs_ds_beta6.yaml
15
+ # `data/binary/popcs-pmf0` will be generated.
16
+ ```
17
+
18
+ #### Vocoder Preparation
19
+ We provide the pre-trained model of [HifiGAN-Singing](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0109_hifigan_bigpopcs_hop128.zip) which is specially designed for SVS with NSF mechanism.
20
+ Please unzip this file into `checkpoints` before training your acoustic model.
21
+
22
+ (Update: You can also move [a ckpt with more training steps](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/model_ckpt_steps_1512000.ckpt) into this vocoder directory)
23
+
24
+ This singing vocoder is trained on ~70 hours singing data, which can be viewed as a universal vocoder.
25
+
26
+ ### 2. Training Example
27
+ First, you need a pre-trained FFT-Singer checkpoint. You can use the [pre-trained model](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/popcs_fs2_pmf0_1230.zip), or train FFT-Singer from scratch, run:
28
+
29
+ ```sh
30
+ # First, train fft-singer;
31
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/popcs_fs2.yaml --exp_name popcs_fs2_pmf0_1230 --reset
32
+ # Then, infer fft-singer;
33
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/popcs_fs2.yaml --exp_name popcs_fs2_pmf0_1230 --reset --infer
34
+ ```
35
+
36
+ Then, to train DiffSinger, run:
37
+ ```sh
38
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/popcs_ds_beta6_offline.yaml --exp_name popcs_ds_beta6_offline_pmf0_1230 --reset
39
+ ```
40
+
41
+ Remember to adjust the "fs2_ckpt" parameter in `usr/configs/popcs_ds_beta6_offline.yaml` to fit your path.
42
+
43
+ ### 3. Inference Example
44
+ ```sh
45
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/popcs_ds_beta6_offline.yaml --exp_name popcs_ds_beta6_offline_pmf0_1230 --reset --infer
46
+ ```
47
+
48
+ We also provide:
49
+ - the pre-trained model of [DiffSinger](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/popcs_ds_beta6_offline_pmf0_1230.zip);
50
+ - the pre-trained model of [FFT-Singer](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/popcs_fs2_pmf0_1230.zip) for the shallow diffusion mechanism in DiffSinger;
51
+
52
+ Remember to put the pre-trained models in `checkpoints` directory.
53
+
54
+ *Note that:*
55
+
56
+ - *the original PWG version vocoder in the paper we used has been put into commercial use, so we provide this HifiGAN version vocoder as a substitute.*
57
+ - *we assume the ground-truth F0 to be given as the pitch information following [1][2][3]. If you want to conduct experiments on MIDI data, you need an external F0 predictor (like [MIDI-old-version](README-SVS-opencpop-cascade.md)) or a joint prediction with spectrograms(like [MIDI-new-version](README-SVS-opencpop-e2e.md)).*
58
+
59
+ [1] Adversarially trained multi-singer sequence-to-sequence singing synthesizer. Interspeech 2020.
60
+
61
+ [2] SEQUENCE-TO-SEQUENCE SINGING SYNTHESIS USING THE FEED-FORWARD TRANSFORMER. ICASSP 2020.
62
+
63
+ [3] DeepSinger : Singing Voice Synthesis with Data Mined From the Web. KDD 2020.
docs/README-SVS.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## DiffSinger (SVS version)
2
+
3
+ ### PART1. [Run DiffSinger on PopCS](README-SVS-popcs.md)
4
+ In this part, we only focus on spectrum modeling (acoustic model) and assume the ground-truth (GT) F0 to be given as the pitch information following these papers [1][2][3].
5
+
6
+ Thus, the pipeline of this part can be summarized as:
7
+
8
+ ```
9
+ [lyrics] -> [linguistic representation] (Frontend)
10
+ [linguistic representation] + [GT F0] + [GT phoneme duration] -> [mel-spectrogram] (Acoustic model)
11
+ [mel-spectrogram] + [GT F0] -> [waveform] (Vocoder)
12
+ ```
13
+
14
+
15
+ [1] Adversarially trained multi-singer sequence-to-sequence singing synthesizer. Interspeech 2020.
16
+
17
+ [2] SEQUENCE-TO-SEQUENCE SINGING SYNTHESIS USING THE FEED-FORWARD TRANSFORMER. ICASSP 2020.
18
+
19
+ [3] DeepSinger : Singing Voice Synthesis with Data Mined From the Web. KDD 2020.
20
+
21
+ ### PART2. [Run DiffSinger on Opencpop](README-SVS-opencpop-cascade.md)
22
+ Thanks [Opencpop team](https://wenet.org.cn/opencpop/) for releasing their SVS dataset with MIDI label, **Jan.20, 2022**. (Also thanks to my co-author [Yi Ren](https://github.com/RayeRen), who applied for the dataset and did some preprocessing works for this part).
23
+
24
+ Since there are elaborately annotated MIDI labels, we are able to supplement the pipeline in PART 1 by adding a naive melody frontend.
25
+
26
+ #### 2.1
27
+ Thus, the pipeline of [this part](README-SVS-opencpop-cascade.md) can be summarized as:
28
+
29
+ ```
30
+ [lyrics] + [MIDI] -> [linguistic representation (with MIDI information)] + [predicted F0] + [predicted phoneme duration] (Melody frontend)
31
+ [linguistic representation] + [predicted F0] + [predicted phoneme duration] -> [mel-spectrogram] (Acoustic model)
32
+ [mel-spectrogram] + [predicted F0] -> [waveform] (Vocoder)
33
+ ```
34
+
35
+ #### 2.2
36
+ In 2.1, we find that if we predict F0 explicitly in the melody frontend, there will be many bad cases of uv/v prediction. Then, we abandon the explicit prediction of the F0 curve in the melody frontend but make a joint prediction with spectrograms.
37
+
38
+ Thus, the pipeline of [this part](README-SVS-opencpop-e2e.md) can be summarized as:
39
+ ```
40
+ [lyrics] + [MIDI] -> [linguistic representation] + [predicted phoneme duration] (Melody frontend)
41
+ [linguistic representation (with MIDI information)] + [predicted phoneme duration] -> [mel-spectrogram] (Acoustic model)
42
+ [mel-spectrogram] -> [predicted F0] (Pitch extractor)
43
+ [mel-spectrogram] + [predicted F0] -> [waveform] (Vocoder)
44
+ ```
docs/README-TTS.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## DiffSpeech (TTS version)
2
+ ### 1. Preparation
3
+
4
+ #### Data Preparation
5
+ a) Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/), then create a link to the dataset folder: `ln -s /xxx/LJSpeech-1.1/ data/raw/`
6
+
7
+ b) Download and Unzip the [ground-truth duration](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/mfa_outputs.tar) extracted by [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/download/v1.0.1/montreal-forced-aligner_linux.tar.gz): `tar -xvf mfa_outputs.tar; mv mfa_outputs data/processed/ljspeech/`
8
+
9
+ c) Run the following scripts to pack the dataset for training/inference.
10
+
11
+ ```sh
12
+ export PYTHONPATH=.
13
+ CUDA_VISIBLE_DEVICES=0 python data_gen/tts/bin/binarize.py --config configs/tts/lj/fs2.yaml
14
+
15
+ # `data/binary/ljspeech` will be generated.
16
+ ```
17
+
18
+ #### Vocoder Preparation
19
+ We provide the pre-trained model of [HifiGAN](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0414_hifi_lj_1.zip) vocoder.
20
+ Please unzip this file into `checkpoints` before training your acoustic model.
21
+
22
+ ### 2. Training Example
23
+
24
+ First, you need a pre-trained FastSpeech2 checkpoint. You can use the [pre-trained model](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/fs2_lj_1.zip), or train FastSpeech2 from scratch, run:
25
+ ```sh
26
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config configs/tts/lj/fs2.yaml --exp_name fs2_lj_1 --reset
27
+ ```
28
+ Then, to train DiffSpeech, run:
29
+ ```sh
30
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/lj_ds_beta6.yaml --exp_name lj_ds_beta6_1213 --reset
31
+ ```
32
+
33
+ Remember to adjust the "fs2_ckpt" parameter in `usr/configs/lj_ds_beta6.yaml` to fit your path.
34
+
35
+ ### 3. Inference Example
36
+
37
+ ```sh
38
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/lj_ds_beta6.yaml --exp_name lj_ds_beta6_1213 --reset --infer
39
+ ```
40
+
41
+ We also provide:
42
+ - the pre-trained model of [DiffSpeech](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/lj_ds_beta6_1213.zip);
43
+ - the individual pre-trained model of [FastSpeech 2](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/fs2_lj_1.zip) for the shallow diffusion mechanism in DiffSpeech;
44
+
45
+ Remember to put the pre-trained models in `checkpoints` directory.
46
+
47
+ ## Mel Visualization
48
+ Along vertical axis, DiffSpeech: [0-80]; FastSpeech2: [80-160].
49
+
50
+ <table style="width:100%">
51
+ <tr>
52
+ <th>DiffSpeech vs. FastSpeech 2</th>
53
+ </tr>
54
+ <tr>
55
+ <td><img src="resources/diffspeech-fs2.png" alt="DiffSpeech-vs-FastSpeech2" height="250"></td>
56
+ </tr>
57
+ <tr>
58
+ <td><img src="resources/diffspeech-fs2-1.png" alt="DiffSpeech-vs-FastSpeech2" height="250"></td>
59
+ </tr>
60
+ <tr>
61
+ <td><img src="resources/diffspeech-fs2-2.png" alt="DiffSpeech-vs-FastSpeech2" height="250"></td>
62
+ </tr>
63
+ </table>
docs/README-zh.md ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism
2
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446)
3
+ [![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
4
+ [![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
5
+ | [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/NATSpeech/DiffSpeech)
6
+ | [English README](../README.md)
7
+
8
+ 本仓库包含了我们的AAAI-2022 [论文](https://arxiv.org/abs/2105.02446)中提出的DiffSpeech (用于语音合成) 与 DiffSinger (用于歌声合成) 的官方Pytorch实现。
9
+
10
+ <table style="width:100%">
11
+ <tr>
12
+ <th>DiffSinger/DiffSpeech训练阶段</th>
13
+ <th>DiffSinger/DiffSpeech推理阶段</th>
14
+ </tr>
15
+ <tr>
16
+ <td><img src="resources/model_a.png" alt="Training" height="300"></td>
17
+ <td><img src="resources/model_b.png" alt="Inference" height="300"></td>
18
+ </tr>
19
+ </table>
20
+
21
+ :tada: :tada: :tada: **一些重要更新**:
22
+ - Mar.2, 2022: [MIDI-新版](README-SVS-opencpop-e2e.md): 重大更新 :sparkles:
23
+ - Mar.1, 2022: [NeuralSVB](https://github.com/MoonInTheRiver/NeuralSVB), 为了歌声美化任务的代码,开源了 :sparkles: :sparkles: :sparkles: .
24
+ - Feb.13, 2022: [NATSpeech](https://github.com/NATSpeech/NATSpeech), 一个升级后的代码框架, 包含了DiffSpeech和我们NeurIPS-2021的工作[PortaSpeech](https://openreview.net/forum?id=xmJsuh8xlq) 已经开源! :sparkles: :sparkles: :sparkles:.
25
+ - Jan.29, 2022: 支持了[MIDI-旧版](README-SVS-opencpop-cascade.md) 版本的歌声合成系统.
26
+ - Jan.13, 2022: 支持了歌声合成系统, 开源了PopCS数据集.
27
+ - Dec.19, 2021: 支持了语音合成系统. [HuggingFace🤗 Demo](https://huggingface.co/spaces/NATSpeech/DiffSpeech)
28
+
29
+ :rocket: **新闻**:
30
+ - Feb.24, 2022: 我们的新工作`NeuralSVB` 被 ACL-2022 接收 [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2202.13277). [音频演示](https://neuralsvb.github.io).
31
+ - Dec.01, 2021: DiffSinger被AAAI-2022接收.
32
+ - Sep.29, 2021: 我们的新工作`PortaSpeech: Portable and High-Quality Generative Text-to-Speech` 被NeurIPS-2021接收 [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2109.15166) .
33
+ - May.06, 2021: 我们把这篇DiffSinger提交到了公开论文网站: Arxiv [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446).
34
+
35
+ ## 安装依赖
36
+ ```sh
37
+ conda create -n your_env_name python=3.8
38
+ source activate your_env_name
39
+ pip install -r requirements_2080.txt (GPU 2080Ti, CUDA 10.2)
40
+ or pip install -r requirements_3090.txt (GPU 3090, CUDA 11.4)
41
+ ```
42
+
43
+ ## DiffSpeech (语音合成的版本)
44
+ ### 1. 准备工作
45
+
46
+ #### 数据准备
47
+ a) 下载并解压 [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/), 创建软链接: `ln -s /xxx/LJSpeech-1.1/ data/raw/`
48
+
49
+ b) 下载并解压 [我们用MFA预处理好的对齐](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/mfa_outputs.tar): `tar -xvf mfa_outputs.tar; mv mfa_outputs data/processed/ljspeech/`
50
+
51
+ c) 按照如下脚本给数据集打包,打包后的二进制文件用于后续的训练和推理.
52
+
53
+ ```sh
54
+ export PYTHONPATH=.
55
+ CUDA_VISIBLE_DEVICES=0 python data_gen/tts/bin/binarize.py --config configs/tts/lj/fs2.yaml
56
+
57
+ # `data/binary/ljspeech` will be generated.
58
+ ```
59
+
60
+ #### 声码器准备
61
+ 我们提供了[HifiGAN](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0414_hifi_lj_1.zip)声码器的预训练模型.
62
+ 请在训练声学模型前,先把声码器文件解压到`checkpoints`里。
63
+
64
+ ### 2. 训练样例
65
+
66
+ 首先你需要一个预训练好的FastSpeech2存档点. 你可以用[我们预训练好的模型](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/fs2_lj_1.zip), 或者跑下面这个指令从零开始训练FastSpeech2:
67
+ ```sh
68
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config configs/tts/lj/fs2.yaml --exp_name fs2_lj_1 --reset
69
+ ```
70
+ 然后为了训练DiffSpeech, 运行:
71
+ ```sh
72
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/lj_ds_beta6.yaml --exp_name lj_ds_beta6_1213 --reset
73
+ ```
74
+
75
+ 记得针对你的路径修改`usr/configs/lj_ds_beta6.yaml`里"fs2_ckpt"这个参数.
76
+
77
+ ### 3. 推理样例
78
+
79
+ ```sh
80
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/lj_ds_beta6.yaml --exp_name lj_ds_beta6_1213 --reset --infer
81
+ ```
82
+
83
+ 我们也提供了:
84
+ - [DiffSpeech](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/lj_ds_beta6_1213.zip)的预训练模型;
85
+ - [FastSpeech 2](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/fs2_lj_1.zip)的预训练模型, 这是为了DiffSpeech里的浅扩散机制;
86
+
87
+ 记得把预训练模型放在 `checkpoints` 目录.
88
+
89
+ ## DiffSinger (歌声合成的版本)
90
+
91
+ ### 0. 数据获取
92
+ - 见 [申请表](https://github.com/MoonInTheRiver/DiffSinger/blob/master/resources/apply_form.md).
93
+ - 数据集 [预览](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/popcs_preview.zip).
94
+
95
+ ### 1. Preparation
96
+ #### 数据准备
97
+ a) 下载并解压PopCS, 创建软链接: `ln -s /xxx/popcs/ data/processed/popcs`
98
+
99
+ b) 按照如下脚本给数据集打包,打包后的二进制文件用于后续的训练和推理.
100
+ ```sh
101
+ export PYTHONPATH=.
102
+ CUDA_VISIBLE_DEVICES=0 python data_gen/tts/bin/binarize.py --config usr/configs/popcs_ds_beta6.yaml
103
+ # `data/binary/popcs-pmf0` 会生成出来.
104
+ ```
105
+
106
+ #### 声码器准备
107
+ 我们提供了[HifiGAN-Singing](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0109_hifigan_bigpopcs_hop128.zip)的预训练模型, 它专门为了歌声合成系统设计, 采用了NSF的技术。
108
+ 请在训练声学模型前,先把声码器文件解压到`checkpoints`里。
109
+
110
+ (更新: 你也可以将我们提供的[训练更多步数的存档点](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/model_ckpt_steps_1512000.ckpt)放到声码器的文件夹里)
111
+
112
+ 这个声码器是在大约70小时的较大数据集上训练的, 可以被认为是一个通用声码器。
113
+
114
+ ### 2. 训练样例
115
+ 首先你需要一个预训练好的FFT-Singer. 你可以用[我们预训练好的模型](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/popcs_fs2_pmf0_1230.zip), 或者用如下脚本从零训练FFT-Singer:
116
+
117
+ ```sh
118
+ # First, train fft-singer;
119
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/popcs_fs2.yaml --exp_name popcs_fs2_pmf0_1230 --reset
120
+ # Then, infer fft-singer;
121
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/popcs_fs2.yaml --exp_name popcs_fs2_pmf0_1230 --reset --infer
122
+ ```
123
+
124
+ 然后, 为了训练DiffSinger, 运行:
125
+ ```sh
126
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/popcs_ds_beta6_offline.yaml --exp_name popcs_ds_beta6_offline_pmf0_1230 --reset
127
+ ```
128
+
129
+ 记得针对你的路径修改`usr/configs/popcs_ds_beta6_offline.yaml`里"fs2_ckpt"这个参数.
130
+
131
+ ### 3. 推理样例
132
+ ```sh
133
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/popcs_ds_beta6_offline.yaml --exp_name popcs_ds_beta6_offline_pmf0_1230 --reset --infer
134
+ ```
135
+
136
+ 我们也提供了:
137
+ - [DiffSinger](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/popcs_ds_beta6_offline_pmf0_1230.zip)的预训练模型;
138
+ - [FFT-Singer](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/popcs_fs2_pmf0_1230.zip)的预训练模型, 这是为了DiffSinger里的浅扩散机制;
139
+
140
+ 记得把预训练模型放在 `checkpoints` 目录.
141
+
142
+ *请注意:*
143
+
144
+ -*我们原始论文中的PWG版本声码器已投入商业使用,因此我们提供此HifiGAN版本声码器作为替代品。*
145
+
146
+ -*我们这篇论文假设提供真实的F0来进行实验,如[1][2][3]等前作所做的那样,重点在频谱建模上,而非F0曲线的预测。如果你想对MIDI数据进行实验,从MIDI和歌词预测F0曲线(显式或隐式),请查看文档[MIDI-old-version](README-SVS-opencpop-cascade.md) 或 [MIDI-new-version](README-SVS-opencpop-e2e.md)。目前已经支持的MIDI数据集有: Opencpop*
147
+
148
+ [1] Adversarially trained multi-singer sequence-to-sequence singing synthesizer. Interspeech 2020.
149
+
150
+ [2] SEQUENCE-TO-SEQUENCE SINGING SYNTHESIS USING THE FEED-FORWARD TRANSFORMER. ICASSP 2020.
151
+
152
+ [3] DeepSinger : Singing Voice Synthesis with Data Mined From the Web. KDD 2020.
153
+
154
+ ## Tensorboard
155
+ ```sh
156
+ tensorboard --logdir_spec exp_name
157
+ ```
158
+ <table style="width:100%">
159
+ <tr>
160
+ <td><img src="resources/tfb.png" alt="Tensorboard" height="250"></td>
161
+ </tr>
162
+ </table>
163
+
164
+ ## Mel 可视化
165
+ 沿着纵轴, DiffSpeech: [0-80]; FastSpeech2: [80-160].
166
+
167
+ <table style="width:100%">
168
+ <tr>
169
+ <th>DiffSpeech vs. FastSpeech 2</th>
170
+ </tr>
171
+ <tr>
172
+ <td><img src="resources/diffspeech-fs2.png" alt="DiffSpeech-vs-FastSpeech2" height="250"></td>
173
+ </tr>
174
+ <tr>
175
+ <td><img src="resources/diffspeech-fs2-1.png" alt="DiffSpeech-vs-FastSpeech2" height="250"></td>
176
+ </tr>
177
+ <tr>
178
+ <td><img src="resources/diffspeech-fs2-2.png" alt="DiffSpeech-vs-FastSpeech2" height="250"></td>
179
+ </tr>
180
+ </table>
181
+
182
+ ## Audio Demos
183
+ 音频样本可以看我们的[样例页](https://diffsinger.github.io/).
184
+
185
+ 我们也放了部分由DiffSpeech+HifiGAN (标记为[P]) 和 GTmel+HifiGAN (标记为[G]) 生成的测试集音频样例在:[resources/demos_1213](../resources/demos_1213).
186
+
187
+ (对应这个预训练参数:[DiffSpeech](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/lj_ds_beta6_1213.zip))
188
+
189
+ ---
190
+ :rocket: :rocket: :rocket: **更新:**
191
+
192
+ 新生成的歌声样例在:[resources/demos_0112](../resources/demos_0112).
193
+
194
+ ## Citation
195
+ 如果本仓库对你的研究和工作有用,请引用以下论文:
196
+
197
+ @article{liu2021diffsinger,
198
+ title={Diffsinger: Singing voice synthesis via shallow diffusion mechanism},
199
+ author={Liu, Jinglin and Li, Chengxi and Ren, Yi and Chen, Feiyang and Liu, Peng and Zhao, Zhou},
200
+ journal={arXiv preprint arXiv:2105.02446},
201
+ volume={2},
202
+ year={2021}}
203
+
204
+
205
+ ## 鸣谢
206
+ 我们的代码基于如下仓库:
207
+ * [denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch)
208
+ * [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning)
209
+ * [ParallelWaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN)
210
+ * [HifiGAN](https://github.com/jik876/hifi-gan)
211
+ * [espnet](https://github.com/espnet/espnet)
212
+ * [DiffWave](https://github.com/lmnt-com/diffwave)
inference/svs/base_svs_infer.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import numpy as np
5
+ from modules.hifigan.hifigan import HifiGanGenerator
6
+ from vocoders.hifigan import HifiGAN
7
+ from inference.svs.opencpop.map import cpop_pinyin2ph_func
8
+
9
+ from utils import load_ckpt
10
+ from utils.hparams import set_hparams, hparams
11
+ from utils.text_encoder import TokenTextEncoder
12
+ from pypinyin import pinyin, lazy_pinyin, Style
13
+ import librosa
14
+ import glob
15
+ import re
16
+
17
+
18
+ class BaseSVSInfer:
19
+ def __init__(self, hparams, device=None):
20
+ if device is None:
21
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
+ self.hparams = hparams
23
+ self.device = device
24
+
25
+ phone_list = ["AP", "SP", "a", "ai", "an", "ang", "ao", "b", "c", "ch", "d", "e", "ei", "en", "eng", "er", "f", "g",
26
+ "h", "i", "ia", "ian", "iang", "iao", "ie", "in", "ing", "iong", "iu", "j", "k", "l", "m", "n", "o",
27
+ "ong", "ou", "p", "q", "r", "s", "sh", "t", "u", "ua", "uai", "uan", "uang", "ui", "un", "uo", "v",
28
+ "van", "ve", "vn", "w", "x", "y", "z", "zh"]
29
+ self.ph_encoder = TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
30
+ self.pinyin2phs = cpop_pinyin2ph_func()
31
+ self.spk_map = {'opencpop': 0}
32
+
33
+ self.model = self.build_model()
34
+ self.model.eval()
35
+ self.model.to(self.device)
36
+ self.vocoder = self.build_vocoder()
37
+ self.vocoder.eval()
38
+ self.vocoder.to(self.device)
39
+
40
+ def build_model(self):
41
+ raise NotImplementedError
42
+
43
+ def forward_model(self, inp):
44
+ raise NotImplementedError
45
+
46
+ def build_vocoder(self):
47
+ base_dir = hparams['vocoder_ckpt']
48
+ config_path = f'{base_dir}/config.yaml'
49
+ ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
50
+ lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
51
+ print('| load HifiGAN: ', ckpt)
52
+ ckpt_dict = torch.load(ckpt, map_location="cpu")
53
+ config = set_hparams(config_path, global_hparams=False)
54
+ state = ckpt_dict["state_dict"]["model_gen"]
55
+ vocoder = HifiGanGenerator(config)
56
+ vocoder.load_state_dict(state, strict=True)
57
+ vocoder.remove_weight_norm()
58
+ vocoder = vocoder.eval().to(self.device)
59
+ return vocoder
60
+
61
+ def run_vocoder(self, c, **kwargs):
62
+ c = c.transpose(2, 1) # [B, 80, T]
63
+ f0 = kwargs.get('f0') # [B, T]
64
+ if f0 is not None and hparams.get('use_nsf'):
65
+ # f0 = torch.FloatTensor(f0).to(self.device)
66
+ y = self.vocoder(c, f0).view(-1)
67
+ else:
68
+ y = self.vocoder(c).view(-1)
69
+ # [T]
70
+ return y[None]
71
+
72
+ def preprocess_word_level_input(self, inp):
73
+ # Pypinyin can't solve polyphonic words
74
+ text_raw = inp['text'].replace('最长', '最常').replace('长睫毛', '常睫毛') \
75
+ .replace('那么长', '那么常').replace('多长', '多常') \
76
+ .replace('很长', '很常') # We hope someone could provide a better g2p module for us by opening pull requests.
77
+
78
+ # lyric
79
+ pinyins = lazy_pinyin(text_raw, strict=False)
80
+ ph_per_word_lst = [self.pinyin2phs[pinyin.strip()] for pinyin in pinyins if pinyin.strip() in self.pinyin2phs]
81
+
82
+ # Note
83
+ note_per_word_lst = [x.strip() for x in inp['notes'].split('|') if x.strip() != '']
84
+ mididur_per_word_lst = [x.strip() for x in inp['notes_duration'].split('|') if x.strip() != '']
85
+
86
+ if len(note_per_word_lst) == len(ph_per_word_lst) == len(mididur_per_word_lst):
87
+ print('Pass word-notes check.')
88
+ else:
89
+ print('The number of words does\'t match the number of notes\' windows. ',
90
+ 'You should split the note(s) for each word by | mark.')
91
+ print(ph_per_word_lst, note_per_word_lst, mididur_per_word_lst)
92
+ print(len(ph_per_word_lst), len(note_per_word_lst), len(mididur_per_word_lst))
93
+ return None
94
+
95
+ note_lst = []
96
+ ph_lst = []
97
+ midi_dur_lst = []
98
+ is_slur = []
99
+ for idx, ph_per_word in enumerate(ph_per_word_lst):
100
+ # for phs in one word:
101
+ # single ph like ['ai'] or multiple phs like ['n', 'i']
102
+ ph_in_this_word = ph_per_word.split()
103
+
104
+ # for notes in one word:
105
+ # single note like ['D4'] or multiple notes like ['D4', 'E4'] which means a 'slur' here.
106
+ note_in_this_word = note_per_word_lst[idx].split()
107
+ midi_dur_in_this_word = mididur_per_word_lst[idx].split()
108
+ # process for the model input
109
+ # Step 1.
110
+ # Deal with note of 'not slur' case or the first note of 'slur' case
111
+ # j ie
112
+ # F#4/Gb4 F#4/Gb4
113
+ # 0 0
114
+ for ph in ph_in_this_word:
115
+ ph_lst.append(ph)
116
+ note_lst.append(note_in_this_word[0])
117
+ midi_dur_lst.append(midi_dur_in_this_word[0])
118
+ is_slur.append(0)
119
+ # step 2.
120
+ # Deal with the 2nd, 3rd... notes of 'slur' case
121
+ # j ie ie
122
+ # F#4/Gb4 F#4/Gb4 C#4/Db4
123
+ # 0 0 1
124
+ if len(note_in_this_word) > 1: # is_slur = True, we should repeat the YUNMU to match the 2nd, 3rd... notes.
125
+ for idx in range(1, len(note_in_this_word)):
126
+ ph_lst.append(ph_in_this_word[1])
127
+ note_lst.append(note_in_this_word[idx])
128
+ midi_dur_lst.append(midi_dur_in_this_word[idx])
129
+ is_slur.append(1)
130
+ ph_seq = ' '.join(ph_lst)
131
+
132
+ if len(ph_lst) == len(note_lst) == len(midi_dur_lst):
133
+ print(len(ph_lst), len(note_lst), len(midi_dur_lst))
134
+ print('Pass word-notes check.')
135
+ else:
136
+ print('The number of words does\'t match the number of notes\' windows. ',
137
+ 'You should split the note(s) for each word by | mark.')
138
+ return None
139
+ return ph_seq, note_lst, midi_dur_lst, is_slur
140
+
141
+ def preprocess_phoneme_level_input(self, inp):
142
+ ph_seq = inp['ph_seq']
143
+ note_lst = inp['note_seq'].split()
144
+ midi_dur_lst = inp['note_dur_seq'].split()
145
+ is_slur = inp['is_slur_seq'].split()
146
+ print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst))
147
+ if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst):
148
+ print('Pass word-notes check.')
149
+ else:
150
+ print('The number of words does\'t match the number of notes\' windows. ',
151
+ 'You should split the note(s) for each word by | mark.')
152
+ return None
153
+ return ph_seq, note_lst, midi_dur_lst, is_slur
154
+
155
+ def preprocess_input(self, inp, input_type='word'):
156
+ """
157
+
158
+ :param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)}
159
+ :return:
160
+ """
161
+
162
+ item_name = inp.get('item_name', '<ITEM_NAME>')
163
+ spk_name = inp.get('spk_name', 'opencpop')
164
+
165
+ # single spk
166
+ spk_id = self.spk_map[spk_name]
167
+
168
+ # get ph seq, note lst, midi dur lst, is slur lst.
169
+ if input_type == 'word':
170
+ ret = self.preprocess_word_level_input(inp)
171
+ elif input_type == 'phoneme': # like transcriptions.txt in Opencpop dataset.
172
+ ret = self.preprocess_phoneme_level_input(inp)
173
+ else:
174
+ print('Invalid input type.')
175
+ return None
176
+
177
+ if ret:
178
+ ph_seq, note_lst, midi_dur_lst, is_slur = ret
179
+ else:
180
+ print('==========> Preprocess_word_level or phone_level input wrong.')
181
+ return None
182
+
183
+ # convert note lst to midi id; convert note dur lst to midi duration
184
+ try:
185
+ midis = [librosa.note_to_midi(x.split("/")[0]) if x != 'rest' else 0
186
+ for x in note_lst]
187
+ midi_dur_lst = [float(x) for x in midi_dur_lst]
188
+ except Exception as e:
189
+ print(e)
190
+ print('Invalid Input Type.')
191
+ return None
192
+
193
+ ph_token = self.ph_encoder.encode(ph_seq)
194
+ item = {'item_name': item_name, 'text': inp['text'], 'ph': ph_seq, 'spk_id': spk_id,
195
+ 'ph_token': ph_token, 'pitch_midi': np.asarray(midis), 'midi_dur': np.asarray(midi_dur_lst),
196
+ 'is_slur': np.asarray(is_slur), }
197
+ item['ph_len'] = len(item['ph_token'])
198
+ return item
199
+
200
+ def input_to_batch(self, item):
201
+ item_names = [item['item_name']]
202
+ text = [item['text']]
203
+ ph = [item['ph']]
204
+ txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device)
205
+ txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device)
206
+ spk_ids = torch.LongTensor(item['spk_id'])[None, :].to(self.device)
207
+
208
+ pitch_midi = torch.LongTensor(item['pitch_midi'])[None, :hparams['max_frames']].to(self.device)
209
+ midi_dur = torch.FloatTensor(item['midi_dur'])[None, :hparams['max_frames']].to(self.device)
210
+ is_slur = torch.LongTensor(item['is_slur'])[None, :hparams['max_frames']].to(self.device)
211
+
212
+ batch = {
213
+ 'item_name': item_names,
214
+ 'text': text,
215
+ 'ph': ph,
216
+ 'txt_tokens': txt_tokens,
217
+ 'txt_lengths': txt_lengths,
218
+ 'spk_ids': spk_ids,
219
+ 'pitch_midi': pitch_midi,
220
+ 'midi_dur': midi_dur,
221
+ 'is_slur': is_slur
222
+ }
223
+ return batch
224
+
225
+ def postprocess_output(self, output):
226
+ return output
227
+
228
+ def infer_once(self, inp):
229
+ inp = self.preprocess_input(inp, input_type=inp['input_type'] if inp.get('input_type') else 'word')
230
+ output = self.forward_model(inp)
231
+ output = self.postprocess_output(output)
232
+ return output
233
+
234
+ @classmethod
235
+ def example_run(cls, inp):
236
+ from utils.audio import save_wav
237
+ set_hparams(print_hparams=False)
238
+ infer_ins = cls(hparams)
239
+ out = infer_ins.infer_once(inp)
240
+ os.makedirs('infer_out', exist_ok=True)
241
+ save_wav(out, f'infer_out/example_out.wav', hparams['audio_sample_rate'])
242
+
243
+
244
+ # if __name__ == '__main__':
245
+ # debug
246
+ # a = BaseSVSInfer(hparams)
247
+ # a.preprocess_input({'text': '你 说 你 不 SP 懂 为 何 在 这 时 牵 手 AP',
248
+ # 'notes': 'D#4/Eb4 | D#4/Eb4 | D#4/Eb4 | D#4/Eb4 | rest | D#4/Eb4 | D4 | D4 | D4 | D#4/Eb4 | F4 | D#4/Eb4 | D4 | rest',
249
+ # 'notes_duration': '0.113740 | 0.329060 | 0.287950 | 0.133480 | 0.150900 | 0.484730 | 0.242010 | 0.180820 | 0.343570 | 0.152050 | 0.266720 | 0.280310 | 0.633300 | 0.444590'
250
+ # })
251
+
252
+ # b = {
253
+ # 'text': '小酒窝长睫毛AP是你最美的记号',
254
+ # 'notes': 'C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4',
255
+ # 'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340'
256
+ # }
257
+ # c = {
258
+ # 'text': '小酒窝长睫毛AP是你最美的记号',
259
+ # 'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
260
+ # 'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
261
+ # 'note_dur_seq': '0.407140 0.407140 0.376190 0.376190 0.242180 0.242180 0.509550 0.509550 0.183420 0.315400 0.315400 0.235020 0.361660 0.361660 0.223070 0.377270 0.377270 0.340550 0.340550 0.299620 0.299620 0.344510 0.344510 0.283770 0.283770 0.323390 0.323390 0.360340 0.360340',
262
+ # 'is_slur_seq': '0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0'
263
+ # } # input like Opencpop dataset.
264
+ # a.preprocess_input(b)
265
+ # a.preprocess_input(c, input_type='phoneme')
inference/svs/ds_cascade.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # from inference.tts.fs import FastSpeechInfer
3
+ # from modules.tts.fs2_orig import FastSpeech2Orig
4
+ from inference.svs.base_svs_infer import BaseSVSInfer
5
+ from utils import load_ckpt
6
+ from utils.hparams import hparams
7
+ from usr.diff.shallow_diffusion_tts import GaussianDiffusion
8
+ from usr.diffsinger_task import DIFF_DECODERS
9
+
10
+ class DiffSingerCascadeInfer(BaseSVSInfer):
11
+ def build_model(self):
12
+ model = GaussianDiffusion(
13
+ phone_encoder=self.ph_encoder,
14
+ out_dims=hparams['audio_num_mel_bins'], denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
15
+ timesteps=hparams['timesteps'],
16
+ K_step=hparams['K_step'],
17
+ loss_type=hparams['diff_loss_type'],
18
+ spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
19
+ )
20
+ model.eval()
21
+ load_ckpt(model, hparams['work_dir'], 'model')
22
+ return model
23
+
24
+ def forward_model(self, inp):
25
+ sample = self.input_to_batch(inp)
26
+ txt_tokens = sample['txt_tokens'] # [B, T_t]
27
+ spk_id = sample.get('spk_ids')
28
+ with torch.no_grad():
29
+ output = self.model(txt_tokens, spk_id=spk_id, ref_mels=None, infer=True,
30
+ pitch_midi=sample['pitch_midi'], midi_dur=sample['midi_dur'],
31
+ is_slur=sample['is_slur'])
32
+ mel_out = output['mel_out'] # [B, T,80]
33
+ f0_pred = output['f0_denorm']
34
+ wav_out = self.run_vocoder(mel_out, f0=f0_pred)
35
+ wav_out = wav_out.cpu().numpy()
36
+ return wav_out[0]
37
+
38
+
39
+ if __name__ == '__main__':
40
+ inp = {
41
+ 'text': '小酒窝长睫毛AP是你最美的记号',
42
+ 'notes': 'C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4',
43
+ 'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340',
44
+ 'input_type': 'word'
45
+ } # user input: Chinese characters
46
+ c = {
47
+ 'text': '小酒窝长睫毛AP是你最美的记号',
48
+ 'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
49
+ 'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
50
+ 'note_dur_seq': '0.407140 0.407140 0.376190 0.376190 0.242180 0.242180 0.509550 0.509550 0.183420 0.315400 0.315400 0.235020 0.361660 0.361660 0.223070 0.377270 0.377270 0.340550 0.340550 0.299620 0.299620 0.344510 0.344510 0.283770 0.283770 0.323390 0.323390 0.360340 0.360340',
51
+ 'is_slur_seq': '0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0',
52
+ 'input_type': 'phoneme'
53
+ } # input like Opencpop dataset.
54
+ DiffSingerCascadeInfer.example_run(inp)
inference/svs/ds_e2e.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # from inference.tts.fs import FastSpeechInfer
3
+ # from modules.tts.fs2_orig import FastSpeech2Orig
4
+ from inference.svs.base_svs_infer import BaseSVSInfer
5
+ from utils import load_ckpt
6
+ from utils.hparams import hparams
7
+ from usr.diff.shallow_diffusion_tts import GaussianDiffusion
8
+ from usr.diffsinger_task import DIFF_DECODERS
9
+ from modules.fastspeech.pe import PitchExtractor
10
+ import utils
11
+
12
+
13
+ class DiffSingerE2EInfer(BaseSVSInfer):
14
+ def build_model(self):
15
+ model = GaussianDiffusion(
16
+ phone_encoder=self.ph_encoder,
17
+ out_dims=hparams['audio_num_mel_bins'], denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
18
+ timesteps=hparams['timesteps'],
19
+ K_step=hparams['K_step'],
20
+ loss_type=hparams['diff_loss_type'],
21
+ spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
22
+ )
23
+ model.eval()
24
+ load_ckpt(model, hparams['work_dir'], 'model')
25
+
26
+ if hparams.get('pe_enable') is not None and hparams['pe_enable']:
27
+ self.pe = PitchExtractor().cuda()
28
+ utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
29
+ self.pe.eval()
30
+ return model
31
+
32
+ def forward_model(self, inp):
33
+ sample = self.input_to_batch(inp)
34
+ txt_tokens = sample['txt_tokens'] # [B, T_t]
35
+ spk_id = sample.get('spk_ids')
36
+ with torch.no_grad():
37
+ output = self.model(txt_tokens, spk_id=spk_id, ref_mels=None, infer=True,
38
+ pitch_midi=sample['pitch_midi'], midi_dur=sample['midi_dur'],
39
+ is_slur=sample['is_slur'])
40
+ mel_out = output['mel_out'] # [B, T,80]
41
+ if hparams.get('pe_enable') is not None and hparams['pe_enable']:
42
+ f0_pred = self.pe(mel_out)['f0_denorm_pred'] # pe predict from Pred mel
43
+ else:
44
+ f0_pred = output['f0_denorm']
45
+ wav_out = self.run_vocoder(mel_out, f0=f0_pred)
46
+ wav_out = wav_out.cpu().numpy()
47
+ return wav_out[0]
48
+
49
+ if __name__ == '__main__':
50
+ inp = {
51
+ 'text': '小酒窝长睫毛AP是你最美的记号',
52
+ 'notes': 'C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4',
53
+ 'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340',
54
+ 'input_type': 'word'
55
+ } # user input: Chinese characters
56
+ c = {
57
+ 'text': '小酒窝长睫毛AP是你最美的记号',
58
+ 'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
59
+ 'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
60
+ 'note_dur_seq': '0.407140 0.407140 0.376190 0.376190 0.242180 0.242180 0.509550 0.509550 0.183420 0.315400 0.315400 0.235020 0.361660 0.361660 0.223070 0.377270 0.377270 0.340550 0.340550 0.299620 0.299620 0.344510 0.344510 0.283770 0.283770 0.323390 0.323390 0.360340 0.360340',
61
+ 'is_slur_seq': '0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0',
62
+ 'input_type': 'phoneme'
63
+ } # input like Opencpop dataset.
64
+ DiffSingerE2EInfer.example_run(inp)
65
+
66
+
67
+ # python inference/svs/ds_e2e.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name 0228_opencpop_ds100_rel
inference/svs/gradio/gradio_settings.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ title: 'DiffSinger'
2
+ description: |
3
+ Gradio demo for DiffSinger.
4
+
5
+ 请给每个汉字分配音高和时值, 每个字对应的音高和时值需要用|分隔符隔开。需要保证分隔符分割出来的音符窗口与汉字个数(AP或SP也算一个汉字)一致。
6
+
7
+ article: |
8
+ Link to <a href='https://github.com/MoonInTheRiver/DiffSinger' style='color:blue;' target='_blank\'>Github REPO</a>
9
+ example_inputs:
10
+ - |-
11
+ 你 说 你 不 SP 懂 为 何 在 这 时 牵 手 AP<sep>D#4/Eb4 | D#4/Eb4 | D#4/Eb4 | D#4/Eb4 | rest | D#4/Eb4 | D4 | D4 | D4 | D#4/Eb4 | F4 | D#4/Eb4 | D4 | rest<sep>0.113740 | 0.329060 | 0.287950 | 0.133480 | 0.150900 | 0.484730 | 0.242010 | 0.180820 | 0.343570 | 0.152050 | 0.266720 | 0.280310 | 0.633300 | 0.444590
12
+ - |-
13
+ 小酒窝长睫毛AP是你最美的记号<sep>C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4<sep>0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340
14
+
15
+ #inference_cls: inference.svs.ds_cascade.DiffSingerCascadeInfer
16
+ #exp_name: 0303_opencpop_ds58_midi
17
+
18
+ inference_cls: inference.svs.ds_e2e.DiffSingerE2EInfer
19
+ exp_name: 0228_opencpop_ds100_rel
inference/svs/gradio/infer.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import re
3
+
4
+ import gradio as gr
5
+ import yaml
6
+ from gradio.inputs import Textbox
7
+
8
+ from inference.svs.base_svs_infer import BaseSVSInfer
9
+ from utils.hparams import set_hparams
10
+ from utils.hparams import hparams as hp
11
+ import numpy as np
12
+
13
+
14
+ class GradioInfer:
15
+ def __init__(self, exp_name, inference_cls, title, description, article, example_inputs):
16
+ self.exp_name = exp_name
17
+ self.title = title
18
+ self.description = description
19
+ self.article = article
20
+ self.example_inputs = example_inputs
21
+ pkg = ".".join(inference_cls.split(".")[:-1])
22
+ cls_name = inference_cls.split(".")[-1]
23
+ self.inference_cls = getattr(importlib.import_module(pkg), cls_name)
24
+
25
+ def greet(self, text, notes, notes_duration):
26
+ PUNCS = '。?;:'
27
+ sents = re.split(rf'([{PUNCS}])', text.replace('\n', ','))
28
+ sents_notes = re.split(rf'([{PUNCS}])', notes.replace('\n', ','))
29
+ sents_notes_dur = re.split(rf'([{PUNCS}])', notes_duration.replace('\n', ','))
30
+
31
+ if sents[-1] not in list(PUNCS):
32
+ sents = sents + ['']
33
+ sents_notes = sents_notes + ['']
34
+ sents_notes_dur = sents_notes_dur + ['']
35
+
36
+ audio_outs = []
37
+ s, n, n_dur = "", "", ""
38
+ for i in range(0, len(sents), 2):
39
+ if len(sents[i]) > 0:
40
+ s += sents[i] + sents[i + 1]
41
+ n += sents_notes[i] + sents_notes[i+1]
42
+ n_dur += sents_notes_dur[i] + sents_notes_dur[i+1]
43
+ if len(s) >= 400 or (i >= len(sents) - 2 and len(s) > 0):
44
+ audio_out = self.infer_ins.infer_once({
45
+ 'text': s,
46
+ 'notes': n,
47
+ 'notes_duration': n_dur,
48
+ })
49
+ audio_out = audio_out * 32767
50
+ audio_out = audio_out.astype(np.int16)
51
+ audio_outs.append(audio_out)
52
+ audio_outs.append(np.zeros(int(hp['audio_sample_rate'] * 0.3)).astype(np.int16))
53
+ s = ""
54
+ n = ""
55
+ audio_outs = np.concatenate(audio_outs)
56
+ return hp['audio_sample_rate'], audio_outs
57
+
58
+ def run(self):
59
+ set_hparams(exp_name=self.exp_name, print_hparams=False)
60
+ infer_cls = self.inference_cls
61
+ self.infer_ins: BaseSVSInfer = infer_cls(hp)
62
+ example_inputs = self.example_inputs
63
+ for i in range(len(example_inputs)):
64
+ text, notes, notes_dur = example_inputs[i].split('<sep>')
65
+ example_inputs[i] = [text, notes, notes_dur]
66
+
67
+ iface = gr.Interface(fn=self.greet,
68
+ inputs=[
69
+ Textbox(lines=2, placeholder=None, default=example_inputs[0][0], label="input text"),
70
+ Textbox(lines=2, placeholder=None, default=example_inputs[0][1], label="input note"),
71
+ Textbox(lines=2, placeholder=None, default=example_inputs[0][2], label="input duration")]
72
+ ,
73
+ outputs="audio",
74
+ allow_flagging="never",
75
+ title=self.title,
76
+ description=self.description,
77
+ article=self.article,
78
+ examples=example_inputs,
79
+ enable_queue=True)
80
+ iface.launch(share=True,)# cache_examples=True)
81
+
82
+
83
+ if __name__ == '__main__':
84
+ gradio_config = yaml.safe_load(open('inference/svs/gradio/gradio_settings.yaml'))
85
+ g = GradioInfer(**gradio_config)
86
+ g.run()
87
+
88
+
89
+ # python inference/svs/gradio/infer.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
90
+ # python inference/svs/ds_cascade.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
91
+ # CUDA_VISIBLE_DEVICES=3 python inference/svs/gradio/infer.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name 0228_opencpop_ds100_rel
inference/svs/opencpop/cpop_pinyin2ph.txt ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ | a | a |
2
+ | ai | ai |
3
+ | an | an |
4
+ | ang | ang |
5
+ | ao | ao |
6
+ | ba | b a |
7
+ | bai | b ai |
8
+ | ban | b an |
9
+ | bang | b ang |
10
+ | bao | b ao |
11
+ | bei | b ei |
12
+ | ben | b en |
13
+ | beng | b eng |
14
+ | bi | b i |
15
+ | bian | b ian |
16
+ | biao | b iao |
17
+ | bie | b ie |
18
+ | bin | b in |
19
+ | bing | b ing |
20
+ | bo | b o |
21
+ | bu | b u |
22
+ | ca | c a |
23
+ | cai | c ai |
24
+ | can | c an |
25
+ | cang | c ang |
26
+ | cao | c ao |
27
+ | ce | c e |
28
+ | cei | c ei |
29
+ | cen | c en |
30
+ | ceng | c eng |
31
+ | cha | ch a |
32
+ | chai | ch ai |
33
+ | chan | ch an |
34
+ | chang | ch ang |
35
+ | chao | ch ao |
36
+ | che | ch e |
37
+ | chen | ch en |
38
+ | cheng | ch eng |
39
+ | chi | ch i |
40
+ | chong | ch ong |
41
+ | chou | ch ou |
42
+ | chu | ch u |
43
+ | chua | ch ua |
44
+ | chuai | ch uai |
45
+ | chuan | ch uan |
46
+ | chuang | ch uang |
47
+ | chui | ch ui |
48
+ | chun | ch un |
49
+ | chuo | ch uo |
50
+ | ci | c i |
51
+ | cong | c ong |
52
+ | cou | c ou |
53
+ | cu | c u |
54
+ | cuan | c uan |
55
+ | cui | c ui |
56
+ | cun | c un |
57
+ | cuo | c uo |
58
+ | da | d a |
59
+ | dai | d ai |
60
+ | dan | d an |
61
+ | dang | d ang |
62
+ | dao | d ao |
63
+ | de | d e |
64
+ | dei | d ei |
65
+ | den | d en |
66
+ | deng | d eng |
67
+ | di | d i |
68
+ | dia | d ia |
69
+ | dian | d ian |
70
+ | diao | d iao |
71
+ | die | d ie |
72
+ | ding | d ing |
73
+ | diu | d iu |
74
+ | dong | d ong |
75
+ | dou | d ou |
76
+ | du | d u |
77
+ | duan | d uan |
78
+ | dui | d ui |
79
+ | dun | d un |
80
+ | duo | d uo |
81
+ | e | e |
82
+ | ei | ei |
83
+ | en | en |
84
+ | eng | eng |
85
+ | er | er |
86
+ | fa | f a |
87
+ | fan | f an |
88
+ | fang | f ang |
89
+ | fei | f ei |
90
+ | fen | f en |
91
+ | feng | f eng |
92
+ | fo | f o |
93
+ | fou | f ou |
94
+ | fu | f u |
95
+ | ga | g a |
96
+ | gai | g ai |
97
+ | gan | g an |
98
+ | gang | g ang |
99
+ | gao | g ao |
100
+ | ge | g e |
101
+ | gei | g ei |
102
+ | gen | g en |
103
+ | geng | g eng |
104
+ | gong | g ong |
105
+ | gou | g ou |
106
+ | gu | g u |
107
+ | gua | g ua |
108
+ | guai | g uai |
109
+ | guan | g uan |
110
+ | guang | g uang |
111
+ | gui | g ui |
112
+ | gun | g un |
113
+ | guo | g uo |
114
+ | ha | h a |
115
+ | hai | h ai |
116
+ | han | h an |
117
+ | hang | h ang |
118
+ | hao | h ao |
119
+ | he | h e |
120
+ | hei | h ei |
121
+ | hen | h en |
122
+ | heng | h eng |
123
+ | hm | h m |
124
+ | hng | h ng |
125
+ | hong | h ong |
126
+ | hou | h ou |
127
+ | hu | h u |
128
+ | hua | h ua |
129
+ | huai | h uai |
130
+ | huan | h uan |
131
+ | huang | h uang |
132
+ | hui | h ui |
133
+ | hun | h un |
134
+ | huo | h uo |
135
+ | ji | j i |
136
+ | jia | j ia |
137
+ | jian | j ian |
138
+ | jiang | j iang |
139
+ | jiao | j iao |
140
+ | jie | j ie |
141
+ | jin | j in |
142
+ | jing | j ing |
143
+ | jiong | j iong |
144
+ | jiu | j iu |
145
+ | ju | j v |
146
+ | juan | j van |
147
+ | jue | j ve |
148
+ | jun | j vn |
149
+ | ka | k a |
150
+ | kai | k ai |
151
+ | kan | k an |
152
+ | kang | k ang |
153
+ | kao | k ao |
154
+ | ke | k e |
155
+ | kei | k ei |
156
+ | ken | k en |
157
+ | keng | k eng |
158
+ | kong | k ong |
159
+ | kou | k ou |
160
+ | ku | k u |
161
+ | kua | k ua |
162
+ | kuai | k uai |
163
+ | kuan | k uan |
164
+ | kuang | k uang |
165
+ | kui | k ui |
166
+ | kun | k un |
167
+ | kuo | k uo |
168
+ | la | l a |
169
+ | lai | l ai |
170
+ | lan | l an |
171
+ | lang | l ang |
172
+ | lao | l ao |
173
+ | le | l e |
174
+ | lei | l ei |
175
+ | leng | l eng |
176
+ | li | l i |
177
+ | lia | l ia |
178
+ | lian | l ian |
179
+ | liang | l iang |
180
+ | liao | l iao |
181
+ | lie | l ie |
182
+ | lin | l in |
183
+ | ling | l ing |
184
+ | liu | l iu |
185
+ | lo | l o |
186
+ | long | l ong |
187
+ | lou | l ou |
188
+ | lu | l u |
189
+ | luan | l uan |
190
+ | lun | l un |
191
+ | luo | l uo |
192
+ | lv | l v |
193
+ | lve | l ve |
194
+ | m | m |
195
+ | ma | m a |
196
+ | mai | m ai |
197
+ | man | m an |
198
+ | mang | m ang |
199
+ | mao | m ao |
200
+ | me | m e |
201
+ | mei | m ei |
202
+ | men | m en |
203
+ | meng | m eng |
204
+ | mi | m i |
205
+ | mian | m ian |
206
+ | miao | m iao |
207
+ | mie | m ie |
208
+ | min | m in |
209
+ | ming | m ing |
210
+ | miu | m iu |
211
+ | mo | m o |
212
+ | mou | m ou |
213
+ | mu | m u |
214
+ | n | n |
215
+ | na | n a |
216
+ | nai | n ai |
217
+ | nan | n an |
218
+ | nang | n ang |
219
+ | nao | n ao |
220
+ | ne | n e |
221
+ | nei | n ei |
222
+ | nen | n en |
223
+ | neng | n eng |
224
+ | ng | n g |
225
+ | ni | n i |
226
+ | nian | n ian |
227
+ | niang | n iang |
228
+ | niao | n iao |
229
+ | nie | n ie |
230
+ | nin | n in |
231
+ | ning | n ing |
232
+ | niu | n iu |
233
+ | nong | n ong |
234
+ | nou | n ou |
235
+ | nu | n u |
236
+ | nuan | n uan |
237
+ | nun | n un |
238
+ | nuo | n uo |
239
+ | nv | n v |
240
+ | nve | n ve |
241
+ | o | o |
242
+ | ou | ou |
243
+ | pa | p a |
244
+ | pai | p ai |
245
+ | pan | p an |
246
+ | pang | p ang |
247
+ | pao | p ao |
248
+ | pei | p ei |
249
+ | pen | p en |
250
+ | peng | p eng |
251
+ | pi | p i |
252
+ | pian | p ian |
253
+ | piao | p iao |
254
+ | pie | p ie |
255
+ | pin | p in |
256
+ | ping | p ing |
257
+ | po | p o |
258
+ | pou | p ou |
259
+ | pu | p u |
260
+ | qi | q i |
261
+ | qia | q ia |
262
+ | qian | q ian |
263
+ | qiang | q iang |
264
+ | qiao | q iao |
265
+ | qie | q ie |
266
+ | qin | q in |
267
+ | qing | q ing |
268
+ | qiong | q iong |
269
+ | qiu | q iu |
270
+ | qu | q v |
271
+ | quan | q van |
272
+ | que | q ve |
273
+ | qun | q vn |
274
+ | ran | r an |
275
+ | rang | r ang |
276
+ | rao | r ao |
277
+ | re | r e |
278
+ | ren | r en |
279
+ | reng | r eng |
280
+ | ri | r i |
281
+ | rong | r ong |
282
+ | rou | r ou |
283
+ | ru | r u |
284
+ | rua | r ua |
285
+ | ruan | r uan |
286
+ | rui | r ui |
287
+ | run | r un |
288
+ | ruo | r uo |
289
+ | sa | s a |
290
+ | sai | s ai |
291
+ | san | s an |
292
+ | sang | s ang |
293
+ | sao | s ao |
294
+ | se | s e |
295
+ | sen | s en |
296
+ | seng | s eng |
297
+ | sha | sh a |
298
+ | shai | sh ai |
299
+ | shan | sh an |
300
+ | shang | sh ang |
301
+ | shao | sh ao |
302
+ | she | sh e |
303
+ | shei | sh ei |
304
+ | shen | sh en |
305
+ | sheng | sh eng |
306
+ | shi | sh i |
307
+ | shou | sh ou |
308
+ | shu | sh u |
309
+ | shua | sh ua |
310
+ | shuai | sh uai |
311
+ | shuan | sh uan |
312
+ | shuang | sh uang |
313
+ | shui | sh ui |
314
+ | shun | sh un |
315
+ | shuo | sh uo |
316
+ | si | s i |
317
+ | song | s ong |
318
+ | sou | s ou |
319
+ | su | s u |
320
+ | suan | s uan |
321
+ | sui | s ui |
322
+ | sun | s un |
323
+ | suo | s uo |
324
+ | ta | t a |
325
+ | tai | t ai |
326
+ | tan | t an |
327
+ | tang | t ang |
328
+ | tao | t ao |
329
+ | te | t e |
330
+ | tei | t ei |
331
+ | teng | t eng |
332
+ | ti | t i |
333
+ | tian | t ian |
334
+ | tiao | t iao |
335
+ | tie | t ie |
336
+ | ting | t ing |
337
+ | tong | t ong |
338
+ | tou | t ou |
339
+ | tu | t u |
340
+ | tuan | t uan |
341
+ | tui | t ui |
342
+ | tun | t un |
343
+ | tuo | t uo |
344
+ | wa | w a |
345
+ | wai | w ai |
346
+ | wan | w an |
347
+ | wang | w ang |
348
+ | wei | w ei |
349
+ | wen | w en |
350
+ | weng | w eng |
351
+ | wo | w o |
352
+ | wu | w u |
353
+ | xi | x i |
354
+ | xia | x ia |
355
+ | xian | x ian |
356
+ | xiang | x iang |
357
+ | xiao | x iao |
358
+ | xie | x ie |
359
+ | xin | x in |
360
+ | xing | x ing |
361
+ | xiong | x iong |
362
+ | xiu | x iu |
363
+ | xu | x v |
364
+ | xuan | x van |
365
+ | xue | x ve |
366
+ | xun | x vn |
367
+ | ya | y a |
368
+ | yan | y an |
369
+ | yang | y ang |
370
+ | yao | y ao |
371
+ | ye | y e |
372
+ | yi | y i |
373
+ | yin | y in |
374
+ | ying | y ing |
375
+ | yo | y o |
376
+ | yong | y ong |
377
+ | you | y ou |
378
+ | yu | y v |
379
+ | yuan | y van |
380
+ | yue | y ve |
381
+ | yun | y vn |
382
+ | za | z a |
383
+ | zai | z ai |
384
+ | zan | z an |
385
+ | zang | z ang |
386
+ | zao | z ao |
387
+ | ze | z e |
388
+ | zei | z ei |
389
+ | zen | z en |
390
+ | zeng | z eng |
391
+ | zha | zh a |
392
+ | zhai | zh ai |
393
+ | zhan | zh an |
394
+ | zhang | zh ang |
395
+ | zhao | zh ao |
396
+ | zhe | zh e |
397
+ | zhei | zh ei |
398
+ | zhen | zh en |
399
+ | zheng | zh eng |
400
+ | zhi | zh i |
401
+ | zhong | zh ong |
402
+ | zhou | zh ou |
403
+ | zhu | zh u |
404
+ | zhua | zh ua |
405
+ | zhuai | zh uai |
406
+ | zhuan | zh uan |
407
+ | zhuang | zh uang |
408
+ | zhui | zh ui |
409
+ | zhun | zh un |
410
+ | zhuo | zh uo |
411
+ | zi | z i |
412
+ | zong | z ong |
413
+ | zou | z ou |
414
+ | zu | z u |
415
+ | zuan | z uan |
416
+ | zui | z ui |
417
+ | zun | z un |
418
+ | zuo | z uo |
inference/svs/opencpop/map.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ def cpop_pinyin2ph_func():
2
+ # In the README file of opencpop dataset, they defined a "pinyin to phoneme mapping table"
3
+ pinyin2phs = {'AP': 'AP', 'SP': 'SP'}
4
+ with open('inference/svs/opencpop/cpop_pinyin2ph.txt') as rf:
5
+ for line in rf.readlines():
6
+ elements = [x.strip() for x in line.split('|') if x.strip() != '']
7
+ pinyin2phs[elements[0]] = elements[1]
8
+ return pinyin2phs
modules/__init__.py ADDED
File without changes
modules/commons/common_layers.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import Parameter
5
+ import torch.onnx.operators
6
+ import torch.nn.functional as F
7
+ import utils
8
+
9
+
10
+ class Reshape(nn.Module):
11
+ def __init__(self, *args):
12
+ super(Reshape, self).__init__()
13
+ self.shape = args
14
+
15
+ def forward(self, x):
16
+ return x.view(self.shape)
17
+
18
+
19
+ class Permute(nn.Module):
20
+ def __init__(self, *args):
21
+ super(Permute, self).__init__()
22
+ self.args = args
23
+
24
+ def forward(self, x):
25
+ return x.permute(self.args)
26
+
27
+
28
+ class LinearNorm(torch.nn.Module):
29
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
30
+ super(LinearNorm, self).__init__()
31
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
32
+
33
+ torch.nn.init.xavier_uniform_(
34
+ self.linear_layer.weight,
35
+ gain=torch.nn.init.calculate_gain(w_init_gain))
36
+
37
+ def forward(self, x):
38
+ return self.linear_layer(x)
39
+
40
+
41
+ class ConvNorm(torch.nn.Module):
42
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
43
+ padding=None, dilation=1, bias=True, w_init_gain='linear'):
44
+ super(ConvNorm, self).__init__()
45
+ if padding is None:
46
+ assert (kernel_size % 2 == 1)
47
+ padding = int(dilation * (kernel_size - 1) / 2)
48
+
49
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
50
+ kernel_size=kernel_size, stride=stride,
51
+ padding=padding, dilation=dilation,
52
+ bias=bias)
53
+
54
+ torch.nn.init.xavier_uniform_(
55
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
56
+
57
+ def forward(self, signal):
58
+ conv_signal = self.conv(signal)
59
+ return conv_signal
60
+
61
+
62
+ def Embedding(num_embeddings, embedding_dim, padding_idx=None):
63
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
64
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
65
+ if padding_idx is not None:
66
+ nn.init.constant_(m.weight[padding_idx], 0)
67
+ return m
68
+
69
+
70
+ def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
71
+ if not export and torch.cuda.is_available():
72
+ try:
73
+ from apex.normalization import FusedLayerNorm
74
+ return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
75
+ except ImportError:
76
+ pass
77
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
78
+
79
+
80
+ def Linear(in_features, out_features, bias=True):
81
+ m = nn.Linear(in_features, out_features, bias)
82
+ nn.init.xavier_uniform_(m.weight)
83
+ if bias:
84
+ nn.init.constant_(m.bias, 0.)
85
+ return m
86
+
87
+
88
+ class SinusoidalPositionalEmbedding(nn.Module):
89
+ """This module produces sinusoidal positional embeddings of any length.
90
+
91
+ Padding symbols are ignored.
92
+ """
93
+
94
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
95
+ super().__init__()
96
+ self.embedding_dim = embedding_dim
97
+ self.padding_idx = padding_idx
98
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
99
+ init_size,
100
+ embedding_dim,
101
+ padding_idx,
102
+ )
103
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
104
+
105
+ @staticmethod
106
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
107
+ """Build sinusoidal embeddings.
108
+
109
+ This matches the implementation in tensor2tensor, but differs slightly
110
+ from the description in Section 3.5 of "Attention Is All You Need".
111
+ """
112
+ half_dim = embedding_dim // 2
113
+ emb = math.log(10000) / (half_dim - 1)
114
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
115
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
116
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
117
+ if embedding_dim % 2 == 1:
118
+ # zero pad
119
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
120
+ if padding_idx is not None:
121
+ emb[padding_idx, :] = 0
122
+ return emb
123
+
124
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
125
+ """Input is expected to be of size [bsz x seqlen]."""
126
+ bsz, seq_len = input.shape[:2]
127
+ max_pos = self.padding_idx + 1 + seq_len
128
+ if self.weights is None or max_pos > self.weights.size(0):
129
+ # recompute/expand embeddings if needed
130
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
131
+ max_pos,
132
+ self.embedding_dim,
133
+ self.padding_idx,
134
+ )
135
+ self.weights = self.weights.to(self._float_tensor)
136
+
137
+ if incremental_state is not None:
138
+ # positions is the same for every token when decoding a single step
139
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
140
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
141
+
142
+ positions = utils.make_positions(input, self.padding_idx) if positions is None else positions
143
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
144
+
145
+ def max_positions(self):
146
+ """Maximum number of supported positions."""
147
+ return int(1e5) # an arbitrary large number
148
+
149
+
150
+ class ConvTBC(nn.Module):
151
+ def __init__(self, in_channels, out_channels, kernel_size, padding=0):
152
+ super(ConvTBC, self).__init__()
153
+ self.in_channels = in_channels
154
+ self.out_channels = out_channels
155
+ self.kernel_size = kernel_size
156
+ self.padding = padding
157
+
158
+ self.weight = torch.nn.Parameter(torch.Tensor(
159
+ self.kernel_size, in_channels, out_channels))
160
+ self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
161
+
162
+ def forward(self, input):
163
+ return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding)
164
+
165
+
166
+ class MultiheadAttention(nn.Module):
167
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
168
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
169
+ encoder_decoder_attention=False):
170
+ super().__init__()
171
+ self.embed_dim = embed_dim
172
+ self.kdim = kdim if kdim is not None else embed_dim
173
+ self.vdim = vdim if vdim is not None else embed_dim
174
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
175
+
176
+ self.num_heads = num_heads
177
+ self.dropout = dropout
178
+ self.head_dim = embed_dim // num_heads
179
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
180
+ self.scaling = self.head_dim ** -0.5
181
+
182
+ self.self_attention = self_attention
183
+ self.encoder_decoder_attention = encoder_decoder_attention
184
+
185
+ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
186
+ 'value to be of the same size'
187
+
188
+ if self.qkv_same_dim:
189
+ self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
190
+ else:
191
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
192
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
193
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
194
+
195
+ if bias:
196
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
197
+ else:
198
+ self.register_parameter('in_proj_bias', None)
199
+
200
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
201
+
202
+ if add_bias_kv:
203
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
204
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
205
+ else:
206
+ self.bias_k = self.bias_v = None
207
+
208
+ self.add_zero_attn = add_zero_attn
209
+
210
+ self.reset_parameters()
211
+
212
+ self.enable_torch_version = False
213
+ if hasattr(F, "multi_head_attention_forward"):
214
+ self.enable_torch_version = True
215
+ else:
216
+ self.enable_torch_version = False
217
+ self.last_attn_probs = None
218
+
219
+ def reset_parameters(self):
220
+ if self.qkv_same_dim:
221
+ nn.init.xavier_uniform_(self.in_proj_weight)
222
+ else:
223
+ nn.init.xavier_uniform_(self.k_proj_weight)
224
+ nn.init.xavier_uniform_(self.v_proj_weight)
225
+ nn.init.xavier_uniform_(self.q_proj_weight)
226
+
227
+ nn.init.xavier_uniform_(self.out_proj.weight)
228
+ if self.in_proj_bias is not None:
229
+ nn.init.constant_(self.in_proj_bias, 0.)
230
+ nn.init.constant_(self.out_proj.bias, 0.)
231
+ if self.bias_k is not None:
232
+ nn.init.xavier_normal_(self.bias_k)
233
+ if self.bias_v is not None:
234
+ nn.init.xavier_normal_(self.bias_v)
235
+
236
+ def forward(
237
+ self,
238
+ query, key, value,
239
+ key_padding_mask=None,
240
+ incremental_state=None,
241
+ need_weights=True,
242
+ static_kv=False,
243
+ attn_mask=None,
244
+ before_softmax=False,
245
+ need_head_weights=False,
246
+ enc_dec_attn_constraint_mask=None,
247
+ reset_attn_weight=None
248
+ ):
249
+ """Input shape: Time x Batch x Channel
250
+
251
+ Args:
252
+ key_padding_mask (ByteTensor, optional): mask to exclude
253
+ keys that are pads, of shape `(batch, src_len)`, where
254
+ padding elements are indicated by 1s.
255
+ need_weights (bool, optional): return the attention weights,
256
+ averaged over heads (default: False).
257
+ attn_mask (ByteTensor, optional): typically used to
258
+ implement causal attention, where the mask prevents the
259
+ attention from looking forward in time (default: None).
260
+ before_softmax (bool, optional): return the raw attention
261
+ weights and values before the attention softmax.
262
+ need_head_weights (bool, optional): return the attention
263
+ weights for each head. Implies *need_weights*. Default:
264
+ return the average attention weights over all heads.
265
+ """
266
+ if need_head_weights:
267
+ need_weights = True
268
+
269
+ tgt_len, bsz, embed_dim = query.size()
270
+ assert embed_dim == self.embed_dim
271
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
272
+
273
+ if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
274
+ if self.qkv_same_dim:
275
+ return F.multi_head_attention_forward(query, key, value,
276
+ self.embed_dim, self.num_heads,
277
+ self.in_proj_weight,
278
+ self.in_proj_bias, self.bias_k, self.bias_v,
279
+ self.add_zero_attn, self.dropout,
280
+ self.out_proj.weight, self.out_proj.bias,
281
+ self.training, key_padding_mask, need_weights,
282
+ attn_mask)
283
+ else:
284
+ return F.multi_head_attention_forward(query, key, value,
285
+ self.embed_dim, self.num_heads,
286
+ torch.empty([0]),
287
+ self.in_proj_bias, self.bias_k, self.bias_v,
288
+ self.add_zero_attn, self.dropout,
289
+ self.out_proj.weight, self.out_proj.bias,
290
+ self.training, key_padding_mask, need_weights,
291
+ attn_mask, use_separate_proj_weight=True,
292
+ q_proj_weight=self.q_proj_weight,
293
+ k_proj_weight=self.k_proj_weight,
294
+ v_proj_weight=self.v_proj_weight)
295
+
296
+ if incremental_state is not None:
297
+ print('Not implemented error.')
298
+ exit()
299
+ else:
300
+ saved_state = None
301
+
302
+ if self.self_attention:
303
+ # self-attention
304
+ q, k, v = self.in_proj_qkv(query)
305
+ elif self.encoder_decoder_attention:
306
+ # encoder-decoder attention
307
+ q = self.in_proj_q(query)
308
+ if key is None:
309
+ assert value is None
310
+ k = v = None
311
+ else:
312
+ k = self.in_proj_k(key)
313
+ v = self.in_proj_v(key)
314
+
315
+ else:
316
+ q = self.in_proj_q(query)
317
+ k = self.in_proj_k(key)
318
+ v = self.in_proj_v(value)
319
+ q *= self.scaling
320
+
321
+ if self.bias_k is not None:
322
+ assert self.bias_v is not None
323
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
324
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
325
+ if attn_mask is not None:
326
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
327
+ if key_padding_mask is not None:
328
+ key_padding_mask = torch.cat(
329
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
330
+
331
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
332
+ if k is not None:
333
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
334
+ if v is not None:
335
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
336
+
337
+ if saved_state is not None:
338
+ print('Not implemented error.')
339
+ exit()
340
+
341
+ src_len = k.size(1)
342
+
343
+ # This is part of a workaround to get around fork/join parallelism
344
+ # not supporting Optional types.
345
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
346
+ key_padding_mask = None
347
+
348
+ if key_padding_mask is not None:
349
+ assert key_padding_mask.size(0) == bsz
350
+ assert key_padding_mask.size(1) == src_len
351
+
352
+ if self.add_zero_attn:
353
+ src_len += 1
354
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
355
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
356
+ if attn_mask is not None:
357
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
358
+ if key_padding_mask is not None:
359
+ key_padding_mask = torch.cat(
360
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
361
+
362
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
363
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
364
+
365
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
366
+
367
+ if attn_mask is not None:
368
+ if len(attn_mask.shape) == 2:
369
+ attn_mask = attn_mask.unsqueeze(0)
370
+ elif len(attn_mask.shape) == 3:
371
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
372
+ bsz * self.num_heads, tgt_len, src_len)
373
+ attn_weights = attn_weights + attn_mask
374
+
375
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
376
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
377
+ attn_weights = attn_weights.masked_fill(
378
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
379
+ -1e9,
380
+ )
381
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
382
+
383
+ if key_padding_mask is not None:
384
+ # don't attend to padding symbols
385
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
386
+ attn_weights = attn_weights.masked_fill(
387
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
388
+ -1e9,
389
+ )
390
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
391
+
392
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
393
+
394
+ if before_softmax:
395
+ return attn_weights, v
396
+
397
+ attn_weights_float = utils.softmax(attn_weights, dim=-1)
398
+ attn_weights = attn_weights_float.type_as(attn_weights)
399
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
400
+
401
+ if reset_attn_weight is not None:
402
+ if reset_attn_weight:
403
+ self.last_attn_probs = attn_probs.detach()
404
+ else:
405
+ assert self.last_attn_probs is not None
406
+ attn_probs = self.last_attn_probs
407
+ attn = torch.bmm(attn_probs, v)
408
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
409
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
410
+ attn = self.out_proj(attn)
411
+
412
+ if need_weights:
413
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
414
+ if not need_head_weights:
415
+ # average attention weights over heads
416
+ attn_weights = attn_weights.mean(dim=0)
417
+ else:
418
+ attn_weights = None
419
+
420
+ return attn, (attn_weights, attn_logits)
421
+
422
+ def in_proj_qkv(self, query):
423
+ return self._in_proj(query).chunk(3, dim=-1)
424
+
425
+ def in_proj_q(self, query):
426
+ if self.qkv_same_dim:
427
+ return self._in_proj(query, end=self.embed_dim)
428
+ else:
429
+ bias = self.in_proj_bias
430
+ if bias is not None:
431
+ bias = bias[:self.embed_dim]
432
+ return F.linear(query, self.q_proj_weight, bias)
433
+
434
+ def in_proj_k(self, key):
435
+ if self.qkv_same_dim:
436
+ return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
437
+ else:
438
+ weight = self.k_proj_weight
439
+ bias = self.in_proj_bias
440
+ if bias is not None:
441
+ bias = bias[self.embed_dim:2 * self.embed_dim]
442
+ return F.linear(key, weight, bias)
443
+
444
+ def in_proj_v(self, value):
445
+ if self.qkv_same_dim:
446
+ return self._in_proj(value, start=2 * self.embed_dim)
447
+ else:
448
+ weight = self.v_proj_weight
449
+ bias = self.in_proj_bias
450
+ if bias is not None:
451
+ bias = bias[2 * self.embed_dim:]
452
+ return F.linear(value, weight, bias)
453
+
454
+ def _in_proj(self, input, start=0, end=None):
455
+ weight = self.in_proj_weight
456
+ bias = self.in_proj_bias
457
+ weight = weight[start:end, :]
458
+ if bias is not None:
459
+ bias = bias[start:end]
460
+ return F.linear(input, weight, bias)
461
+
462
+
463
+ def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
464
+ return attn_weights
465
+
466
+
467
+ class Swish(torch.autograd.Function):
468
+ @staticmethod
469
+ def forward(ctx, i):
470
+ result = i * torch.sigmoid(i)
471
+ ctx.save_for_backward(i)
472
+ return result
473
+
474
+ @staticmethod
475
+ def backward(ctx, grad_output):
476
+ i = ctx.saved_variables[0]
477
+ sigmoid_i = torch.sigmoid(i)
478
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
479
+
480
+
481
+ class CustomSwish(nn.Module):
482
+ def forward(self, input_tensor):
483
+ return Swish.apply(input_tensor)
484
+
485
+
486
+ class TransformerFFNLayer(nn.Module):
487
+ def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
488
+ super().__init__()
489
+ self.kernel_size = kernel_size
490
+ self.dropout = dropout
491
+ self.act = act
492
+ if padding == 'SAME':
493
+ self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
494
+ elif padding == 'LEFT':
495
+ self.ffn_1 = nn.Sequential(
496
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
497
+ nn.Conv1d(hidden_size, filter_size, kernel_size)
498
+ )
499
+ self.ffn_2 = Linear(filter_size, hidden_size)
500
+ if self.act == 'swish':
501
+ self.swish_fn = CustomSwish()
502
+
503
+ def forward(self, x, incremental_state=None):
504
+ # x: T x B x C
505
+ if incremental_state is not None:
506
+ assert incremental_state is None, 'Nar-generation does not allow this.'
507
+ exit(1)
508
+
509
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
510
+ x = x * self.kernel_size ** -0.5
511
+
512
+ if incremental_state is not None:
513
+ x = x[-1:]
514
+ if self.act == 'gelu':
515
+ x = F.gelu(x)
516
+ if self.act == 'relu':
517
+ x = F.relu(x)
518
+ if self.act == 'swish':
519
+ x = self.swish_fn(x)
520
+ x = F.dropout(x, self.dropout, training=self.training)
521
+ x = self.ffn_2(x)
522
+ return x
523
+
524
+
525
+ class BatchNorm1dTBC(nn.Module):
526
+ def __init__(self, c):
527
+ super(BatchNorm1dTBC, self).__init__()
528
+ self.bn = nn.BatchNorm1d(c)
529
+
530
+ def forward(self, x):
531
+ """
532
+
533
+ :param x: [T, B, C]
534
+ :return: [T, B, C]
535
+ """
536
+ x = x.permute(1, 2, 0) # [B, C, T]
537
+ x = self.bn(x) # [B, C, T]
538
+ x = x.permute(2, 0, 1) # [T, B, C]
539
+ return x
540
+
541
+
542
+ class EncSALayer(nn.Module):
543
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
544
+ relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'):
545
+ super().__init__()
546
+ self.c = c
547
+ self.dropout = dropout
548
+ self.num_heads = num_heads
549
+ if num_heads > 0:
550
+ if norm == 'ln':
551
+ self.layer_norm1 = LayerNorm(c)
552
+ elif norm == 'bn':
553
+ self.layer_norm1 = BatchNorm1dTBC(c)
554
+ self.self_attn = MultiheadAttention(
555
+ self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False,
556
+ )
557
+ if norm == 'ln':
558
+ self.layer_norm2 = LayerNorm(c)
559
+ elif norm == 'bn':
560
+ self.layer_norm2 = BatchNorm1dTBC(c)
561
+ self.ffn = TransformerFFNLayer(
562
+ c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
563
+
564
+ def forward(self, x, encoder_padding_mask=None, **kwargs):
565
+ layer_norm_training = kwargs.get('layer_norm_training', None)
566
+ if layer_norm_training is not None:
567
+ self.layer_norm1.training = layer_norm_training
568
+ self.layer_norm2.training = layer_norm_training
569
+ if self.num_heads > 0:
570
+ residual = x
571
+ x = self.layer_norm1(x)
572
+ x, _, = self.self_attn(
573
+ query=x,
574
+ key=x,
575
+ value=x,
576
+ key_padding_mask=encoder_padding_mask
577
+ )
578
+ x = F.dropout(x, self.dropout, training=self.training)
579
+ x = residual + x
580
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
581
+
582
+ residual = x
583
+ x = self.layer_norm2(x)
584
+ x = self.ffn(x)
585
+ x = F.dropout(x, self.dropout, training=self.training)
586
+ x = residual + x
587
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
588
+ return x
589
+
590
+
591
+ class DecSALayer(nn.Module):
592
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9, act='gelu'):
593
+ super().__init__()
594
+ self.c = c
595
+ self.dropout = dropout
596
+ self.layer_norm1 = LayerNorm(c)
597
+ self.self_attn = MultiheadAttention(
598
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
599
+ )
600
+ self.layer_norm2 = LayerNorm(c)
601
+ self.encoder_attn = MultiheadAttention(
602
+ c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
603
+ )
604
+ self.layer_norm3 = LayerNorm(c)
605
+ self.ffn = TransformerFFNLayer(
606
+ c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
607
+
608
+ def forward(
609
+ self,
610
+ x,
611
+ encoder_out=None,
612
+ encoder_padding_mask=None,
613
+ incremental_state=None,
614
+ self_attn_mask=None,
615
+ self_attn_padding_mask=None,
616
+ attn_out=None,
617
+ reset_attn_weight=None,
618
+ **kwargs,
619
+ ):
620
+ layer_norm_training = kwargs.get('layer_norm_training', None)
621
+ if layer_norm_training is not None:
622
+ self.layer_norm1.training = layer_norm_training
623
+ self.layer_norm2.training = layer_norm_training
624
+ self.layer_norm3.training = layer_norm_training
625
+ residual = x
626
+ x = self.layer_norm1(x)
627
+ x, _ = self.self_attn(
628
+ query=x,
629
+ key=x,
630
+ value=x,
631
+ key_padding_mask=self_attn_padding_mask,
632
+ incremental_state=incremental_state,
633
+ attn_mask=self_attn_mask
634
+ )
635
+ x = F.dropout(x, self.dropout, training=self.training)
636
+ x = residual + x
637
+
638
+ residual = x
639
+ x = self.layer_norm2(x)
640
+ if encoder_out is not None:
641
+ x, attn = self.encoder_attn(
642
+ query=x,
643
+ key=encoder_out,
644
+ value=encoder_out,
645
+ key_padding_mask=encoder_padding_mask,
646
+ incremental_state=incremental_state,
647
+ static_kv=True,
648
+ enc_dec_attn_constraint_mask=None, #utils.get_incremental_state(self, incremental_state, 'enc_dec_attn_constraint_mask'),
649
+ reset_attn_weight=reset_attn_weight
650
+ )
651
+ attn_logits = attn[1]
652
+ else:
653
+ assert attn_out is not None
654
+ x = self.encoder_attn.in_proj_v(attn_out.transpose(0, 1))
655
+ attn_logits = None
656
+ x = F.dropout(x, self.dropout, training=self.training)
657
+ x = residual + x
658
+
659
+ residual = x
660
+ x = self.layer_norm3(x)
661
+ x = self.ffn(x, incremental_state=incremental_state)
662
+ x = F.dropout(x, self.dropout, training=self.training)
663
+ x = residual + x
664
+ # if len(attn_logits.size()) > 3:
665
+ # indices = attn_logits.softmax(-1).max(-1).values.sum(-1).argmax(-1)
666
+ # attn_logits = attn_logits.gather(1,
667
+ # indices[:, None, None, None].repeat(1, 1, attn_logits.size(-2), attn_logits.size(-1))).squeeze(1)
668
+ return x, attn_logits
modules/commons/espnet_positional_embedding.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+
5
+ class PositionalEncoding(torch.nn.Module):
6
+ """Positional encoding.
7
+ Args:
8
+ d_model (int): Embedding dimension.
9
+ dropout_rate (float): Dropout rate.
10
+ max_len (int): Maximum input length.
11
+ reverse (bool): Whether to reverse the input position.
12
+ """
13
+
14
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
15
+ """Construct an PositionalEncoding object."""
16
+ super(PositionalEncoding, self).__init__()
17
+ self.d_model = d_model
18
+ self.reverse = reverse
19
+ self.xscale = math.sqrt(self.d_model)
20
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
21
+ self.pe = None
22
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
23
+
24
+ def extend_pe(self, x):
25
+ """Reset the positional encodings."""
26
+ if self.pe is not None:
27
+ if self.pe.size(1) >= x.size(1):
28
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
29
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
30
+ return
31
+ pe = torch.zeros(x.size(1), self.d_model)
32
+ if self.reverse:
33
+ position = torch.arange(
34
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
35
+ ).unsqueeze(1)
36
+ else:
37
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
38
+ div_term = torch.exp(
39
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
40
+ * -(math.log(10000.0) / self.d_model)
41
+ )
42
+ pe[:, 0::2] = torch.sin(position * div_term)
43
+ pe[:, 1::2] = torch.cos(position * div_term)
44
+ pe = pe.unsqueeze(0)
45
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
46
+
47
+ def forward(self, x: torch.Tensor):
48
+ """Add positional encoding.
49
+ Args:
50
+ x (torch.Tensor): Input tensor (batch, time, `*`).
51
+ Returns:
52
+ torch.Tensor: Encoded tensor (batch, time, `*`).
53
+ """
54
+ self.extend_pe(x)
55
+ x = x * self.xscale + self.pe[:, : x.size(1)]
56
+ return self.dropout(x)
57
+
58
+
59
+ class ScaledPositionalEncoding(PositionalEncoding):
60
+ """Scaled positional encoding module.
61
+ See Sec. 3.2 https://arxiv.org/abs/1809.08895
62
+ Args:
63
+ d_model (int): Embedding dimension.
64
+ dropout_rate (float): Dropout rate.
65
+ max_len (int): Maximum input length.
66
+ """
67
+
68
+ def __init__(self, d_model, dropout_rate, max_len=5000):
69
+ """Initialize class."""
70
+ super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
71
+ self.alpha = torch.nn.Parameter(torch.tensor(1.0))
72
+
73
+ def reset_parameters(self):
74
+ """Reset parameters."""
75
+ self.alpha.data = torch.tensor(1.0)
76
+
77
+ def forward(self, x):
78
+ """Add positional encoding.
79
+ Args:
80
+ x (torch.Tensor): Input tensor (batch, time, `*`).
81
+ Returns:
82
+ torch.Tensor: Encoded tensor (batch, time, `*`).
83
+ """
84
+ self.extend_pe(x)
85
+ x = x + self.alpha * self.pe[:, : x.size(1)]
86
+ return self.dropout(x)
87
+
88
+
89
+ class RelPositionalEncoding(PositionalEncoding):
90
+ """Relative positional encoding module.
91
+ See : Appendix B in https://arxiv.org/abs/1901.02860
92
+ Args:
93
+ d_model (int): Embedding dimension.
94
+ dropout_rate (float): Dropout rate.
95
+ max_len (int): Maximum input length.
96
+ """
97
+
98
+ def __init__(self, d_model, dropout_rate, max_len=5000):
99
+ """Initialize class."""
100
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
101
+
102
+ def forward(self, x):
103
+ """Compute positional encoding.
104
+ Args:
105
+ x (torch.Tensor): Input tensor (batch, time, `*`).
106
+ Returns:
107
+ torch.Tensor: Encoded tensor (batch, time, `*`).
108
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
109
+ """
110
+ self.extend_pe(x)
111
+ x = x * self.xscale
112
+ pos_emb = self.pe[:, : x.size(1)]
113
+ return self.dropout(x) + self.dropout(pos_emb)
modules/commons/ssim.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # '''
2
+ # https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py
3
+ # '''
4
+ #
5
+ # import torch
6
+ # import torch.jit
7
+ # import torch.nn.functional as F
8
+ #
9
+ #
10
+ # @torch.jit.script
11
+ # def create_window(window_size: int, sigma: float, channel: int):
12
+ # '''
13
+ # Create 1-D gauss kernel
14
+ # :param window_size: the size of gauss kernel
15
+ # :param sigma: sigma of normal distribution
16
+ # :param channel: input channel
17
+ # :return: 1D kernel
18
+ # '''
19
+ # coords = torch.arange(window_size, dtype=torch.float)
20
+ # coords -= window_size // 2
21
+ #
22
+ # g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
23
+ # g /= g.sum()
24
+ #
25
+ # g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1)
26
+ # return g
27
+ #
28
+ #
29
+ # @torch.jit.script
30
+ # def _gaussian_filter(x, window_1d, use_padding: bool):
31
+ # '''
32
+ # Blur input with 1-D kernel
33
+ # :param x: batch of tensors to be blured
34
+ # :param window_1d: 1-D gauss kernel
35
+ # :param use_padding: padding image before conv
36
+ # :return: blured tensors
37
+ # '''
38
+ # C = x.shape[1]
39
+ # padding = 0
40
+ # if use_padding:
41
+ # window_size = window_1d.shape[3]
42
+ # padding = window_size // 2
43
+ # out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C)
44
+ # out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C)
45
+ # return out
46
+ #
47
+ #
48
+ # @torch.jit.script
49
+ # def ssim(X, Y, window, data_range: float, use_padding: bool = False):
50
+ # '''
51
+ # Calculate ssim index for X and Y
52
+ # :param X: images [B, C, H, N_bins]
53
+ # :param Y: images [B, C, H, N_bins]
54
+ # :param window: 1-D gauss kernel
55
+ # :param data_range: value range of input images. (usually 1.0 or 255)
56
+ # :param use_padding: padding image before conv
57
+ # :return:
58
+ # '''
59
+ #
60
+ # K1 = 0.01
61
+ # K2 = 0.03
62
+ # compensation = 1.0
63
+ #
64
+ # C1 = (K1 * data_range) ** 2
65
+ # C2 = (K2 * data_range) ** 2
66
+ #
67
+ # mu1 = _gaussian_filter(X, window, use_padding)
68
+ # mu2 = _gaussian_filter(Y, window, use_padding)
69
+ # sigma1_sq = _gaussian_filter(X * X, window, use_padding)
70
+ # sigma2_sq = _gaussian_filter(Y * Y, window, use_padding)
71
+ # sigma12 = _gaussian_filter(X * Y, window, use_padding)
72
+ #
73
+ # mu1_sq = mu1.pow(2)
74
+ # mu2_sq = mu2.pow(2)
75
+ # mu1_mu2 = mu1 * mu2
76
+ #
77
+ # sigma1_sq = compensation * (sigma1_sq - mu1_sq)
78
+ # sigma2_sq = compensation * (sigma2_sq - mu2_sq)
79
+ # sigma12 = compensation * (sigma12 - mu1_mu2)
80
+ #
81
+ # cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
82
+ # # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan.
83
+ # cs_map = cs_map.clamp_min(0.)
84
+ # ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
85
+ #
86
+ # ssim_val = ssim_map.mean(dim=(1, 2, 3)) # reduce along CHW
87
+ # cs = cs_map.mean(dim=(1, 2, 3))
88
+ #
89
+ # return ssim_val, cs
90
+ #
91
+ #
92
+ # @torch.jit.script
93
+ # def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8):
94
+ # '''
95
+ # interface of ms-ssim
96
+ # :param X: a batch of images, (N,C,H,W)
97
+ # :param Y: a batch of images, (N,C,H,W)
98
+ # :param window: 1-D gauss kernel
99
+ # :param data_range: value range of input images. (usually 1.0 or 255)
100
+ # :param weights: weights for different levels
101
+ # :param use_padding: padding image before conv
102
+ # :param eps: use for avoid grad nan.
103
+ # :return:
104
+ # '''
105
+ # levels = weights.shape[0]
106
+ # cs_vals = []
107
+ # ssim_vals = []
108
+ # for _ in range(levels):
109
+ # ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding)
110
+ # # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
111
+ # ssim_val = ssim_val.clamp_min(eps)
112
+ # cs = cs.clamp_min(eps)
113
+ # cs_vals.append(cs)
114
+ #
115
+ # ssim_vals.append(ssim_val)
116
+ # padding = (X.shape[2] % 2, X.shape[3] % 2)
117
+ # X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding)
118
+ # Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding)
119
+ #
120
+ # cs_vals = torch.stack(cs_vals, dim=0)
121
+ # ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0)
122
+ # return ms_ssim_val
123
+ #
124
+ #
125
+ # class SSIM(torch.jit.ScriptModule):
126
+ # __constants__ = ['data_range', 'use_padding']
127
+ #
128
+ # def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False):
129
+ # '''
130
+ # :param window_size: the size of gauss kernel
131
+ # :param window_sigma: sigma of normal distribution
132
+ # :param data_range: value range of input images. (usually 1.0 or 255)
133
+ # :param channel: input channels (default: 3)
134
+ # :param use_padding: padding image before conv
135
+ # '''
136
+ # super().__init__()
137
+ # assert window_size % 2 == 1, 'Window size must be odd.'
138
+ # window = create_window(window_size, window_sigma, channel)
139
+ # self.register_buffer('window', window)
140
+ # self.data_range = data_range
141
+ # self.use_padding = use_padding
142
+ #
143
+ # @torch.jit.script_method
144
+ # def forward(self, X, Y):
145
+ # r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding)
146
+ # return r[0]
147
+ #
148
+ #
149
+ # class MS_SSIM(torch.jit.ScriptModule):
150
+ # __constants__ = ['data_range', 'use_padding', 'eps']
151
+ #
152
+ # def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None,
153
+ # levels=None, eps=1e-8):
154
+ # '''
155
+ # class for ms-ssim
156
+ # :param window_size: the size of gauss kernel
157
+ # :param window_sigma: sigma of normal distribution
158
+ # :param data_range: value range of input images. (usually 1.0 or 255)
159
+ # :param channel: input channels
160
+ # :param use_padding: padding image before conv
161
+ # :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
162
+ # :param levels: number of downsampling
163
+ # :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
164
+ # '''
165
+ # super().__init__()
166
+ # assert window_size % 2 == 1, 'Window size must be odd.'
167
+ # self.data_range = data_range
168
+ # self.use_padding = use_padding
169
+ # self.eps = eps
170
+ #
171
+ # window = create_window(window_size, window_sigma, channel)
172
+ # self.register_buffer('window', window)
173
+ #
174
+ # if weights is None:
175
+ # weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
176
+ # weights = torch.tensor(weights, dtype=torch.float)
177
+ #
178
+ # if levels is not None:
179
+ # weights = weights[:levels]
180
+ # weights = weights / weights.sum()
181
+ #
182
+ # self.register_buffer('weights', weights)
183
+ #
184
+ # @torch.jit.script_method
185
+ # def forward(self, X, Y):
186
+ # return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights,
187
+ # use_padding=self.use_padding, eps=self.eps)
188
+ #
189
+ #
190
+ # if __name__ == '__main__':
191
+ # print('Simple Test')
192
+ # im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda')
193
+ # img1 = im / 255
194
+ # img2 = img1 * 0.5
195
+ #
196
+ # losser = SSIM(data_range=1.).cuda()
197
+ # loss = losser(img1, img2).mean()
198
+ #
199
+ # losser2 = MS_SSIM(data_range=1.).cuda()
200
+ # loss2 = losser2(img1, img2).mean()
201
+ #
202
+ # print(loss.item())
203
+ # print(loss2.item())
204
+ #
205
+ # if __name__ == '__main__':
206
+ # print('Training Test')
207
+ # import cv2
208
+ # import torch.optim
209
+ # import numpy as np
210
+ # import imageio
211
+ # import time
212
+ #
213
+ # out_test_video = False
214
+ # # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF
215
+ # video_use_gif = False
216
+ #
217
+ # im = cv2.imread('test_img1.jpg', 1)
218
+ # t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255.
219
+ #
220
+ # if out_test_video:
221
+ # if video_use_gif:
222
+ # fps = 0.5
223
+ # out_wh = (im.shape[1] // 2, im.shape[0] // 2)
224
+ # suffix = '.gif'
225
+ # else:
226
+ # fps = 5
227
+ # out_wh = (im.shape[1], im.shape[0])
228
+ # suffix = '.mkv'
229
+ # video_last_time = time.perf_counter()
230
+ # video = imageio.get_writer('ssim_test' + suffix, fps=fps)
231
+ #
232
+ # # 测试ssim
233
+ # print('Training SSIM')
234
+ # rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
235
+ # rand_im.requires_grad = True
236
+ # optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
237
+ # losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda()
238
+ # ssim_score = 0
239
+ # while ssim_score < 0.999:
240
+ # optim.zero_grad()
241
+ # loss = losser(rand_im, t_im)
242
+ # (-loss).sum().backward()
243
+ # ssim_score = loss.item()
244
+ # optim.step()
245
+ # r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
246
+ # r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
247
+ #
248
+ # if out_test_video:
249
+ # if time.perf_counter() - video_last_time > 1. / fps:
250
+ # video_last_time = time.perf_counter()
251
+ # out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
252
+ # out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
253
+ # if isinstance(out_frame, cv2.UMat):
254
+ # out_frame = out_frame.get()
255
+ # video.append_data(out_frame)
256
+ #
257
+ # cv2.imshow('ssim', r_im)
258
+ # cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score)
259
+ # cv2.waitKey(1)
260
+ #
261
+ # if out_test_video:
262
+ # video.close()
263
+ #
264
+ # # 测试ms_ssim
265
+ # if out_test_video:
266
+ # if video_use_gif:
267
+ # fps = 0.5
268
+ # out_wh = (im.shape[1] // 2, im.shape[0] // 2)
269
+ # suffix = '.gif'
270
+ # else:
271
+ # fps = 5
272
+ # out_wh = (im.shape[1], im.shape[0])
273
+ # suffix = '.mkv'
274
+ # video_last_time = time.perf_counter()
275
+ # video = imageio.get_writer('ms_ssim_test' + suffix, fps=fps)
276
+ #
277
+ # print('Training MS_SSIM')
278
+ # rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
279
+ # rand_im.requires_grad = True
280
+ # optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
281
+ # losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda()
282
+ # ssim_score = 0
283
+ # while ssim_score < 0.999:
284
+ # optim.zero_grad()
285
+ # loss = losser(rand_im, t_im)
286
+ # (-loss).sum().backward()
287
+ # ssim_score = loss.item()
288
+ # optim.step()
289
+ # r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
290
+ # r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
291
+ #
292
+ # if out_test_video:
293
+ # if time.perf_counter() - video_last_time > 1. / fps:
294
+ # video_last_time = time.perf_counter()
295
+ # out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
296
+ # out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
297
+ # if isinstance(out_frame, cv2.UMat):
298
+ # out_frame = out_frame.get()
299
+ # video.append_data(out_frame)
300
+ #
301
+ # cv2.imshow('ms_ssim', r_im)
302
+ # cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score)
303
+ # cv2.waitKey(1)
304
+ #
305
+ # if out_test_video:
306
+ # video.close()
307
+
308
+ """
309
+ Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim
310
+ """
311
+
312
+ import torch
313
+ import torch.nn.functional as F
314
+ from torch.autograd import Variable
315
+ import numpy as np
316
+ from math import exp
317
+
318
+
319
+ def gaussian(window_size, sigma):
320
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
321
+ return gauss / gauss.sum()
322
+
323
+
324
+ def create_window(window_size, channel):
325
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
326
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
327
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
328
+ return window
329
+
330
+
331
+ def _ssim(img1, img2, window, window_size, channel, size_average=True):
332
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
333
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
334
+
335
+ mu1_sq = mu1.pow(2)
336
+ mu2_sq = mu2.pow(2)
337
+ mu1_mu2 = mu1 * mu2
338
+
339
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
340
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
341
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
342
+
343
+ C1 = 0.01 ** 2
344
+ C2 = 0.03 ** 2
345
+
346
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
347
+
348
+ if size_average:
349
+ return ssim_map.mean()
350
+ else:
351
+ return ssim_map.mean(1)
352
+
353
+
354
+ class SSIM(torch.nn.Module):
355
+ def __init__(self, window_size=11, size_average=True):
356
+ super(SSIM, self).__init__()
357
+ self.window_size = window_size
358
+ self.size_average = size_average
359
+ self.channel = 1
360
+ self.window = create_window(window_size, self.channel)
361
+
362
+ def forward(self, img1, img2):
363
+ (_, channel, _, _) = img1.size()
364
+
365
+ if channel == self.channel and self.window.data.type() == img1.data.type():
366
+ window = self.window
367
+ else:
368
+ window = create_window(self.window_size, channel)
369
+
370
+ if img1.is_cuda:
371
+ window = window.cuda(img1.get_device())
372
+ window = window.type_as(img1)
373
+
374
+ self.window = window
375
+ self.channel = channel
376
+
377
+ return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
378
+
379
+
380
+ window = None
381
+
382
+
383
+ def ssim(img1, img2, window_size=11, size_average=True):
384
+ (_, channel, _, _) = img1.size()
385
+ global window
386
+ if window is None:
387
+ window = create_window(window_size, channel)
388
+ if img1.is_cuda:
389
+ window = window.cuda(img1.get_device())
390
+ window = window.type_as(img1)
391
+ return _ssim(img1, img2, window, window_size, channel, size_average)
modules/diffsinger_midi/fs2.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.commons.common_layers import *
2
+ from modules.commons.common_layers import Embedding
3
+ from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
4
+ EnergyPredictor, FastspeechEncoder
5
+ from utils.cwt import cwt2f0
6
+ from utils.hparams import hparams
7
+ from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
8
+ from modules.fastspeech.fs2 import FastSpeech2
9
+
10
+
11
+ class FastspeechMIDIEncoder(FastspeechEncoder):
12
+ def forward_embedding(self, txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding):
13
+ # embed tokens and positions
14
+ x = self.embed_scale * self.embed_tokens(txt_tokens)
15
+ x = x + midi_embedding + midi_dur_embedding + slur_embedding
16
+ if hparams['use_pos_embed']:
17
+ if hparams.get('rel_pos') is not None and hparams['rel_pos']:
18
+ x = self.embed_positions(x)
19
+ else:
20
+ positions = self.embed_positions(txt_tokens)
21
+ x = x + positions
22
+ x = F.dropout(x, p=self.dropout, training=self.training)
23
+ return x
24
+
25
+ def forward(self, txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding):
26
+ """
27
+
28
+ :param txt_tokens: [B, T]
29
+ :return: {
30
+ 'encoder_out': [T x B x C]
31
+ }
32
+ """
33
+ encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
34
+ x = self.forward_embedding(txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding) # [B, T, H]
35
+ x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask)
36
+ return x
37
+
38
+
39
+ FS_ENCODERS = {
40
+ 'fft': lambda hp, embed_tokens, d: FastspeechMIDIEncoder(
41
+ embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
42
+ num_heads=hp['num_heads']),
43
+ }
44
+
45
+
46
+ class FastSpeech2MIDI(FastSpeech2):
47
+ def __init__(self, dictionary, out_dims=None):
48
+ super().__init__(dictionary, out_dims)
49
+ del self.encoder
50
+ self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary)
51
+ self.midi_embed = Embedding(300, self.hidden_size, self.padding_idx)
52
+ self.midi_dur_layer = Linear(1, self.hidden_size)
53
+ self.is_slur_embed = Embedding(2, self.hidden_size)
54
+
55
+ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
56
+ ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
57
+ spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
58
+ ret = {}
59
+
60
+ midi_embedding = self.midi_embed(kwargs['pitch_midi'])
61
+ midi_dur_embedding, slur_embedding = 0, 0
62
+ if kwargs.get('midi_dur') is not None:
63
+ midi_dur_embedding = self.midi_dur_layer(kwargs['midi_dur'][:, :, None]) # [B, T, 1] -> [B, T, H]
64
+ if kwargs.get('is_slur') is not None:
65
+ slur_embedding = self.is_slur_embed(kwargs['is_slur'])
66
+ encoder_out = self.encoder(txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding) # [B, T, C]
67
+ src_nonpadding = (txt_tokens > 0).float()[:, :, None]
68
+
69
+ # add ref style embed
70
+ # Not implemented
71
+ # variance encoder
72
+ var_embed = 0
73
+
74
+ # encoder_out_dur denotes encoder outputs for duration predictor
75
+ # in speech adaptation, duration predictor use old speaker embedding
76
+ if hparams['use_spk_embed']:
77
+ spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
78
+ elif hparams['use_spk_id']:
79
+ spk_embed_id = spk_embed
80
+ if spk_embed_dur_id is None:
81
+ spk_embed_dur_id = spk_embed_id
82
+ if spk_embed_f0_id is None:
83
+ spk_embed_f0_id = spk_embed_id
84
+ spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
85
+ spk_embed_dur = spk_embed_f0 = spk_embed
86
+ if hparams['use_split_spk_id']:
87
+ spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
88
+ spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
89
+ else:
90
+ spk_embed_dur = spk_embed_f0 = spk_embed = 0
91
+
92
+ # add dur
93
+ dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
94
+
95
+ mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
96
+
97
+ decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
98
+
99
+ mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
100
+ decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
101
+
102
+ tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
103
+
104
+ # add pitch and energy embed
105
+ pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
106
+ if hparams['use_pitch_embed']:
107
+ pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
108
+ decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
109
+ if hparams['use_energy_embed']:
110
+ decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
111
+
112
+ ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
113
+
114
+ if skip_decoder:
115
+ return ret
116
+ ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
117
+
118
+ return ret
modules/fastspeech/fs2.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.commons.common_layers import *
2
+ from modules.commons.common_layers import Embedding
3
+ from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
4
+ EnergyPredictor, FastspeechEncoder
5
+ from utils.cwt import cwt2f0
6
+ from utils.hparams import hparams
7
+ from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
8
+
9
+ FS_ENCODERS = {
10
+ 'fft': lambda hp, embed_tokens, d: FastspeechEncoder(
11
+ embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
12
+ num_heads=hp['num_heads']),
13
+ }
14
+
15
+ FS_DECODERS = {
16
+ 'fft': lambda hp: FastspeechDecoder(
17
+ hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']),
18
+ }
19
+
20
+
21
+ class FastSpeech2(nn.Module):
22
+ def __init__(self, dictionary, out_dims=None):
23
+ super().__init__()
24
+ self.dictionary = dictionary
25
+ self.padding_idx = dictionary.pad()
26
+ self.enc_layers = hparams['enc_layers']
27
+ self.dec_layers = hparams['dec_layers']
28
+ self.hidden_size = hparams['hidden_size']
29
+ self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size)
30
+ self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary)
31
+ self.decoder = FS_DECODERS[hparams['decoder_type']](hparams)
32
+ self.out_dims = out_dims
33
+ if out_dims is None:
34
+ self.out_dims = hparams['audio_num_mel_bins']
35
+ self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True)
36
+
37
+ if hparams['use_spk_id']:
38
+ self.spk_embed_proj = Embedding(hparams['num_spk'] + 1, self.hidden_size)
39
+ if hparams['use_split_spk_id']:
40
+ self.spk_embed_f0 = Embedding(hparams['num_spk'] + 1, self.hidden_size)
41
+ self.spk_embed_dur = Embedding(hparams['num_spk'] + 1, self.hidden_size)
42
+ elif hparams['use_spk_embed']:
43
+ self.spk_embed_proj = Linear(256, self.hidden_size, bias=True)
44
+ predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
45
+ self.dur_predictor = DurationPredictor(
46
+ self.hidden_size,
47
+ n_chans=predictor_hidden,
48
+ n_layers=hparams['dur_predictor_layers'],
49
+ dropout_rate=hparams['predictor_dropout'], padding=hparams['ffn_padding'],
50
+ kernel_size=hparams['dur_predictor_kernel'])
51
+ self.length_regulator = LengthRegulator()
52
+ if hparams['use_pitch_embed']:
53
+ self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx)
54
+ if hparams['pitch_type'] == 'cwt':
55
+ h = hparams['cwt_hidden_size']
56
+ cwt_out_dims = 10
57
+ if hparams['use_uv']:
58
+ cwt_out_dims = cwt_out_dims + 1
59
+ self.cwt_predictor = nn.Sequential(
60
+ nn.Linear(self.hidden_size, h),
61
+ PitchPredictor(
62
+ h,
63
+ n_chans=predictor_hidden,
64
+ n_layers=hparams['predictor_layers'],
65
+ dropout_rate=hparams['predictor_dropout'], odim=cwt_out_dims,
66
+ padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']))
67
+ self.cwt_stats_layers = nn.Sequential(
68
+ nn.Linear(self.hidden_size, h), nn.ReLU(),
69
+ nn.Linear(h, h), nn.ReLU(), nn.Linear(h, 2)
70
+ )
71
+ else:
72
+ self.pitch_predictor = PitchPredictor(
73
+ self.hidden_size,
74
+ n_chans=predictor_hidden,
75
+ n_layers=hparams['predictor_layers'],
76
+ dropout_rate=hparams['predictor_dropout'],
77
+ odim=2 if hparams['pitch_type'] == 'frame' else 1,
78
+ padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
79
+ if hparams['use_energy_embed']:
80
+ self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx)
81
+ self.energy_predictor = EnergyPredictor(
82
+ self.hidden_size,
83
+ n_chans=predictor_hidden,
84
+ n_layers=hparams['predictor_layers'],
85
+ dropout_rate=hparams['predictor_dropout'], odim=1,
86
+ padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
87
+
88
+ def build_embedding(self, dictionary, embed_dim):
89
+ num_embeddings = len(dictionary)
90
+ emb = Embedding(num_embeddings, embed_dim, self.padding_idx)
91
+ return emb
92
+
93
+ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
94
+ ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
95
+ spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
96
+ ret = {}
97
+ encoder_out = self.encoder(txt_tokens) # [B, T, C]
98
+ src_nonpadding = (txt_tokens > 0).float()[:, :, None]
99
+
100
+ # add ref style embed
101
+ # Not implemented
102
+ # variance encoder
103
+ var_embed = 0
104
+
105
+ # encoder_out_dur denotes encoder outputs for duration predictor
106
+ # in speech adaptation, duration predictor use old speaker embedding
107
+ if hparams['use_spk_embed']:
108
+ spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
109
+ elif hparams['use_spk_id']:
110
+ spk_embed_id = spk_embed
111
+ if spk_embed_dur_id is None:
112
+ spk_embed_dur_id = spk_embed_id
113
+ if spk_embed_f0_id is None:
114
+ spk_embed_f0_id = spk_embed_id
115
+ spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
116
+ spk_embed_dur = spk_embed_f0 = spk_embed
117
+ if hparams['use_split_spk_id']:
118
+ spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
119
+ spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
120
+ else:
121
+ spk_embed_dur = spk_embed_f0 = spk_embed = 0
122
+
123
+ # add dur
124
+ dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
125
+
126
+ mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
127
+
128
+ decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
129
+
130
+ mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
131
+ decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
132
+
133
+ tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
134
+
135
+ # add pitch and energy embed
136
+ pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
137
+ if hparams['use_pitch_embed']:
138
+ pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
139
+ decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
140
+ if hparams['use_energy_embed']:
141
+ decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
142
+
143
+ ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
144
+
145
+ if skip_decoder:
146
+ return ret
147
+ ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
148
+
149
+ return ret
150
+
151
+ def add_dur(self, dur_input, mel2ph, txt_tokens, ret):
152
+ """
153
+
154
+ :param dur_input: [B, T_txt, H]
155
+ :param mel2ph: [B, T_mel]
156
+ :param txt_tokens: [B, T_txt]
157
+ :param ret:
158
+ :return:
159
+ """
160
+ src_padding = txt_tokens == 0
161
+ dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach())
162
+ if mel2ph is None:
163
+ dur, xs = self.dur_predictor.inference(dur_input, src_padding)
164
+ ret['dur'] = xs
165
+ ret['dur_choice'] = dur
166
+ mel2ph = self.length_regulator(dur, src_padding).detach()
167
+ # from modules.fastspeech.fake_modules import FakeLengthRegulator
168
+ # fake_lr = FakeLengthRegulator()
169
+ # fake_mel2ph = fake_lr(dur, (1 - src_padding.long()).sum(-1))[..., 0].detach()
170
+ # print(mel2ph == fake_mel2ph)
171
+ else:
172
+ ret['dur'] = self.dur_predictor(dur_input, src_padding)
173
+ ret['mel2ph'] = mel2ph
174
+ return mel2ph
175
+
176
+ def add_energy(self, decoder_inp, energy, ret):
177
+ decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
178
+ ret['energy_pred'] = energy_pred = self.energy_predictor(decoder_inp)[:, :, 0]
179
+ if energy is None:
180
+ energy = energy_pred
181
+ energy = torch.clamp(energy * 256 // 4, max=255).long()
182
+ energy_embed = self.energy_embed(energy)
183
+ return energy_embed
184
+
185
+ def add_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
186
+ if hparams['pitch_type'] == 'ph':
187
+ pitch_pred_inp = encoder_out.detach() + hparams['predictor_grad'] * (encoder_out - encoder_out.detach())
188
+ pitch_padding = encoder_out.sum().abs() == 0
189
+ ret['pitch_pred'] = pitch_pred = self.pitch_predictor(pitch_pred_inp)
190
+ if f0 is None:
191
+ f0 = pitch_pred[:, :, 0]
192
+ ret['f0_denorm'] = f0_denorm = denorm_f0(f0, None, hparams, pitch_padding=pitch_padding)
193
+ pitch = f0_to_coarse(f0_denorm) # start from 0 [B, T_txt]
194
+ pitch = F.pad(pitch, [1, 0])
195
+ pitch = torch.gather(pitch, 1, mel2ph) # [B, T_mel]
196
+ pitch_embed = self.pitch_embed(pitch)
197
+ return pitch_embed
198
+ decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
199
+
200
+ pitch_padding = mel2ph == 0
201
+
202
+ if hparams['pitch_type'] == 'cwt':
203
+ pitch_padding = None
204
+ ret['cwt'] = cwt_out = self.cwt_predictor(decoder_inp)
205
+ stats_out = self.cwt_stats_layers(encoder_out[:, 0, :]) # [B, 2]
206
+ mean = ret['f0_mean'] = stats_out[:, 0]
207
+ std = ret['f0_std'] = stats_out[:, 1]
208
+ cwt_spec = cwt_out[:, :, :10]
209
+ if f0 is None:
210
+ std = std * hparams['cwt_std_scale']
211
+ f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
212
+ if hparams['use_uv']:
213
+ assert cwt_out.shape[-1] == 11
214
+ uv = cwt_out[:, :, -1] > 0
215
+ elif hparams['pitch_ar']:
216
+ ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if self.training else None)
217
+ if f0 is None:
218
+ f0 = pitch_pred[:, :, 0]
219
+ else:
220
+ ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
221
+ if f0 is None:
222
+ f0 = pitch_pred[:, :, 0]
223
+ if hparams['use_uv'] and uv is None:
224
+ uv = pitch_pred[:, :, 1] > 0
225
+ ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
226
+ if pitch_padding is not None:
227
+ f0[pitch_padding] = 0
228
+
229
+ pitch = f0_to_coarse(f0_denorm) # start from 0
230
+ pitch_embed = self.pitch_embed(pitch)
231
+ return pitch_embed
232
+
233
+ def run_decoder(self, decoder_inp, tgt_nonpadding, ret, infer, **kwargs):
234
+ x = decoder_inp # [B, T, H]
235
+ x = self.decoder(x)
236
+ x = self.mel_out(x)
237
+ return x * tgt_nonpadding
238
+
239
+ def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
240
+ f0 = cwt2f0(cwt_spec, mean, std, hparams['cwt_scales'])
241
+ f0 = torch.cat(
242
+ [f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1)
243
+ f0_norm = norm_f0(f0, None, hparams)
244
+ return f0_norm
245
+
246
+ def out2mel(self, out):
247
+ return out
248
+
249
+ @staticmethod
250
+ def mel_norm(x):
251
+ return (x + 5.5) / (6.3 / 2) - 1
252
+
253
+ @staticmethod
254
+ def mel_denorm(x):
255
+ return (x + 1) * (6.3 / 2) - 5.5
modules/fastspeech/pe.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.commons.common_layers import *
2
+ from utils.hparams import hparams
3
+ from modules.fastspeech.tts_modules import PitchPredictor
4
+ from utils.pitch_utils import denorm_f0
5
+
6
+
7
+ class Prenet(nn.Module):
8
+ def __init__(self, in_dim=80, out_dim=256, kernel=5, n_layers=3, strides=None):
9
+ super(Prenet, self).__init__()
10
+ padding = kernel // 2
11
+ self.layers = []
12
+ self.strides = strides if strides is not None else [1] * n_layers
13
+ for l in range(n_layers):
14
+ self.layers.append(nn.Sequential(
15
+ nn.Conv1d(in_dim, out_dim, kernel_size=kernel, padding=padding, stride=self.strides[l]),
16
+ nn.ReLU(),
17
+ nn.BatchNorm1d(out_dim)
18
+ ))
19
+ in_dim = out_dim
20
+ self.layers = nn.ModuleList(self.layers)
21
+ self.out_proj = nn.Linear(out_dim, out_dim)
22
+
23
+ def forward(self, x):
24
+ """
25
+
26
+ :param x: [B, T, 80]
27
+ :return: [L, B, T, H], [B, T, H]
28
+ """
29
+ padding_mask = x.abs().sum(-1).eq(0).data # [B, T]
30
+ nonpadding_mask_TB = 1 - padding_mask.float()[:, None, :] # [B, 1, T]
31
+ x = x.transpose(1, 2)
32
+ hiddens = []
33
+ for i, l in enumerate(self.layers):
34
+ nonpadding_mask_TB = nonpadding_mask_TB[:, :, ::self.strides[i]]
35
+ x = l(x) * nonpadding_mask_TB
36
+ hiddens.append(x)
37
+ hiddens = torch.stack(hiddens, 0) # [L, B, H, T]
38
+ hiddens = hiddens.transpose(2, 3) # [L, B, T, H]
39
+ x = self.out_proj(x.transpose(1, 2)) # [B, T, H]
40
+ x = x * nonpadding_mask_TB.transpose(1, 2)
41
+ return hiddens, x
42
+
43
+
44
+ class ConvBlock(nn.Module):
45
+ def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0):
46
+ super().__init__()
47
+ self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride)
48
+ self.norm = norm
49
+ if self.norm == 'bn':
50
+ self.norm = nn.BatchNorm1d(n_chans)
51
+ elif self.norm == 'in':
52
+ self.norm = nn.InstanceNorm1d(n_chans, affine=True)
53
+ elif self.norm == 'gn':
54
+ self.norm = nn.GroupNorm(n_chans // 16, n_chans)
55
+ elif self.norm == 'ln':
56
+ self.norm = LayerNorm(n_chans // 16, n_chans)
57
+ elif self.norm == 'wn':
58
+ self.conv = torch.nn.utils.weight_norm(self.conv.conv)
59
+ self.dropout = nn.Dropout(dropout)
60
+ self.relu = nn.ReLU()
61
+
62
+ def forward(self, x):
63
+ """
64
+
65
+ :param x: [B, C, T]
66
+ :return: [B, C, T]
67
+ """
68
+ x = self.conv(x)
69
+ if not isinstance(self.norm, str):
70
+ if self.norm == 'none':
71
+ pass
72
+ elif self.norm == 'ln':
73
+ x = self.norm(x.transpose(1, 2)).transpose(1, 2)
74
+ else:
75
+ x = self.norm(x)
76
+ x = self.relu(x)
77
+ x = self.dropout(x)
78
+ return x
79
+
80
+
81
+ class ConvStacks(nn.Module):
82
+ def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn',
83
+ dropout=0, strides=None, res=True):
84
+ super().__init__()
85
+ self.conv = torch.nn.ModuleList()
86
+ self.kernel_size = kernel_size
87
+ self.res = res
88
+ self.in_proj = Linear(idim, n_chans)
89
+ if strides is None:
90
+ strides = [1] * n_layers
91
+ else:
92
+ assert len(strides) == n_layers
93
+ for idx in range(n_layers):
94
+ self.conv.append(ConvBlock(
95
+ n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout))
96
+ self.out_proj = Linear(n_chans, odim)
97
+
98
+ def forward(self, x, return_hiddens=False):
99
+ """
100
+
101
+ :param x: [B, T, H]
102
+ :return: [B, T, H]
103
+ """
104
+ x = self.in_proj(x)
105
+ x = x.transpose(1, -1) # (B, idim, Tmax)
106
+ hiddens = []
107
+ for f in self.conv:
108
+ x_ = f(x)
109
+ x = x + x_ if self.res else x_ # (B, C, Tmax)
110
+ hiddens.append(x)
111
+ x = x.transpose(1, -1)
112
+ x = self.out_proj(x) # (B, Tmax, H)
113
+ if return_hiddens:
114
+ hiddens = torch.stack(hiddens, 1) # [B, L, C, T]
115
+ return x, hiddens
116
+ return x
117
+
118
+
119
+ class PitchExtractor(nn.Module):
120
+ def __init__(self, n_mel_bins=80, conv_layers=2):
121
+ super().__init__()
122
+ self.hidden_size = hparams['hidden_size']
123
+ self.predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
124
+ self.conv_layers = conv_layers
125
+
126
+ self.mel_prenet = Prenet(n_mel_bins, self.hidden_size, strides=[1, 1, 1])
127
+ if self.conv_layers > 0:
128
+ self.mel_encoder = ConvStacks(
129
+ idim=self.hidden_size, n_chans=self.hidden_size, odim=self.hidden_size, n_layers=self.conv_layers)
130
+ self.pitch_predictor = PitchPredictor(
131
+ self.hidden_size, n_chans=self.predictor_hidden,
132
+ n_layers=5, dropout_rate=0.1, odim=2,
133
+ padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
134
+
135
+ def forward(self, mel_input=None):
136
+ ret = {}
137
+ mel_hidden = self.mel_prenet(mel_input)[1]
138
+ if self.conv_layers > 0:
139
+ mel_hidden = self.mel_encoder(mel_hidden)
140
+
141
+ ret['pitch_pred'] = pitch_pred = self.pitch_predictor(mel_hidden)
142
+
143
+ pitch_padding = mel_input.abs().sum(-1) == 0
144
+ use_uv = hparams['pitch_type'] == 'frame' and hparams['use_uv']
145
+
146
+ ret['f0_denorm_pred'] = denorm_f0(
147
+ pitch_pred[:, :, 0], (pitch_pred[:, :, 1] > 0) if use_uv else None,
148
+ hparams, pitch_padding=pitch_padding)
149
+ return ret
modules/fastspeech/tts_modules.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+
8
+ from modules.commons.espnet_positional_embedding import RelPositionalEncoding
9
+ from modules.commons.common_layers import SinusoidalPositionalEmbedding, Linear, EncSALayer, DecSALayer, BatchNorm1dTBC
10
+ from utils.hparams import hparams
11
+
12
+ DEFAULT_MAX_SOURCE_POSITIONS = 2000
13
+ DEFAULT_MAX_TARGET_POSITIONS = 2000
14
+
15
+
16
+ class TransformerEncoderLayer(nn.Module):
17
+ def __init__(self, hidden_size, dropout, kernel_size=None, num_heads=2, norm='ln'):
18
+ super().__init__()
19
+ self.hidden_size = hidden_size
20
+ self.dropout = dropout
21
+ self.num_heads = num_heads
22
+ self.op = EncSALayer(
23
+ hidden_size, num_heads, dropout=dropout,
24
+ attention_dropout=0.0, relu_dropout=dropout,
25
+ kernel_size=kernel_size
26
+ if kernel_size is not None else hparams['enc_ffn_kernel_size'],
27
+ padding=hparams['ffn_padding'],
28
+ norm=norm, act=hparams['ffn_act'])
29
+
30
+ def forward(self, x, **kwargs):
31
+ return self.op(x, **kwargs)
32
+
33
+
34
+ ######################
35
+ # fastspeech modules
36
+ ######################
37
+ class LayerNorm(torch.nn.LayerNorm):
38
+ """Layer normalization module.
39
+ :param int nout: output dim size
40
+ :param int dim: dimension to be normalized
41
+ """
42
+
43
+ def __init__(self, nout, dim=-1):
44
+ """Construct an LayerNorm object."""
45
+ super(LayerNorm, self).__init__(nout, eps=1e-12)
46
+ self.dim = dim
47
+
48
+ def forward(self, x):
49
+ """Apply layer normalization.
50
+ :param torch.Tensor x: input tensor
51
+ :return: layer normalized tensor
52
+ :rtype torch.Tensor
53
+ """
54
+ if self.dim == -1:
55
+ return super(LayerNorm, self).forward(x)
56
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
57
+
58
+
59
+ class DurationPredictor(torch.nn.Module):
60
+ """Duration predictor module.
61
+ This is a module of duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
62
+ The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder.
63
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
64
+ https://arxiv.org/pdf/1905.09263.pdf
65
+ Note:
66
+ The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`,
67
+ the outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
68
+ """
69
+
70
+ def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0, padding='SAME'):
71
+ """Initilize duration predictor module.
72
+ Args:
73
+ idim (int): Input dimension.
74
+ n_layers (int, optional): Number of convolutional layers.
75
+ n_chans (int, optional): Number of channels of convolutional layers.
76
+ kernel_size (int, optional): Kernel size of convolutional layers.
77
+ dropout_rate (float, optional): Dropout rate.
78
+ offset (float, optional): Offset value to avoid nan in log domain.
79
+ """
80
+ super(DurationPredictor, self).__init__()
81
+ self.offset = offset
82
+ self.conv = torch.nn.ModuleList()
83
+ self.kernel_size = kernel_size
84
+ self.padding = padding
85
+ for idx in range(n_layers):
86
+ in_chans = idim if idx == 0 else n_chans
87
+ self.conv += [torch.nn.Sequential(
88
+ torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
89
+ if padding == 'SAME'
90
+ else (kernel_size - 1, 0), 0),
91
+ torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
92
+ torch.nn.ReLU(),
93
+ LayerNorm(n_chans, dim=1),
94
+ torch.nn.Dropout(dropout_rate)
95
+ )]
96
+ if hparams['dur_loss'] in ['mse', 'huber']:
97
+ odims = 1
98
+ elif hparams['dur_loss'] == 'mog':
99
+ odims = 15
100
+ elif hparams['dur_loss'] == 'crf':
101
+ odims = 32
102
+ from torchcrf import CRF
103
+ self.crf = CRF(odims, batch_first=True)
104
+ self.linear = torch.nn.Linear(n_chans, odims)
105
+
106
+ def _forward(self, xs, x_masks=None, is_inference=False):
107
+ xs = xs.transpose(1, -1) # (B, idim, Tmax)
108
+ for f in self.conv:
109
+ xs = f(xs) # (B, C, Tmax)
110
+ if x_masks is not None:
111
+ xs = xs * (1 - x_masks.float())[:, None, :]
112
+
113
+ xs = self.linear(xs.transpose(1, -1)) # [B, T, C]
114
+ xs = xs * (1 - x_masks.float())[:, :, None] # (B, T, C)
115
+ if is_inference:
116
+ return self.out2dur(xs), xs
117
+ else:
118
+ if hparams['dur_loss'] in ['mse']:
119
+ xs = xs.squeeze(-1) # (B, Tmax)
120
+ return xs
121
+
122
+ def out2dur(self, xs):
123
+ if hparams['dur_loss'] in ['mse']:
124
+ # NOTE: calculate in log domain
125
+ xs = xs.squeeze(-1) # (B, Tmax)
126
+ dur = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value
127
+ elif hparams['dur_loss'] == 'mog':
128
+ return NotImplementedError
129
+ elif hparams['dur_loss'] == 'crf':
130
+ dur = torch.LongTensor(self.crf.decode(xs)).cuda()
131
+ return dur
132
+
133
+ def forward(self, xs, x_masks=None):
134
+ """Calculate forward propagation.
135
+ Args:
136
+ xs (Tensor): Batch of input sequences (B, Tmax, idim).
137
+ x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
138
+ Returns:
139
+ Tensor: Batch of predicted durations in log domain (B, Tmax).
140
+ """
141
+ return self._forward(xs, x_masks, False)
142
+
143
+ def inference(self, xs, x_masks=None):
144
+ """Inference duration.
145
+ Args:
146
+ xs (Tensor): Batch of input sequences (B, Tmax, idim).
147
+ x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
148
+ Returns:
149
+ LongTensor: Batch of predicted durations in linear domain (B, Tmax).
150
+ """
151
+ return self._forward(xs, x_masks, True)
152
+
153
+
154
+ class LengthRegulator(torch.nn.Module):
155
+ def __init__(self, pad_value=0.0):
156
+ super(LengthRegulator, self).__init__()
157
+ self.pad_value = pad_value
158
+
159
+ def forward(self, dur, dur_padding=None, alpha=1.0):
160
+ """
161
+ Example (no batch dim version):
162
+ 1. dur = [2,2,3]
163
+ 2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
164
+ 3. token_mask = [[1,1,0,0,0,0,0],
165
+ [0,0,1,1,0,0,0],
166
+ [0,0,0,0,1,1,1]]
167
+ 4. token_idx * token_mask = [[1,1,0,0,0,0,0],
168
+ [0,0,2,2,0,0,0],
169
+ [0,0,0,0,3,3,3]]
170
+ 5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]
171
+
172
+ :param dur: Batch of durations of each frame (B, T_txt)
173
+ :param dur_padding: Batch of padding of each frame (B, T_txt)
174
+ :param alpha: duration rescale coefficient
175
+ :return:
176
+ mel2ph (B, T_speech)
177
+ """
178
+ assert alpha > 0
179
+ dur = torch.round(dur.float() * alpha).long()
180
+ if dur_padding is not None:
181
+ dur = dur * (1 - dur_padding.long())
182
+ token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
183
+ dur_cumsum = torch.cumsum(dur, 1)
184
+ dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)
185
+
186
+ pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
187
+ token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
188
+ mel2ph = (token_idx * token_mask.long()).sum(1)
189
+ return mel2ph
190
+
191
+
192
+ class PitchPredictor(torch.nn.Module):
193
+ def __init__(self, idim, n_layers=5, n_chans=384, odim=2, kernel_size=5,
194
+ dropout_rate=0.1, padding='SAME'):
195
+ """Initilize pitch predictor module.
196
+ Args:
197
+ idim (int): Input dimension.
198
+ n_layers (int, optional): Number of convolutional layers.
199
+ n_chans (int, optional): Number of channels of convolutional layers.
200
+ kernel_size (int, optional): Kernel size of convolutional layers.
201
+ dropout_rate (float, optional): Dropout rate.
202
+ """
203
+ super(PitchPredictor, self).__init__()
204
+ self.conv = torch.nn.ModuleList()
205
+ self.kernel_size = kernel_size
206
+ self.padding = padding
207
+ for idx in range(n_layers):
208
+ in_chans = idim if idx == 0 else n_chans
209
+ self.conv += [torch.nn.Sequential(
210
+ torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
211
+ if padding == 'SAME'
212
+ else (kernel_size - 1, 0), 0),
213
+ torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
214
+ torch.nn.ReLU(),
215
+ LayerNorm(n_chans, dim=1),
216
+ torch.nn.Dropout(dropout_rate)
217
+ )]
218
+ self.linear = torch.nn.Linear(n_chans, odim)
219
+ self.embed_positions = SinusoidalPositionalEmbedding(idim, 0, init_size=4096)
220
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1]))
221
+
222
+ def forward(self, xs):
223
+ """
224
+
225
+ :param xs: [B, T, H]
226
+ :return: [B, T, H]
227
+ """
228
+ positions = self.pos_embed_alpha * self.embed_positions(xs[..., 0])
229
+ xs = xs + positions
230
+ xs = xs.transpose(1, -1) # (B, idim, Tmax)
231
+ for f in self.conv:
232
+ xs = f(xs) # (B, C, Tmax)
233
+ # NOTE: calculate in log domain
234
+ xs = self.linear(xs.transpose(1, -1)) # (B, Tmax, H)
235
+ return xs
236
+
237
+
238
+ class EnergyPredictor(PitchPredictor):
239
+ pass
240
+
241
+
242
+ def mel2ph_to_dur(mel2ph, T_txt, max_dur=None):
243
+ B, _ = mel2ph.shape
244
+ dur = mel2ph.new_zeros(B, T_txt + 1).scatter_add(1, mel2ph, torch.ones_like(mel2ph))
245
+ dur = dur[:, 1:]
246
+ if max_dur is not None:
247
+ dur = dur.clamp(max=max_dur)
248
+ return dur
249
+
250
+
251
+ class FFTBlocks(nn.Module):
252
+ def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=None, num_heads=2,
253
+ use_pos_embed=True, use_last_norm=True, norm='ln', use_pos_embed_alpha=True):
254
+ super().__init__()
255
+ self.num_layers = num_layers
256
+ embed_dim = self.hidden_size = hidden_size
257
+ self.dropout = dropout if dropout is not None else hparams['dropout']
258
+ self.use_pos_embed = use_pos_embed
259
+ self.use_last_norm = use_last_norm
260
+ if use_pos_embed:
261
+ self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
262
+ self.padding_idx = 0
263
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
264
+ self.embed_positions = SinusoidalPositionalEmbedding(
265
+ embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
266
+ )
267
+
268
+ self.layers = nn.ModuleList([])
269
+ self.layers.extend([
270
+ TransformerEncoderLayer(self.hidden_size, self.dropout,
271
+ kernel_size=ffn_kernel_size, num_heads=num_heads)
272
+ for _ in range(self.num_layers)
273
+ ])
274
+ if self.use_last_norm:
275
+ if norm == 'ln':
276
+ self.layer_norm = nn.LayerNorm(embed_dim)
277
+ elif norm == 'bn':
278
+ self.layer_norm = BatchNorm1dTBC(embed_dim)
279
+ else:
280
+ self.layer_norm = None
281
+
282
+ def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
283
+ """
284
+ :param x: [B, T, C]
285
+ :param padding_mask: [B, T]
286
+ :return: [B, T, C] or [L, B, T, C]
287
+ """
288
+ padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
289
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
290
+ if self.use_pos_embed:
291
+ positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
292
+ x = x + positions
293
+ x = F.dropout(x, p=self.dropout, training=self.training)
294
+ # B x T x C -> T x B x C
295
+ x = x.transpose(0, 1) * nonpadding_mask_TB
296
+ hiddens = []
297
+ for layer in self.layers:
298
+ x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
299
+ hiddens.append(x)
300
+ if self.use_last_norm:
301
+ x = self.layer_norm(x) * nonpadding_mask_TB
302
+ if return_hiddens:
303
+ x = torch.stack(hiddens, 0) # [L, T, B, C]
304
+ x = x.transpose(1, 2) # [L, B, T, C]
305
+ else:
306
+ x = x.transpose(0, 1) # [B, T, C]
307
+ return x
308
+
309
+
310
+ class FastspeechEncoder(FFTBlocks):
311
+ def __init__(self, embed_tokens, hidden_size=None, num_layers=None, kernel_size=None, num_heads=2):
312
+ hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
313
+ kernel_size = hparams['enc_ffn_kernel_size'] if kernel_size is None else kernel_size
314
+ num_layers = hparams['dec_layers'] if num_layers is None else num_layers
315
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
316
+ use_pos_embed=False) # use_pos_embed_alpha for compatibility
317
+ self.embed_tokens = embed_tokens
318
+ self.embed_scale = math.sqrt(hidden_size)
319
+ self.padding_idx = 0
320
+ if hparams.get('rel_pos') is not None and hparams['rel_pos']:
321
+ self.embed_positions = RelPositionalEncoding(hidden_size, dropout_rate=0.0)
322
+ else:
323
+ self.embed_positions = SinusoidalPositionalEmbedding(
324
+ hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
325
+ )
326
+
327
+ def forward(self, txt_tokens):
328
+ """
329
+
330
+ :param txt_tokens: [B, T]
331
+ :return: {
332
+ 'encoder_out': [T x B x C]
333
+ }
334
+ """
335
+ encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
336
+ x = self.forward_embedding(txt_tokens) # [B, T, H]
337
+ x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask)
338
+ return x
339
+
340
+ def forward_embedding(self, txt_tokens):
341
+ # embed tokens and positions
342
+ x = self.embed_scale * self.embed_tokens(txt_tokens)
343
+ if hparams['use_pos_embed']:
344
+ positions = self.embed_positions(txt_tokens)
345
+ x = x + positions
346
+ x = F.dropout(x, p=self.dropout, training=self.training)
347
+ return x
348
+
349
+
350
+ class FastspeechDecoder(FFTBlocks):
351
+ def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None):
352
+ num_heads = hparams['num_heads'] if num_heads is None else num_heads
353
+ hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
354
+ kernel_size = hparams['dec_ffn_kernel_size'] if kernel_size is None else kernel_size
355
+ num_layers = hparams['dec_layers'] if num_layers is None else num_layers
356
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
357
+