Vladimir Alabov commited on
Commit
1ba66c7
1 Parent(s): 03d977c

Refactor #2

Browse files
Files changed (49) hide show
  1. requirements.txt +1 -0
  2. so_vits_svc_fork/__init__.py +0 -5
  3. so_vits_svc_fork/__main__.py +0 -917
  4. so_vits_svc_fork/cluster/__init__.py +0 -48
  5. so_vits_svc_fork/cluster/train_cluster.py +0 -141
  6. so_vits_svc_fork/dataset.py +0 -87
  7. so_vits_svc_fork/default_gui_presets.json +0 -92
  8. so_vits_svc_fork/f0.py +0 -239
  9. so_vits_svc_fork/gui.py +0 -851
  10. so_vits_svc_fork/hparams.py +0 -38
  11. so_vits_svc_fork/inference/__init__.py +0 -0
  12. so_vits_svc_fork/inference/core.py +0 -692
  13. so_vits_svc_fork/inference/main.py +0 -272
  14. so_vits_svc_fork/logger.py +0 -46
  15. so_vits_svc_fork/modules/__init__.py +0 -0
  16. so_vits_svc_fork/modules/attentions.py +0 -488
  17. so_vits_svc_fork/modules/commons.py +0 -132
  18. so_vits_svc_fork/modules/decoders/__init__.py +0 -0
  19. so_vits_svc_fork/modules/decoders/f0.py +0 -46
  20. so_vits_svc_fork/modules/decoders/hifigan/__init__.py +0 -3
  21. so_vits_svc_fork/modules/decoders/hifigan/_models.py +0 -311
  22. so_vits_svc_fork/modules/decoders/hifigan/_utils.py +0 -15
  23. so_vits_svc_fork/modules/decoders/mb_istft/__init__.py +0 -15
  24. so_vits_svc_fork/modules/decoders/mb_istft/_generators.py +0 -376
  25. so_vits_svc_fork/modules/decoders/mb_istft/_loss.py +0 -11
  26. so_vits_svc_fork/modules/decoders/mb_istft/_pqmf.py +0 -128
  27. so_vits_svc_fork/modules/decoders/mb_istft/_stft.py +0 -244
  28. so_vits_svc_fork/modules/decoders/mb_istft/_stft_loss.py +0 -142
  29. so_vits_svc_fork/modules/descriminators.py +0 -177
  30. so_vits_svc_fork/modules/encoders.py +0 -136
  31. so_vits_svc_fork/modules/flows.py +0 -48
  32. so_vits_svc_fork/modules/losses.py +0 -58
  33. so_vits_svc_fork/modules/mel_processing.py +0 -205
  34. so_vits_svc_fork/modules/modules.py +0 -452
  35. so_vits_svc_fork/modules/synthesizers.py +0 -233
  36. so_vits_svc_fork/preprocessing/__init__.py +0 -0
  37. so_vits_svc_fork/preprocessing/config_templates/quickvc.json +0 -78
  38. so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1-legacy.json +0 -69
  39. so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1.json +0 -71
  40. so_vits_svc_fork/preprocessing/preprocess_classify.py +0 -95
  41. so_vits_svc_fork/preprocessing/preprocess_flist_config.py +0 -86
  42. so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py +0 -157
  43. so_vits_svc_fork/preprocessing/preprocess_resample.py +0 -144
  44. so_vits_svc_fork/preprocessing/preprocess_speaker_diarization.py +0 -93
  45. so_vits_svc_fork/preprocessing/preprocess_split.py +0 -78
  46. so_vits_svc_fork/preprocessing/preprocess_utils.py +0 -5
  47. so_vits_svc_fork/py.typed +0 -0
  48. so_vits_svc_fork/train.py +0 -571
  49. so_vits_svc_fork/utils.py +0 -478
requirements.txt CHANGED
@@ -28,3 +28,4 @@ faiss-cpu
28
  wheel
29
  ipython
30
  cm_time
 
 
28
  wheel
29
  ipython
30
  cm_time
31
+ so-vits-svc-fork
so_vits_svc_fork/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- __version__ = "4.1.1"
2
-
3
- from .logger import init_logger
4
-
5
- init_logger()
 
 
 
 
 
 
so_vits_svc_fork/__main__.py DELETED
@@ -1,917 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- from logging import getLogger
5
- from multiprocessing import freeze_support
6
- from pathlib import Path
7
- from typing import Literal
8
-
9
- import click
10
- import torch
11
-
12
- from so_vits_svc_fork import __version__
13
- from so_vits_svc_fork.utils import get_optimal_device
14
-
15
- LOG = getLogger(__name__)
16
-
17
- IS_TEST = "test" in Path(__file__).parent.stem
18
- if IS_TEST:
19
- LOG.debug("Test mode is on.")
20
-
21
-
22
- class RichHelpFormatter(click.HelpFormatter):
23
- def __init__(
24
- self,
25
- indent_increment: int = 2,
26
- width: int | None = None,
27
- max_width: int | None = None,
28
- ) -> None:
29
- width = 100
30
- super().__init__(indent_increment, width, max_width)
31
- LOG.info(f"Version: {__version__}")
32
-
33
-
34
- def patch_wrap_text():
35
- orig_wrap_text = click.formatting.wrap_text
36
-
37
- def wrap_text(
38
- text,
39
- width=78,
40
- initial_indent="",
41
- subsequent_indent="",
42
- preserve_paragraphs=False,
43
- ):
44
- return orig_wrap_text(
45
- text.replace("\n", "\n\n"),
46
- width=width,
47
- initial_indent=initial_indent,
48
- subsequent_indent=subsequent_indent,
49
- preserve_paragraphs=True,
50
- ).replace("\n\n", "\n")
51
-
52
- click.formatting.wrap_text = wrap_text
53
-
54
-
55
- patch_wrap_text()
56
-
57
- CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"], show_default=True)
58
- click.Context.formatter_class = RichHelpFormatter
59
-
60
-
61
- @click.group(context_settings=CONTEXT_SETTINGS)
62
- def cli():
63
- """so-vits-svc allows any folder structure for training data.
64
- However, the following folder structure is recommended.\n
65
- When training: dataset_raw/{speaker_name}/**/{wav_name}.{any_format}\n
66
- When inference: configs/44k/config.json, logs/44k/G_XXXX.pth\n
67
- If the folder structure is followed, you DO NOT NEED TO SPECIFY model path, config path, etc.
68
- (The latest model will be automatically loaded.)\n
69
- To train a model, run pre-resample, pre-config, pre-hubert, train.\n
70
- To infer a model, run infer.
71
- """
72
-
73
-
74
- @cli.command()
75
- @click.option(
76
- "-c",
77
- "--config-path",
78
- type=click.Path(exists=True),
79
- help="path to config",
80
- default=Path("./configs/44k/config.json"),
81
- )
82
- @click.option(
83
- "-m",
84
- "--model-path",
85
- type=click.Path(),
86
- help="path to output dir",
87
- default=Path("./logs/44k"),
88
- )
89
- @click.option(
90
- "-t/-nt",
91
- "--tensorboard/--no-tensorboard",
92
- default=False,
93
- type=bool,
94
- help="launch tensorboard",
95
- )
96
- @click.option(
97
- "-r",
98
- "--reset-optimizer",
99
- default=False,
100
- type=bool,
101
- help="reset optimizer",
102
- is_flag=True,
103
- )
104
- def train(
105
- config_path: Path,
106
- model_path: Path,
107
- tensorboard: bool = False,
108
- reset_optimizer: bool = False,
109
- ):
110
- """Train model
111
- If D_0.pth or G_0.pth not found, automatically download from hub."""
112
- from .train import train
113
-
114
- config_path = Path(config_path)
115
- model_path = Path(model_path)
116
-
117
- if tensorboard:
118
- import webbrowser
119
-
120
- from tensorboard import program
121
-
122
- getLogger("tensorboard").setLevel(30)
123
- tb = program.TensorBoard()
124
- tb.configure(argv=[None, "--logdir", model_path.as_posix()])
125
- url = tb.launch()
126
- webbrowser.open(url)
127
-
128
- train(
129
- config_path=config_path, model_path=model_path, reset_optimizer=reset_optimizer
130
- )
131
-
132
-
133
- @cli.command()
134
- def gui():
135
- """Opens GUI
136
- for conversion and realtime inference"""
137
- from .gui import main
138
-
139
- main()
140
-
141
-
142
- @cli.command()
143
- @click.argument(
144
- "input-path",
145
- type=click.Path(exists=True),
146
- )
147
- @click.option(
148
- "-o",
149
- "--output-path",
150
- type=click.Path(),
151
- help="path to output dir",
152
- )
153
- @click.option("-s", "--speaker", type=str, default=None, help="speaker name")
154
- @click.option(
155
- "-m",
156
- "--model-path",
157
- type=click.Path(exists=True),
158
- default=Path("./logs/44k/"),
159
- help="path to model",
160
- )
161
- @click.option(
162
- "-c",
163
- "--config-path",
164
- type=click.Path(exists=True),
165
- default=Path("./configs/44k/config.json"),
166
- help="path to config",
167
- )
168
- @click.option(
169
- "-k",
170
- "--cluster-model-path",
171
- type=click.Path(exists=True),
172
- default=None,
173
- help="path to cluster model",
174
- )
175
- @click.option(
176
- "-re",
177
- "--recursive",
178
- type=bool,
179
- default=False,
180
- help="Search recursively",
181
- is_flag=True,
182
- )
183
- @click.option("-t", "--transpose", type=int, default=0, help="transpose")
184
- @click.option(
185
- "-db", "--db-thresh", type=int, default=-20, help="threshold (DB) (RELATIVE)"
186
- )
187
- @click.option(
188
- "-fm",
189
- "--f0-method",
190
- type=click.Choice(["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"]),
191
- default="dio",
192
- help="f0 prediction method",
193
- )
194
- @click.option(
195
- "-a/-na",
196
- "--auto-predict-f0/--no-auto-predict-f0",
197
- type=bool,
198
- default=True,
199
- help="auto predict f0",
200
- )
201
- @click.option(
202
- "-r", "--cluster-infer-ratio", type=float, default=0, help="cluster infer ratio"
203
- )
204
- @click.option("-n", "--noise-scale", type=float, default=0.4, help="noise scale")
205
- @click.option("-p", "--pad-seconds", type=float, default=0.5, help="pad seconds")
206
- @click.option(
207
- "-d",
208
- "--device",
209
- type=str,
210
- default=get_optimal_device(),
211
- help="device",
212
- )
213
- @click.option("-ch", "--chunk-seconds", type=float, default=0.5, help="chunk seconds")
214
- @click.option(
215
- "-ab/-nab",
216
- "--absolute-thresh/--no-absolute-thresh",
217
- type=bool,
218
- default=False,
219
- help="absolute thresh",
220
- )
221
- @click.option(
222
- "-mc",
223
- "--max-chunk-seconds",
224
- type=float,
225
- default=40,
226
- help="maximum allowed single chunk length, set lower if you get out of memory (0 to disable)",
227
- )
228
- def infer(
229
- # paths
230
- input_path: Path,
231
- output_path: Path,
232
- model_path: Path,
233
- config_path: Path,
234
- recursive: bool,
235
- # svc config
236
- speaker: str,
237
- cluster_model_path: Path | None = None,
238
- transpose: int = 0,
239
- auto_predict_f0: bool = False,
240
- cluster_infer_ratio: float = 0,
241
- noise_scale: float = 0.4,
242
- f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
243
- # slice config
244
- db_thresh: int = -40,
245
- pad_seconds: float = 0.5,
246
- chunk_seconds: float = 0.5,
247
- absolute_thresh: bool = False,
248
- max_chunk_seconds: float = 40,
249
- device: str | torch.device = get_optimal_device(),
250
- ):
251
- """Inference"""
252
- from so_vits_svc_fork.inference.main import infer
253
-
254
- if not auto_predict_f0:
255
- LOG.warning(
256
- f"auto_predict_f0 = False, transpose = {transpose}. If you want to change the pitch, please set transpose."
257
- "Generally transpose = 0 does not work because your voice pitch and target voice pitch are different."
258
- )
259
-
260
- input_path = Path(input_path)
261
- if output_path is None:
262
- output_path = input_path.parent / f"{input_path.stem}.out{input_path.suffix}"
263
- output_path = Path(output_path)
264
- if input_path.is_dir() and not recursive:
265
- raise ValueError(
266
- "input_path is a directory. Use 0re or --recursive to infer recursively."
267
- )
268
- model_path = Path(model_path)
269
- if model_path.is_dir():
270
- model_path = list(
271
- sorted(model_path.glob("G_*.pth"), key=lambda x: x.stat().st_mtime)
272
- )[-1]
273
- LOG.info(f"Since model_path is a directory, use {model_path}")
274
- config_path = Path(config_path)
275
- if cluster_model_path is not None:
276
- cluster_model_path = Path(cluster_model_path)
277
- infer(
278
- # paths
279
- input_path=input_path,
280
- output_path=output_path,
281
- model_path=model_path,
282
- config_path=config_path,
283
- recursive=recursive,
284
- # svc config
285
- speaker=speaker,
286
- cluster_model_path=cluster_model_path,
287
- transpose=transpose,
288
- auto_predict_f0=auto_predict_f0,
289
- cluster_infer_ratio=cluster_infer_ratio,
290
- noise_scale=noise_scale,
291
- f0_method=f0_method,
292
- # slice config
293
- db_thresh=db_thresh,
294
- pad_seconds=pad_seconds,
295
- chunk_seconds=chunk_seconds,
296
- absolute_thresh=absolute_thresh,
297
- max_chunk_seconds=max_chunk_seconds,
298
- device=device,
299
- )
300
-
301
-
302
- @cli.command()
303
- @click.option(
304
- "-m",
305
- "--model-path",
306
- type=click.Path(exists=True),
307
- default=Path("./logs/44k/"),
308
- help="path to model",
309
- )
310
- @click.option(
311
- "-c",
312
- "--config-path",
313
- type=click.Path(exists=True),
314
- default=Path("./configs/44k/config.json"),
315
- help="path to config",
316
- )
317
- @click.option(
318
- "-k",
319
- "--cluster-model-path",
320
- type=click.Path(exists=True),
321
- default=None,
322
- help="path to cluster model",
323
- )
324
- @click.option("-t", "--transpose", type=int, default=12, help="transpose")
325
- @click.option(
326
- "-a/-na",
327
- "--auto-predict-f0/--no-auto-predict-f0",
328
- type=bool,
329
- default=True,
330
- help="auto predict f0 (not recommended for realtime since voice pitch will not be stable)",
331
- )
332
- @click.option(
333
- "-r", "--cluster-infer-ratio", type=float, default=0, help="cluster infer ratio"
334
- )
335
- @click.option("-n", "--noise-scale", type=float, default=0.4, help="noise scale")
336
- @click.option(
337
- "-db", "--db-thresh", type=int, default=-30, help="threshold (DB) (ABSOLUTE)"
338
- )
339
- @click.option(
340
- "-fm",
341
- "--f0-method",
342
- type=click.Choice(["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"]),
343
- default="dio",
344
- help="f0 prediction method",
345
- )
346
- @click.option("-p", "--pad-seconds", type=float, default=0.02, help="pad seconds")
347
- @click.option("-ch", "--chunk-seconds", type=float, default=0.5, help="chunk seconds")
348
- @click.option(
349
- "-cr",
350
- "--crossfade-seconds",
351
- type=float,
352
- default=0.01,
353
- help="crossfade seconds",
354
- )
355
- @click.option(
356
- "-ab",
357
- "--additional-infer-before-seconds",
358
- type=float,
359
- default=0.2,
360
- help="additional infer before seconds",
361
- )
362
- @click.option(
363
- "-aa",
364
- "--additional-infer-after-seconds",
365
- type=float,
366
- default=0.1,
367
- help="additional infer after seconds",
368
- )
369
- @click.option("-b", "--block-seconds", type=float, default=0.5, help="block seconds")
370
- @click.option(
371
- "-d",
372
- "--device",
373
- type=str,
374
- default=get_optimal_device(),
375
- help="device",
376
- )
377
- @click.option("-s", "--speaker", type=str, default=None, help="speaker name")
378
- @click.option("-v", "--version", type=int, default=2, help="version")
379
- @click.option("-i", "--input-device", type=int, default=None, help="input device")
380
- @click.option("-o", "--output-device", type=int, default=None, help="output device")
381
- @click.option(
382
- "-po",
383
- "--passthrough-original",
384
- type=bool,
385
- default=False,
386
- is_flag=True,
387
- help="passthrough original (for latency check)",
388
- )
389
- def vc(
390
- # paths
391
- model_path: Path,
392
- config_path: Path,
393
- # svc config
394
- speaker: str,
395
- cluster_model_path: Path | None,
396
- transpose: int,
397
- auto_predict_f0: bool,
398
- cluster_infer_ratio: float,
399
- noise_scale: float,
400
- f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"],
401
- # slice config
402
- db_thresh: int,
403
- pad_seconds: float,
404
- chunk_seconds: float,
405
- # realtime config
406
- crossfade_seconds: float,
407
- additional_infer_before_seconds: float,
408
- additional_infer_after_seconds: float,
409
- block_seconds: float,
410
- version: int,
411
- input_device: int | str | None,
412
- output_device: int | str | None,
413
- device: torch.device,
414
- passthrough_original: bool = False,
415
- ) -> None:
416
- """Realtime inference from microphone"""
417
- from so_vits_svc_fork.inference.main import realtime
418
-
419
- if auto_predict_f0:
420
- LOG.warning(
421
- "auto_predict_f0 = True in realtime inference will cause unstable voice pitch, use with caution"
422
- )
423
- else:
424
- LOG.warning(
425
- f"auto_predict_f0 = False, transpose = {transpose}. If you want to change the pitch, please change the transpose value."
426
- "Generally transpose = 0 does not work because your voice pitch and target voice pitch are different."
427
- )
428
- model_path = Path(model_path)
429
- config_path = Path(config_path)
430
- if cluster_model_path is not None:
431
- cluster_model_path = Path(cluster_model_path)
432
- if model_path.is_dir():
433
- model_path = list(
434
- sorted(model_path.glob("G_*.pth"), key=lambda x: x.stat().st_mtime)
435
- )[-1]
436
- LOG.info(f"Since model_path is a directory, use {model_path}")
437
-
438
- realtime(
439
- # paths
440
- model_path=model_path,
441
- config_path=config_path,
442
- # svc config
443
- speaker=speaker,
444
- cluster_model_path=cluster_model_path,
445
- transpose=transpose,
446
- auto_predict_f0=auto_predict_f0,
447
- cluster_infer_ratio=cluster_infer_ratio,
448
- noise_scale=noise_scale,
449
- f0_method=f0_method,
450
- # slice config
451
- db_thresh=db_thresh,
452
- pad_seconds=pad_seconds,
453
- chunk_seconds=chunk_seconds,
454
- # realtime config
455
- crossfade_seconds=crossfade_seconds,
456
- additional_infer_before_seconds=additional_infer_before_seconds,
457
- additional_infer_after_seconds=additional_infer_after_seconds,
458
- block_seconds=block_seconds,
459
- version=version,
460
- input_device=input_device,
461
- output_device=output_device,
462
- device=device,
463
- passthrough_original=passthrough_original,
464
- )
465
-
466
-
467
- @cli.command()
468
- @click.option(
469
- "-i",
470
- "--input-dir",
471
- type=click.Path(exists=True),
472
- default=Path("./dataset_raw"),
473
- help="path to source dir",
474
- )
475
- @click.option(
476
- "-o",
477
- "--output-dir",
478
- type=click.Path(),
479
- default=Path("./dataset/44k"),
480
- help="path to output dir",
481
- )
482
- @click.option("-s", "--sampling-rate", type=int, default=44100, help="sampling rate")
483
- @click.option(
484
- "-n",
485
- "--n-jobs",
486
- type=int,
487
- default=-1,
488
- help="number of jobs (optimal value may depend on your RAM capacity and audio duration per file)",
489
- )
490
- @click.option("-d", "--top-db", type=float, default=30, help="top db")
491
- @click.option("-f", "--frame-seconds", type=float, default=1, help="frame seconds")
492
- @click.option(
493
- "-ho", "-hop", "--hop-seconds", type=float, default=0.3, help="hop seconds"
494
- )
495
- def pre_resample(
496
- input_dir: Path,
497
- output_dir: Path,
498
- sampling_rate: int,
499
- n_jobs: int,
500
- top_db: int,
501
- frame_seconds: float,
502
- hop_seconds: float,
503
- ) -> None:
504
- """Preprocessing part 1: resample"""
505
- from so_vits_svc_fork.preprocessing.preprocess_resample import preprocess_resample
506
-
507
- input_dir = Path(input_dir)
508
- output_dir = Path(output_dir)
509
- preprocess_resample(
510
- input_dir=input_dir,
511
- output_dir=output_dir,
512
- sampling_rate=sampling_rate,
513
- n_jobs=n_jobs,
514
- top_db=top_db,
515
- frame_seconds=frame_seconds,
516
- hop_seconds=hop_seconds,
517
- )
518
-
519
-
520
- from so_vits_svc_fork.preprocessing.preprocess_flist_config import CONFIG_TEMPLATE_DIR
521
-
522
-
523
- @cli.command()
524
- @click.option(
525
- "-i",
526
- "--input-dir",
527
- type=click.Path(exists=True),
528
- default=Path("./dataset/44k"),
529
- help="path to source dir",
530
- )
531
- @click.option(
532
- "-f",
533
- "--filelist-path",
534
- type=click.Path(),
535
- default=Path("./filelists/44k"),
536
- help="path to filelist dir",
537
- )
538
- @click.option(
539
- "-c",
540
- "--config-path",
541
- type=click.Path(),
542
- default=Path("./configs/44k/config.json"),
543
- help="path to config",
544
- )
545
- @click.option(
546
- "-t",
547
- "--config-type",
548
- type=click.Choice([x.stem for x in CONFIG_TEMPLATE_DIR.rglob("*.json")]),
549
- default="so-vits-svc-4.0v1",
550
- help="config type",
551
- )
552
- def pre_config(
553
- input_dir: Path,
554
- filelist_path: Path,
555
- config_path: Path,
556
- config_type: str,
557
- ):
558
- """Preprocessing part 2: config"""
559
- from so_vits_svc_fork.preprocessing.preprocess_flist_config import preprocess_config
560
-
561
- input_dir = Path(input_dir)
562
- filelist_path = Path(filelist_path)
563
- config_path = Path(config_path)
564
- preprocess_config(
565
- input_dir=input_dir,
566
- train_list_path=filelist_path / "train.txt",
567
- val_list_path=filelist_path / "val.txt",
568
- test_list_path=filelist_path / "test.txt",
569
- config_path=config_path,
570
- config_name=config_type,
571
- )
572
-
573
-
574
- @cli.command()
575
- @click.option(
576
- "-i",
577
- "--input-dir",
578
- type=click.Path(exists=True),
579
- default=Path("./dataset/44k"),
580
- help="path to source dir",
581
- )
582
- @click.option(
583
- "-c",
584
- "--config-path",
585
- type=click.Path(exists=True),
586
- help="path to config",
587
- default=Path("./configs/44k/config.json"),
588
- )
589
- @click.option(
590
- "-n",
591
- "--n-jobs",
592
- type=int,
593
- default=None,
594
- help="number of jobs (optimal value may depend on your VRAM capacity and audio duration per file)",
595
- )
596
- @click.option(
597
- "-f/-nf",
598
- "--force-rebuild/--no-force-rebuild",
599
- type=bool,
600
- default=True,
601
- help="force rebuild existing preprocessed files",
602
- )
603
- @click.option(
604
- "-fm",
605
- "--f0-method",
606
- type=click.Choice(["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"]),
607
- default="dio",
608
- )
609
- def pre_hubert(
610
- input_dir: Path,
611
- config_path: Path,
612
- n_jobs: bool,
613
- force_rebuild: bool,
614
- f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"],
615
- ) -> None:
616
- """Preprocessing part 3: hubert
617
- If the HuBERT model is not found, it will be downloaded automatically."""
618
- from so_vits_svc_fork.preprocessing.preprocess_hubert_f0 import preprocess_hubert_f0
619
-
620
- input_dir = Path(input_dir)
621
- config_path = Path(config_path)
622
- preprocess_hubert_f0(
623
- input_dir=input_dir,
624
- config_path=config_path,
625
- n_jobs=n_jobs,
626
- force_rebuild=force_rebuild,
627
- f0_method=f0_method,
628
- )
629
-
630
-
631
- @cli.command()
632
- @click.option(
633
- "-i",
634
- "--input-dir",
635
- type=click.Path(exists=True),
636
- default=Path("./dataset_raw_raw/"),
637
- help="path to source dir",
638
- )
639
- @click.option(
640
- "-o",
641
- "--output-dir",
642
- type=click.Path(),
643
- default=Path("./dataset_raw/"),
644
- help="path to output dir",
645
- )
646
- @click.option(
647
- "-n",
648
- "--n-jobs",
649
- type=int,
650
- default=-1,
651
- help="number of jobs (optimal value may depend on your VRAM capacity and audio duration per file)",
652
- )
653
- @click.option("-min", "--min-speakers", type=int, default=2, help="min speakers")
654
- @click.option("-max", "--max-speakers", type=int, default=2, help="max speakers")
655
- @click.option(
656
- "-t", "--huggingface-token", type=str, default=None, help="huggingface token"
657
- )
658
- @click.option("-s", "--sr", type=int, default=44100, help="sampling rate")
659
- def pre_sd(
660
- input_dir: Path | str,
661
- output_dir: Path | str,
662
- min_speakers: int,
663
- max_speakers: int,
664
- huggingface_token: str | None,
665
- n_jobs: int,
666
- sr: int,
667
- ):
668
- """Speech diarization using pyannote.audio"""
669
- if huggingface_token is None:
670
- huggingface_token = os.environ.get("HUGGINGFACE_TOKEN", None)
671
- if huggingface_token is None:
672
- huggingface_token = click.prompt(
673
- "Please enter your HuggingFace token", hide_input=True
674
- )
675
- if os.environ.get("HUGGINGFACE_TOKEN", None) is None:
676
- LOG.info("You can also set the HUGGINGFACE_TOKEN environment variable.")
677
- assert huggingface_token is not None
678
- huggingface_token = huggingface_token.rstrip(" \n\r\t\0")
679
- if len(huggingface_token) <= 1:
680
- raise ValueError("HuggingFace token is empty: " + huggingface_token)
681
-
682
- if max_speakers == 1:
683
- LOG.warning("Consider using pre-split if max_speakers == 1")
684
- from so_vits_svc_fork.preprocessing.preprocess_speaker_diarization import (
685
- preprocess_speaker_diarization,
686
- )
687
-
688
- preprocess_speaker_diarization(
689
- input_dir=input_dir,
690
- output_dir=output_dir,
691
- min_speakers=min_speakers,
692
- max_speakers=max_speakers,
693
- huggingface_token=huggingface_token,
694
- n_jobs=n_jobs,
695
- sr=sr,
696
- )
697
-
698
-
699
- @cli.command()
700
- @click.option(
701
- "-i",
702
- "--input-dir",
703
- type=click.Path(exists=True),
704
- default=Path("./dataset_raw_raw/"),
705
- help="path to source dir",
706
- )
707
- @click.option(
708
- "-o",
709
- "--output-dir",
710
- type=click.Path(),
711
- default=Path("./dataset_raw/"),
712
- help="path to output dir",
713
- )
714
- @click.option(
715
- "-n",
716
- "--n-jobs",
717
- type=int,
718
- default=-1,
719
- help="number of jobs (optimal value may depend on your RAM capacity and audio duration per file)",
720
- )
721
- @click.option(
722
- "-l",
723
- "--max-length",
724
- type=float,
725
- default=10,
726
- help="max length of each split in seconds",
727
- )
728
- @click.option("-d", "--top-db", type=float, default=30, help="top db")
729
- @click.option("-f", "--frame-seconds", type=float, default=1, help="frame seconds")
730
- @click.option(
731
- "-ho", "-hop", "--hop-seconds", type=float, default=0.3, help="hop seconds"
732
- )
733
- @click.option("-s", "--sr", type=int, default=44100, help="sample rate")
734
- def pre_split(
735
- input_dir: Path | str,
736
- output_dir: Path | str,
737
- max_length: float,
738
- top_db: int,
739
- frame_seconds: float,
740
- hop_seconds: float,
741
- n_jobs: int,
742
- sr: int,
743
- ):
744
- """Split audio files into multiple files"""
745
- from so_vits_svc_fork.preprocessing.preprocess_split import preprocess_split
746
-
747
- preprocess_split(
748
- input_dir=input_dir,
749
- output_dir=output_dir,
750
- max_length=max_length,
751
- top_db=top_db,
752
- frame_seconds=frame_seconds,
753
- hop_seconds=hop_seconds,
754
- n_jobs=n_jobs,
755
- sr=sr,
756
- )
757
-
758
-
759
- @cli.command()
760
- @click.option(
761
- "-i",
762
- "--input-dir",
763
- type=click.Path(exists=True),
764
- required=True,
765
- help="path to source dir",
766
- )
767
- @click.option(
768
- "-o",
769
- "--output-dir",
770
- type=click.Path(),
771
- default=None,
772
- help="path to output dir",
773
- )
774
- @click.option(
775
- "-c/-nc",
776
- "--create-new/--no-create-new",
777
- type=bool,
778
- default=True,
779
- help="create a new folder for the speaker if not exist",
780
- )
781
- def pre_classify(
782
- input_dir: Path | str,
783
- output_dir: Path | str | None,
784
- create_new: bool,
785
- ) -> None:
786
- """Classify multiple audio files into multiple files"""
787
- from so_vits_svc_fork.preprocessing.preprocess_classify import preprocess_classify
788
-
789
- if output_dir is None:
790
- output_dir = input_dir
791
- preprocess_classify(
792
- input_dir=input_dir,
793
- output_dir=output_dir,
794
- create_new=create_new,
795
- )
796
-
797
-
798
- @cli.command
799
- def clean():
800
- """Clean up files, only useful if you are using the default file structure"""
801
- import shutil
802
-
803
- folders = ["dataset", "filelists", "logs"]
804
- # if pyip.inputYesNo(f"Are you sure you want to delete files in {folders}?") == "yes":
805
- if input("Are you sure you want to delete files in {folders}?") in ["yes", "y"]:
806
- for folder in folders:
807
- if Path(folder).exists():
808
- shutil.rmtree(folder)
809
- LOG.info("Cleaned up files")
810
- else:
811
- LOG.info("Aborted")
812
-
813
-
814
- @cli.command
815
- @click.option(
816
- "-i",
817
- "--input-path",
818
- type=click.Path(exists=True),
819
- help="model path",
820
- default=Path("./logs/44k/"),
821
- )
822
- @click.option(
823
- "-o",
824
- "--output-path",
825
- type=click.Path(),
826
- help="onnx model path to save",
827
- default=None,
828
- )
829
- @click.option(
830
- "-c",
831
- "--config-path",
832
- type=click.Path(),
833
- help="config path",
834
- default=Path("./configs/44k/config.json"),
835
- )
836
- @click.option(
837
- "-d",
838
- "--device",
839
- type=str,
840
- default="cpu",
841
- help="device to use",
842
- )
843
- def onnx(
844
- input_path: Path, output_path: Path, config_path: Path, device: torch.device | str
845
- ) -> None:
846
- """Export model to onnx (currently not working)"""
847
- raise NotImplementedError("ONNX export is not yet supported")
848
- input_path = Path(input_path)
849
- if input_path.is_dir():
850
- input_path = list(input_path.glob("*.pth"))[0]
851
- if output_path is None:
852
- output_path = input_path.with_suffix(".onnx")
853
- output_path = Path(output_path)
854
- if output_path.is_dir():
855
- output_path = output_path / (input_path.stem + ".onnx")
856
- config_path = Path(config_path)
857
- device_ = torch.device(device)
858
- from so_vits_svc_fork.modules.onnx._export import onnx_export
859
-
860
- onnx_export(
861
- input_path=input_path,
862
- output_path=output_path,
863
- config_path=config_path,
864
- device=device_,
865
- )
866
-
867
-
868
- @cli.command
869
- @click.option(
870
- "-i",
871
- "--input-dir",
872
- type=click.Path(exists=True),
873
- help="dataset directory",
874
- default=Path("./dataset/44k"),
875
- )
876
- @click.option(
877
- "-o",
878
- "--output-path",
879
- type=click.Path(),
880
- help="model path to save",
881
- default=Path("./logs/44k/kmeans.pt"),
882
- )
883
- @click.option("-n", "--n-clusters", type=int, help="number of clusters", default=2000)
884
- @click.option(
885
- "-m/-nm", "--minibatch/--no-minibatch", default=True, help="use minibatch k-means"
886
- )
887
- @click.option(
888
- "-b", "--batch-size", type=int, default=4096, help="batch size for minibatch kmeans"
889
- )
890
- @click.option(
891
- "-p/-np", "--partial-fit", default=False, help="use partial fit (only use with -m)"
892
- )
893
- def train_cluster(
894
- input_dir: Path,
895
- output_path: Path,
896
- n_clusters: int,
897
- minibatch: bool,
898
- batch_size: int,
899
- partial_fit: bool,
900
- ) -> None:
901
- """Train k-means clustering"""
902
- from .cluster.train_cluster import main
903
-
904
- main(
905
- input_dir=input_dir,
906
- output_path=output_path,
907
- n_clusters=n_clusters,
908
- verbose=True,
909
- use_minibatch=minibatch,
910
- batch_size=batch_size,
911
- partial_fit=partial_fit,
912
- )
913
-
914
-
915
- if __name__ == "__main__":
916
- freeze_support()
917
- cli()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/cluster/__init__.py DELETED
@@ -1,48 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from pathlib import Path
4
- from typing import Any
5
-
6
- import torch
7
- from sklearn.cluster import KMeans
8
-
9
-
10
- def get_cluster_model(ckpt_path: Path | str):
11
- with Path(ckpt_path).open("rb") as f:
12
- checkpoint = torch.load(
13
- f, map_location="cpu"
14
- ) # Danger of arbitrary code execution
15
- kmeans_dict = {}
16
- for spk, ckpt in checkpoint.items():
17
- km = KMeans(ckpt["n_features_in_"])
18
- km.__dict__["n_features_in_"] = ckpt["n_features_in_"]
19
- km.__dict__["_n_threads"] = ckpt["_n_threads"]
20
- km.__dict__["cluster_centers_"] = ckpt["cluster_centers_"]
21
- kmeans_dict[spk] = km
22
- return kmeans_dict
23
-
24
-
25
- def check_speaker(model: Any, speaker: Any):
26
- if speaker not in model:
27
- raise ValueError(f"Speaker {speaker} not in {list(model.keys())}")
28
-
29
-
30
- def get_cluster_result(model: Any, x: Any, speaker: Any):
31
- """
32
- x: np.array [t, 256]
33
- return cluster class result
34
- """
35
- check_speaker(model, speaker)
36
- return model[speaker].predict(x)
37
-
38
-
39
- def get_cluster_center_result(model: Any, x: Any, speaker: Any):
40
- """x: np.array [t, 256]"""
41
- check_speaker(model, speaker)
42
- predict = model[speaker].predict(x)
43
- return model[speaker].cluster_centers_[predict]
44
-
45
-
46
- def get_center(model: Any, x: Any, speaker: Any):
47
- check_speaker(model, speaker)
48
- return model[speaker].cluster_centers_[x]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/cluster/train_cluster.py DELETED
@@ -1,141 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import math
4
- from logging import getLogger
5
- from pathlib import Path
6
- from typing import Any
7
-
8
- import numpy as np
9
- import torch
10
- from cm_time import timer
11
- from joblib import Parallel, delayed
12
- from sklearn.cluster import KMeans, MiniBatchKMeans
13
- from tqdm_joblib import tqdm_joblib
14
-
15
- LOG = getLogger(__name__)
16
-
17
-
18
- def train_cluster(
19
- input_dir: Path | str,
20
- n_clusters: int,
21
- use_minibatch: bool = True,
22
- batch_size: int = 4096,
23
- partial_fit: bool = False,
24
- verbose: bool = False,
25
- ) -> dict:
26
- input_dir = Path(input_dir)
27
- if not partial_fit:
28
- LOG.info(f"Loading features from {input_dir}")
29
- features = []
30
- for path in input_dir.rglob("*.data.pt"):
31
- with path.open("rb") as f:
32
- features.append(
33
- torch.load(f, weights_only=True)["content"].squeeze(0).numpy().T
34
- )
35
- if not features:
36
- raise ValueError(f"No features found in {input_dir}")
37
- features = np.concatenate(features, axis=0).astype(np.float32)
38
- if features.shape[0] < n_clusters:
39
- raise ValueError(
40
- "Too few HuBERT features to cluster. Consider using a smaller number of clusters."
41
- )
42
- LOG.info(
43
- f"shape: {features.shape}, size: {features.nbytes/1024**2:.2f} MB, dtype: {features.dtype}"
44
- )
45
- with timer() as t:
46
- if use_minibatch:
47
- kmeans = MiniBatchKMeans(
48
- n_clusters=n_clusters,
49
- verbose=verbose,
50
- batch_size=batch_size,
51
- max_iter=80,
52
- n_init="auto",
53
- ).fit(features)
54
- else:
55
- kmeans = KMeans(
56
- n_clusters=n_clusters, verbose=verbose, n_init="auto"
57
- ).fit(features)
58
- LOG.info(f"Clustering took {t.elapsed:.2f} seconds")
59
-
60
- x = {
61
- "n_features_in_": kmeans.n_features_in_,
62
- "_n_threads": kmeans._n_threads,
63
- "cluster_centers_": kmeans.cluster_centers_,
64
- }
65
- return x
66
- else:
67
- # minibatch partial fit
68
- paths = list(input_dir.rglob("*.data.pt"))
69
- if len(paths) == 0:
70
- raise ValueError(f"No features found in {input_dir}")
71
- LOG.info(f"Found {len(paths)} features in {input_dir}")
72
- n_batches = math.ceil(len(paths) / batch_size)
73
- LOG.info(f"Splitting into {n_batches} batches")
74
- with timer() as t:
75
- kmeans = MiniBatchKMeans(
76
- n_clusters=n_clusters,
77
- verbose=verbose,
78
- batch_size=batch_size,
79
- max_iter=80,
80
- n_init="auto",
81
- )
82
- for i in range(0, len(paths), batch_size):
83
- LOG.info(
84
- f"Processing batch {i//batch_size+1}/{n_batches} for speaker {input_dir.stem}"
85
- )
86
- features = []
87
- for path in paths[i : i + batch_size]:
88
- with path.open("rb") as f:
89
- features.append(
90
- torch.load(f, weights_only=True)["content"]
91
- .squeeze(0)
92
- .numpy()
93
- .T
94
- )
95
- features = np.concatenate(features, axis=0).astype(np.float32)
96
- kmeans.partial_fit(features)
97
- LOG.info(f"Clustering took {t.elapsed:.2f} seconds")
98
-
99
- x = {
100
- "n_features_in_": kmeans.n_features_in_,
101
- "_n_threads": kmeans._n_threads,
102
- "cluster_centers_": kmeans.cluster_centers_,
103
- }
104
- return x
105
-
106
-
107
- def main(
108
- input_dir: Path | str,
109
- output_path: Path | str,
110
- n_clusters: int = 10000,
111
- use_minibatch: bool = True,
112
- batch_size: int = 4096,
113
- partial_fit: bool = False,
114
- verbose: bool = False,
115
- ) -> None:
116
- input_dir = Path(input_dir)
117
- output_path = Path(output_path)
118
-
119
- if not (use_minibatch or not partial_fit):
120
- raise ValueError("partial_fit requires use_minibatch")
121
-
122
- def train_cluster_(input_path: Path, **kwargs: Any) -> tuple[str, dict]:
123
- return input_path.stem, train_cluster(input_path, **kwargs)
124
-
125
- with tqdm_joblib(desc="Training clusters", total=len(list(input_dir.iterdir()))):
126
- parallel_result = Parallel(n_jobs=-1)(
127
- delayed(train_cluster_)(
128
- speaker_name,
129
- n_clusters=n_clusters,
130
- use_minibatch=use_minibatch,
131
- batch_size=batch_size,
132
- partial_fit=partial_fit,
133
- verbose=verbose,
134
- )
135
- for speaker_name in input_dir.iterdir()
136
- )
137
- assert parallel_result is not None
138
- checkpoint = dict(parallel_result)
139
- output_path.parent.mkdir(exist_ok=True, parents=True)
140
- with output_path.open("wb") as f:
141
- torch.save(checkpoint, f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/dataset.py DELETED
@@ -1,87 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from pathlib import Path
4
- from random import Random
5
- from typing import Sequence
6
-
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- from torch.utils.data import Dataset
11
-
12
- from .hparams import HParams
13
-
14
-
15
- class TextAudioDataset(Dataset):
16
- def __init__(self, hps: HParams, is_validation: bool = False):
17
- self.datapaths = [
18
- Path(x).parent / (Path(x).name + ".data.pt")
19
- for x in Path(
20
- hps.data.validation_files if is_validation else hps.data.training_files
21
- )
22
- .read_text("utf-8")
23
- .splitlines()
24
- ]
25
- self.hps = hps
26
- self.random = Random(hps.train.seed)
27
- self.random.shuffle(self.datapaths)
28
- self.max_spec_len = 800
29
-
30
- def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
31
- with Path(self.datapaths[index]).open("rb") as f:
32
- data = torch.load(f, weights_only=True, map_location="cpu")
33
-
34
- # cut long data randomly
35
- spec_len = data["mel_spec"].shape[1]
36
- hop_len = self.hps.data.hop_length
37
- if spec_len > self.max_spec_len:
38
- start = self.random.randint(0, spec_len - self.max_spec_len)
39
- end = start + self.max_spec_len - 10
40
- for key in data.keys():
41
- if key == "audio":
42
- data[key] = data[key][:, start * hop_len : end * hop_len]
43
- elif key == "spk":
44
- continue
45
- else:
46
- data[key] = data[key][..., start:end]
47
- torch.cuda.empty_cache()
48
- return data
49
-
50
- def __len__(self) -> int:
51
- return len(self.datapaths)
52
-
53
-
54
- def _pad_stack(array: Sequence[torch.Tensor]) -> torch.Tensor:
55
- max_idx = torch.argmax(torch.tensor([x_.shape[-1] for x_ in array]))
56
- max_x = array[max_idx]
57
- x_padded = [
58
- F.pad(x_, (0, max_x.shape[-1] - x_.shape[-1]), mode="constant", value=0)
59
- for x_ in array
60
- ]
61
- return torch.stack(x_padded)
62
-
63
-
64
- class TextAudioCollate(nn.Module):
65
- def forward(
66
- self, batch: Sequence[dict[str, torch.Tensor]]
67
- ) -> tuple[torch.Tensor, ...]:
68
- batch = [b for b in batch if b is not None]
69
- batch = list(sorted(batch, key=lambda x: x["mel_spec"].shape[1], reverse=True))
70
- lengths = torch.tensor([b["mel_spec"].shape[1] for b in batch]).long()
71
- results = {}
72
- for key in batch[0].keys():
73
- if key not in ["spk"]:
74
- results[key] = _pad_stack([b[key] for b in batch]).cpu()
75
- else:
76
- results[key] = torch.tensor([[b[key]] for b in batch]).cpu()
77
-
78
- return (
79
- results["content"],
80
- results["f0"],
81
- results["spec"],
82
- results["mel_spec"],
83
- results["audio"],
84
- results["spk"],
85
- lengths,
86
- results["uv"],
87
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/default_gui_presets.json DELETED
@@ -1,92 +0,0 @@
1
- {
2
- "Default VC (GPU, GTX 1060)": {
3
- "silence_threshold": -35.0,
4
- "transpose": 12.0,
5
- "auto_predict_f0": false,
6
- "f0_method": "dio",
7
- "cluster_infer_ratio": 0.0,
8
- "noise_scale": 0.4,
9
- "pad_seconds": 0.1,
10
- "chunk_seconds": 0.5,
11
- "absolute_thresh": true,
12
- "max_chunk_seconds": 40,
13
- "crossfade_seconds": 0.05,
14
- "block_seconds": 0.35,
15
- "additional_infer_before_seconds": 0.15,
16
- "additional_infer_after_seconds": 0.1,
17
- "realtime_algorithm": "1 (Divide constantly)",
18
- "passthrough_original": false,
19
- "use_gpu": true
20
- },
21
- "Default VC (CPU)": {
22
- "silence_threshold": -35.0,
23
- "transpose": 12.0,
24
- "auto_predict_f0": false,
25
- "f0_method": "dio",
26
- "cluster_infer_ratio": 0.0,
27
- "noise_scale": 0.4,
28
- "pad_seconds": 0.1,
29
- "chunk_seconds": 0.5,
30
- "absolute_thresh": true,
31
- "max_chunk_seconds": 40,
32
- "crossfade_seconds": 0.05,
33
- "block_seconds": 1.5,
34
- "additional_infer_before_seconds": 0.01,
35
- "additional_infer_after_seconds": 0.01,
36
- "realtime_algorithm": "1 (Divide constantly)",
37
- "passthrough_original": false,
38
- "use_gpu": false
39
- },
40
- "Default VC (Mobile CPU)": {
41
- "silence_threshold": -35.0,
42
- "transpose": 12.0,
43
- "auto_predict_f0": false,
44
- "f0_method": "dio",
45
- "cluster_infer_ratio": 0.0,
46
- "noise_scale": 0.4,
47
- "pad_seconds": 0.1,
48
- "chunk_seconds": 0.5,
49
- "absolute_thresh": true,
50
- "max_chunk_seconds": 40,
51
- "crossfade_seconds": 0.05,
52
- "block_seconds": 2.5,
53
- "additional_infer_before_seconds": 0.01,
54
- "additional_infer_after_seconds": 0.01,
55
- "realtime_algorithm": "1 (Divide constantly)",
56
- "passthrough_original": false,
57
- "use_gpu": false
58
- },
59
- "Default VC (Crooning)": {
60
- "silence_threshold": -35.0,
61
- "transpose": 12.0,
62
- "auto_predict_f0": false,
63
- "f0_method": "dio",
64
- "cluster_infer_ratio": 0.0,
65
- "noise_scale": 0.4,
66
- "pad_seconds": 0.1,
67
- "chunk_seconds": 0.5,
68
- "absolute_thresh": true,
69
- "max_chunk_seconds": 40,
70
- "crossfade_seconds": 0.04,
71
- "block_seconds": 0.15,
72
- "additional_infer_before_seconds": 0.05,
73
- "additional_infer_after_seconds": 0.05,
74
- "realtime_algorithm": "1 (Divide constantly)",
75
- "passthrough_original": false,
76
- "use_gpu": true
77
- },
78
- "Default File": {
79
- "silence_threshold": -35.0,
80
- "transpose": 0.0,
81
- "auto_predict_f0": true,
82
- "f0_method": "crepe",
83
- "cluster_infer_ratio": 0.0,
84
- "noise_scale": 0.4,
85
- "pad_seconds": 0.1,
86
- "chunk_seconds": 0.5,
87
- "absolute_thresh": true,
88
- "max_chunk_seconds": 40,
89
- "auto_play": true,
90
- "passthrough_original": false
91
- }
92
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/f0.py DELETED
@@ -1,239 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from logging import getLogger
4
- from typing import Any, Literal
5
-
6
- import numpy as np
7
- import torch
8
- import torchcrepe
9
- from cm_time import timer
10
- from numpy import dtype, float32, ndarray
11
- from torch import FloatTensor, Tensor
12
-
13
- from so_vits_svc_fork.utils import get_optimal_device
14
-
15
- LOG = getLogger(__name__)
16
-
17
-
18
- def normalize_f0(
19
- f0: FloatTensor, x_mask: FloatTensor, uv: FloatTensor, random_scale=True
20
- ) -> FloatTensor:
21
- # calculate means based on x_mask
22
- uv_sum = torch.sum(uv, dim=1, keepdim=True)
23
- uv_sum[uv_sum == 0] = 9999
24
- means = torch.sum(f0[:, 0, :] * uv, dim=1, keepdim=True) / uv_sum
25
-
26
- if random_scale:
27
- factor = torch.Tensor(f0.shape[0], 1).uniform_(0.8, 1.2).to(f0.device)
28
- else:
29
- factor = torch.ones(f0.shape[0], 1).to(f0.device)
30
- # normalize f0 based on means and factor
31
- f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1)
32
- if torch.isnan(f0_norm).any():
33
- exit(0)
34
- return f0_norm * x_mask
35
-
36
-
37
- def interpolate_f0(
38
- f0: ndarray[Any, dtype[float32]]
39
- ) -> tuple[ndarray[Any, dtype[float32]], ndarray[Any, dtype[float32]]]:
40
- data = np.reshape(f0, (f0.size, 1))
41
-
42
- vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
43
- vuv_vector[data > 0.0] = 1.0
44
- vuv_vector[data <= 0.0] = 0.0
45
-
46
- ip_data = data
47
-
48
- frame_number = data.size
49
- last_value = 0.0
50
- for i in range(frame_number):
51
- if data[i] <= 0.0:
52
- j = i + 1
53
- for j in range(i + 1, frame_number):
54
- if data[j] > 0.0:
55
- break
56
- if j < frame_number - 1:
57
- if last_value > 0.0:
58
- step = (data[j] - data[i - 1]) / float(j - i)
59
- for k in range(i, j):
60
- ip_data[k] = data[i - 1] + step * (k - i + 1)
61
- else:
62
- for k in range(i, j):
63
- ip_data[k] = data[j]
64
- else:
65
- for k in range(i, frame_number):
66
- ip_data[k] = last_value
67
- else:
68
- ip_data[i] = data[i]
69
- last_value = data[i]
70
-
71
- return ip_data[:, 0], vuv_vector[:, 0]
72
-
73
-
74
- def compute_f0_parselmouth(
75
- wav_numpy: ndarray[Any, dtype[float32]],
76
- p_len: None | int = None,
77
- sampling_rate: int = 44100,
78
- hop_length: int = 512,
79
- ):
80
- import parselmouth
81
-
82
- x = wav_numpy
83
- if p_len is None:
84
- p_len = x.shape[0] // hop_length
85
- else:
86
- assert abs(p_len - x.shape[0] // hop_length) < 4, "pad length error"
87
- time_step = hop_length / sampling_rate * 1000
88
- f0_min = 50
89
- f0_max = 1100
90
- f0 = (
91
- parselmouth.Sound(x, sampling_rate)
92
- .to_pitch_ac(
93
- time_step=time_step / 1000,
94
- voicing_threshold=0.6,
95
- pitch_floor=f0_min,
96
- pitch_ceiling=f0_max,
97
- )
98
- .selected_array["frequency"]
99
- )
100
-
101
- pad_size = (p_len - len(f0) + 1) // 2
102
- if pad_size > 0 or p_len - len(f0) - pad_size > 0:
103
- f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
104
- return f0
105
-
106
-
107
- def _resize_f0(
108
- x: ndarray[Any, dtype[float32]], target_len: int
109
- ) -> ndarray[Any, dtype[float32]]:
110
- source = np.array(x)
111
- source[source < 0.001] = np.nan
112
- target = np.interp(
113
- np.arange(0, len(source) * target_len, len(source)) / target_len,
114
- np.arange(0, len(source)),
115
- source,
116
- )
117
- res = np.nan_to_num(target)
118
- return res
119
-
120
-
121
- def compute_f0_pyworld(
122
- wav_numpy: ndarray[Any, dtype[float32]],
123
- p_len: None | int = None,
124
- sampling_rate: int = 44100,
125
- hop_length: int = 512,
126
- type_: Literal["dio", "harvest"] = "dio",
127
- ):
128
- import pyworld
129
-
130
- if p_len is None:
131
- p_len = wav_numpy.shape[0] // hop_length
132
- if type_ == "dio":
133
- f0, t = pyworld.dio(
134
- wav_numpy.astype(np.double),
135
- fs=sampling_rate,
136
- f0_ceil=f0_max,
137
- f0_floor=f0_min,
138
- frame_period=1000 * hop_length / sampling_rate,
139
- )
140
- elif type_ == "harvest":
141
- f0, t = pyworld.harvest(
142
- wav_numpy.astype(np.double),
143
- fs=sampling_rate,
144
- f0_ceil=f0_max,
145
- f0_floor=f0_min,
146
- frame_period=1000 * hop_length / sampling_rate,
147
- )
148
- f0 = pyworld.stonemask(wav_numpy.astype(np.double), f0, t, sampling_rate)
149
- for index, pitch in enumerate(f0):
150
- f0[index] = round(pitch, 1)
151
- return _resize_f0(f0, p_len)
152
-
153
-
154
- def compute_f0_crepe(
155
- wav_numpy: ndarray[Any, dtype[float32]],
156
- p_len: None | int = None,
157
- sampling_rate: int = 44100,
158
- hop_length: int = 512,
159
- device: str | torch.device = get_optimal_device(),
160
- model: Literal["full", "tiny"] = "full",
161
- ):
162
- audio = torch.from_numpy(wav_numpy).to(device, copy=True)
163
- audio = torch.unsqueeze(audio, dim=0)
164
-
165
- if audio.ndim == 2 and audio.shape[0] > 1:
166
- audio = torch.mean(audio, dim=0, keepdim=True).detach()
167
- # (T) -> (1, T)
168
- audio = audio.detach()
169
-
170
- pitch: Tensor = torchcrepe.predict(
171
- audio,
172
- sampling_rate,
173
- hop_length,
174
- f0_min,
175
- f0_max,
176
- model,
177
- batch_size=hop_length * 2,
178
- device=device,
179
- pad=True,
180
- )
181
-
182
- f0 = pitch.squeeze(0).cpu().float().numpy()
183
- p_len = p_len or wav_numpy.shape[0] // hop_length
184
- f0 = _resize_f0(f0, p_len)
185
- return f0
186
-
187
-
188
- def compute_f0(
189
- wav_numpy: ndarray[Any, dtype[float32]],
190
- p_len: None | int = None,
191
- sampling_rate: int = 44100,
192
- hop_length: int = 512,
193
- method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
194
- **kwargs,
195
- ):
196
- with timer() as t:
197
- wav_numpy = wav_numpy.astype(np.float32)
198
- wav_numpy /= np.quantile(np.abs(wav_numpy), 0.999)
199
- if method in ["dio", "harvest"]:
200
- f0 = compute_f0_pyworld(wav_numpy, p_len, sampling_rate, hop_length, method)
201
- elif method == "crepe":
202
- f0 = compute_f0_crepe(wav_numpy, p_len, sampling_rate, hop_length, **kwargs)
203
- elif method == "crepe-tiny":
204
- f0 = compute_f0_crepe(
205
- wav_numpy, p_len, sampling_rate, hop_length, model="tiny", **kwargs
206
- )
207
- elif method == "parselmouth":
208
- f0 = compute_f0_parselmouth(wav_numpy, p_len, sampling_rate, hop_length)
209
- else:
210
- raise ValueError(
211
- "type must be dio, crepe, crepe-tiny, harvest or parselmouth"
212
- )
213
- rtf = t.elapsed / (len(wav_numpy) / sampling_rate)
214
- LOG.info(f"F0 inference time: {t.elapsed:.3f}s, RTF: {rtf:.3f}")
215
- return f0
216
-
217
-
218
- def f0_to_coarse(f0: torch.Tensor | float):
219
- is_torch = isinstance(f0, torch.Tensor)
220
- f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
221
- f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (
222
- f0_mel_max - f0_mel_min
223
- ) + 1
224
-
225
- f0_mel[f0_mel <= 1] = 1
226
- f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
227
- f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int)
228
- assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (
229
- f0_coarse.max(),
230
- f0_coarse.min(),
231
- )
232
- return f0_coarse
233
-
234
-
235
- f0_bin = 256
236
- f0_max = 1100.0
237
- f0_min = 50.0
238
- f0_mel_min = 1127 * np.log(1 + f0_min / 700)
239
- f0_mel_max = 1127 * np.log(1 + f0_max / 700)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/gui.py DELETED
@@ -1,851 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import multiprocessing
5
- import os
6
- from copy import copy
7
- from logging import getLogger
8
- from pathlib import Path
9
-
10
- import PySimpleGUI as sg
11
- import sounddevice as sd
12
- import soundfile as sf
13
- import torch
14
- from pebble import ProcessFuture, ProcessPool
15
-
16
- from . import __version__
17
- from .utils import get_optimal_device
18
-
19
- GUI_DEFAULT_PRESETS_PATH = Path(__file__).parent / "default_gui_presets.json"
20
- GUI_PRESETS_PATH = Path("./user_gui_presets.json").absolute()
21
-
22
- LOG = getLogger(__name__)
23
-
24
-
25
- def play_audio(path: Path | str):
26
- if isinstance(path, Path):
27
- path = path.as_posix()
28
- data, sr = sf.read(path)
29
- sd.play(data, sr)
30
-
31
-
32
- def load_presets() -> dict:
33
- defaults = json.loads(GUI_DEFAULT_PRESETS_PATH.read_text("utf-8"))
34
- users = (
35
- json.loads(GUI_PRESETS_PATH.read_text("utf-8"))
36
- if GUI_PRESETS_PATH.exists()
37
- else {}
38
- )
39
- # prioriy: defaults > users
40
- # order: defaults -> users
41
- return {**defaults, **users, **defaults}
42
-
43
-
44
- def add_preset(name: str, preset: dict) -> dict:
45
- presets = load_presets()
46
- presets[name] = preset
47
- with GUI_PRESETS_PATH.open("w") as f:
48
- json.dump(presets, f, indent=2)
49
- return load_presets()
50
-
51
-
52
- def delete_preset(name: str) -> dict:
53
- presets = load_presets()
54
- if name in presets:
55
- del presets[name]
56
- else:
57
- LOG.warning(f"Cannot delete preset {name} because it does not exist.")
58
- with GUI_PRESETS_PATH.open("w") as f:
59
- json.dump(presets, f, indent=2)
60
- return load_presets()
61
-
62
-
63
- def get_output_path(input_path: Path) -> Path:
64
- # Default output path
65
- output_path = input_path.parent / f"{input_path.stem}.out{input_path.suffix}"
66
-
67
- # Increment file number in path if output file already exists
68
- file_num = 1
69
- while output_path.exists():
70
- output_path = (
71
- input_path.parent / f"{input_path.stem}.out_{file_num}{input_path.suffix}"
72
- )
73
- file_num += 1
74
- return output_path
75
-
76
-
77
- def get_supported_file_types() -> tuple[tuple[str, str], ...]:
78
- res = tuple(
79
- [
80
- (extension, f".{extension.lower()}")
81
- for extension in sf.available_formats().keys()
82
- ]
83
- )
84
-
85
- # Sort by popularity
86
- common_file_types = ["WAV", "MP3", "FLAC", "OGG", "M4A", "WMA"]
87
- res = sorted(
88
- res,
89
- key=lambda x: common_file_types.index(x[0])
90
- if x[0] in common_file_types
91
- else len(common_file_types),
92
- )
93
- return res
94
-
95
-
96
- def get_supported_file_types_concat() -> tuple[tuple[str, str], ...]:
97
- return (("Audio", " ".join(sf.available_formats().keys())),)
98
-
99
-
100
- def validate_output_file_type(output_path: Path) -> bool:
101
- supported_file_types = sorted(
102
- [f".{extension.lower()}" for extension in sf.available_formats().keys()]
103
- )
104
- if not output_path.suffix:
105
- sg.popup_ok(
106
- "Error: Output path missing file type extension, enter "
107
- + "one of the following manually:\n\n"
108
- + "\n".join(supported_file_types)
109
- )
110
- return False
111
- if output_path.suffix.lower() not in supported_file_types:
112
- sg.popup_ok(
113
- f"Error: {output_path.suffix.lower()} is not a supported "
114
- + "extension; use one of the following:\n\n"
115
- + "\n".join(supported_file_types)
116
- )
117
- return False
118
- return True
119
-
120
-
121
- def get_devices(
122
- update: bool = True,
123
- ) -> tuple[list[str], list[str], list[int], list[int]]:
124
- if update:
125
- sd._terminate()
126
- sd._initialize()
127
- devices = sd.query_devices()
128
- hostapis = sd.query_hostapis()
129
- for hostapi in hostapis:
130
- for device_idx in hostapi["devices"]:
131
- devices[device_idx]["hostapi_name"] = hostapi["name"]
132
- input_devices = [
133
- f"{d['name']} ({d['hostapi_name']})"
134
- for d in devices
135
- if d["max_input_channels"] > 0
136
- ]
137
- output_devices = [
138
- f"{d['name']} ({d['hostapi_name']})"
139
- for d in devices
140
- if d["max_output_channels"] > 0
141
- ]
142
- input_devices_indices = [d["index"] for d in devices if d["max_input_channels"] > 0]
143
- output_devices_indices = [
144
- d["index"] for d in devices if d["max_output_channels"] > 0
145
- ]
146
- return input_devices, output_devices, input_devices_indices, output_devices_indices
147
-
148
-
149
- def after_inference(window: sg.Window, path: Path, auto_play: bool, output_path: Path):
150
- try:
151
- LOG.info(f"Finished inference for {path.stem}{path.suffix}")
152
- window["infer"].update(disabled=False)
153
-
154
- if auto_play:
155
- play_audio(output_path)
156
- except Exception as e:
157
- LOG.exception(e)
158
-
159
-
160
- def main():
161
- LOG.info(f"version: {__version__}")
162
-
163
- # sg.theme("Dark")
164
- sg.theme_add_new(
165
- "Very Dark",
166
- {
167
- "BACKGROUND": "#111111",
168
- "TEXT": "#FFFFFF",
169
- "INPUT": "#444444",
170
- "TEXT_INPUT": "#FFFFFF",
171
- "SCROLL": "#333333",
172
- "BUTTON": ("white", "#112233"),
173
- "PROGRESS": ("#111111", "#333333"),
174
- "BORDER": 2,
175
- "SLIDER_DEPTH": 2,
176
- "PROGRESS_DEPTH": 2,
177
- },
178
- )
179
- sg.theme("Very Dark")
180
-
181
- model_candidates = list(sorted(Path("./logs/44k/").glob("G_*.pth")))
182
-
183
- frame_contents = {
184
- "Paths": [
185
- [
186
- sg.Text("Model path"),
187
- sg.Push(),
188
- sg.InputText(
189
- key="model_path",
190
- default_text=model_candidates[-1].absolute().as_posix()
191
- if model_candidates
192
- else "",
193
- enable_events=True,
194
- ),
195
- sg.FileBrowse(
196
- initial_folder=Path("./logs/44k/").absolute
197
- if Path("./logs/44k/").exists()
198
- else Path(".").absolute().as_posix(),
199
- key="model_path_browse",
200
- file_types=(
201
- ("PyTorch", "G_*.pth G_*.pt"),
202
- ("Pytorch", "*.pth *.pt"),
203
- ),
204
- ),
205
- ],
206
- [
207
- sg.Text("Config path"),
208
- sg.Push(),
209
- sg.InputText(
210
- key="config_path",
211
- default_text=Path("./configs/44k/config.json").absolute().as_posix()
212
- if Path("./configs/44k/config.json").exists()
213
- else "",
214
- enable_events=True,
215
- ),
216
- sg.FileBrowse(
217
- initial_folder=Path("./configs/44k/").as_posix()
218
- if Path("./configs/44k/").exists()
219
- else Path(".").absolute().as_posix(),
220
- key="config_path_browse",
221
- file_types=(("JSON", "*.json"),),
222
- ),
223
- ],
224
- [
225
- sg.Text("Cluster model path (Optional)"),
226
- sg.Push(),
227
- sg.InputText(
228
- key="cluster_model_path",
229
- default_text=Path("./logs/44k/kmeans.pt").absolute().as_posix()
230
- if Path("./logs/44k/kmeans.pt").exists()
231
- else "",
232
- enable_events=True,
233
- ),
234
- sg.FileBrowse(
235
- initial_folder="./logs/44k/"
236
- if Path("./logs/44k/").exists()
237
- else ".",
238
- key="cluster_model_path_browse",
239
- file_types=(("PyTorch", "*.pt"), ("Pickle", "*.pt *.pth *.pkl")),
240
- ),
241
- ],
242
- ],
243
- "Common": [
244
- [
245
- sg.Text("Speaker"),
246
- sg.Push(),
247
- sg.Combo(values=[], key="speaker", size=(20, 1)),
248
- ],
249
- [
250
- sg.Text("Silence threshold"),
251
- sg.Push(),
252
- sg.Slider(
253
- range=(-60.0, 0),
254
- orientation="h",
255
- key="silence_threshold",
256
- resolution=0.1,
257
- ),
258
- ],
259
- [
260
- sg.Text(
261
- "Pitch (12 = 1 octave)\n"
262
- "ADJUST THIS based on your voice\n"
263
- "when Auto predict F0 is turned off.",
264
- size=(None, 4),
265
- ),
266
- sg.Push(),
267
- sg.Slider(
268
- range=(-36, 36),
269
- orientation="h",
270
- key="transpose",
271
- tick_interval=12,
272
- ),
273
- ],
274
- [
275
- sg.Checkbox(
276
- key="auto_predict_f0",
277
- text="Auto predict F0 (Pitch may become unstable when turned on in real-time inference.)",
278
- )
279
- ],
280
- [
281
- sg.Text("F0 prediction method"),
282
- sg.Push(),
283
- sg.Combo(
284
- ["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"],
285
- key="f0_method",
286
- ),
287
- ],
288
- [
289
- sg.Text("Cluster infer ratio"),
290
- sg.Push(),
291
- sg.Slider(
292
- range=(0, 1.0),
293
- orientation="h",
294
- key="cluster_infer_ratio",
295
- resolution=0.01,
296
- ),
297
- ],
298
- [
299
- sg.Text("Noise scale"),
300
- sg.Push(),
301
- sg.Slider(
302
- range=(0.0, 1.0),
303
- orientation="h",
304
- key="noise_scale",
305
- resolution=0.01,
306
- ),
307
- ],
308
- [
309
- sg.Text("Pad seconds"),
310
- sg.Push(),
311
- sg.Slider(
312
- range=(0.0, 1.0),
313
- orientation="h",
314
- key="pad_seconds",
315
- resolution=0.01,
316
- ),
317
- ],
318
- [
319
- sg.Text("Chunk seconds"),
320
- sg.Push(),
321
- sg.Slider(
322
- range=(0.0, 3.0),
323
- orientation="h",
324
- key="chunk_seconds",
325
- resolution=0.01,
326
- ),
327
- ],
328
- [
329
- sg.Text("Max chunk seconds (set lower if Out Of Memory, 0 to disable)"),
330
- sg.Push(),
331
- sg.Slider(
332
- range=(0.0, 240.0),
333
- orientation="h",
334
- key="max_chunk_seconds",
335
- resolution=1.0,
336
- ),
337
- ],
338
- [
339
- sg.Checkbox(
340
- key="absolute_thresh",
341
- text="Absolute threshold (ignored (True) in realtime inference)",
342
- )
343
- ],
344
- ],
345
- "File": [
346
- [
347
- sg.Text("Input audio path"),
348
- sg.Push(),
349
- sg.InputText(key="input_path", enable_events=True),
350
- sg.FileBrowse(
351
- initial_folder=".",
352
- key="input_path_browse",
353
- file_types=get_supported_file_types_concat(),
354
- ),
355
- sg.FolderBrowse(
356
- button_text="Browse(Folder)",
357
- initial_folder=".",
358
- key="input_path_folder_browse",
359
- target="input_path",
360
- ),
361
- sg.Button("Play", key="play_input"),
362
- ],
363
- [
364
- sg.Text("Output audio path"),
365
- sg.Push(),
366
- sg.InputText(key="output_path"),
367
- sg.FileSaveAs(
368
- initial_folder=".",
369
- key="output_path_browse",
370
- file_types=get_supported_file_types(),
371
- ),
372
- ],
373
- [sg.Checkbox(key="auto_play", text="Auto play", default=True)],
374
- ],
375
- "Realtime": [
376
- [
377
- sg.Text("Crossfade seconds"),
378
- sg.Push(),
379
- sg.Slider(
380
- range=(0, 0.6),
381
- orientation="h",
382
- key="crossfade_seconds",
383
- resolution=0.001,
384
- ),
385
- ],
386
- [
387
- sg.Text(
388
- "Block seconds", # \n(big -> more robust, slower, (the same) latency)"
389
- tooltip="Big -> more robust, slower, (the same) latency",
390
- ),
391
- sg.Push(),
392
- sg.Slider(
393
- range=(0, 3.0),
394
- orientation="h",
395
- key="block_seconds",
396
- resolution=0.001,
397
- ),
398
- ],
399
- [
400
- sg.Text(
401
- "Additional Infer seconds (before)", # \n(big -> more robust, slower)"
402
- tooltip="Big -> more robust, slower, additional latency",
403
- ),
404
- sg.Push(),
405
- sg.Slider(
406
- range=(0, 2.0),
407
- orientation="h",
408
- key="additional_infer_before_seconds",
409
- resolution=0.001,
410
- ),
411
- ],
412
- [
413
- sg.Text(
414
- "Additional Infer seconds (after)", # \n(big -> more robust, slower, additional latency)"
415
- tooltip="Big -> more robust, slower, additional latency",
416
- ),
417
- sg.Push(),
418
- sg.Slider(
419
- range=(0, 2.0),
420
- orientation="h",
421
- key="additional_infer_after_seconds",
422
- resolution=0.001,
423
- ),
424
- ],
425
- [
426
- sg.Text("Realtime algorithm"),
427
- sg.Push(),
428
- sg.Combo(
429
- ["2 (Divide by speech)", "1 (Divide constantly)"],
430
- default_value="1 (Divide constantly)",
431
- key="realtime_algorithm",
432
- ),
433
- ],
434
- [
435
- sg.Text("Input device"),
436
- sg.Push(),
437
- sg.Combo(
438
- key="input_device",
439
- values=[],
440
- size=(60, 1),
441
- ),
442
- ],
443
- [
444
- sg.Text("Output device"),
445
- sg.Push(),
446
- sg.Combo(
447
- key="output_device",
448
- values=[],
449
- size=(60, 1),
450
- ),
451
- ],
452
- [
453
- sg.Checkbox(
454
- "Passthrough original audio (for latency check)",
455
- key="passthrough_original",
456
- default=False,
457
- ),
458
- sg.Push(),
459
- sg.Button("Refresh devices", key="refresh_devices"),
460
- ],
461
- [
462
- sg.Frame(
463
- "Notes",
464
- [
465
- [
466
- sg.Text(
467
- "In Realtime Inference:\n"
468
- " - Setting F0 prediction method to 'crepe` may cause performance degradation.\n"
469
- " - Auto Predict F0 must be turned off.\n"
470
- "If the audio sounds mumbly and choppy:\n"
471
- " Case: The inference has not been made in time (Increase Block seconds)\n"
472
- " Case: Mic input is low (Decrease Silence threshold)\n"
473
- )
474
- ]
475
- ],
476
- ),
477
- ],
478
- ],
479
- "Presets": [
480
- [
481
- sg.Text("Presets"),
482
- sg.Push(),
483
- sg.Combo(
484
- key="presets",
485
- values=list(load_presets().keys()),
486
- size=(40, 1),
487
- enable_events=True,
488
- ),
489
- sg.Button("Delete preset", key="delete_preset"),
490
- ],
491
- [
492
- sg.Text("Preset name"),
493
- sg.Stretch(),
494
- sg.InputText(key="preset_name", size=(26, 1)),
495
- sg.Button("Add current settings as a preset", key="add_preset"),
496
- ],
497
- ],
498
- }
499
-
500
- # frames
501
- frames = {}
502
- for name, items in frame_contents.items():
503
- frame = sg.Frame(name, items)
504
- frame.expand_x = True
505
- frames[name] = [frame]
506
-
507
- bottoms = [
508
- [
509
- sg.Checkbox(
510
- key="use_gpu",
511
- default=get_optimal_device() != torch.device("cpu"),
512
- text="Use GPU"
513
- + (
514
- " (not available; if your device has GPU, make sure you installed PyTorch with CUDA support)"
515
- if get_optimal_device() == torch.device("cpu")
516
- else ""
517
- ),
518
- disabled=get_optimal_device() == torch.device("cpu"),
519
- )
520
- ],
521
- [
522
- sg.Button("Infer", key="infer"),
523
- sg.Button("(Re)Start Voice Changer", key="start_vc"),
524
- sg.Button("Stop Voice Changer", key="stop_vc"),
525
- sg.Push(),
526
- # sg.Button("ONNX Export", key="onnx_export"),
527
- ],
528
- ]
529
- column1 = sg.Column(
530
- [
531
- frames["Paths"],
532
- frames["Common"],
533
- ],
534
- vertical_alignment="top",
535
- )
536
- column2 = sg.Column(
537
- [
538
- frames["File"],
539
- frames["Realtime"],
540
- frames["Presets"],
541
- ]
542
- + bottoms
543
- )
544
- # columns
545
- layout = [[column1, column2]]
546
- # get screen size
547
- screen_width, screen_height = sg.Window.get_screen_size()
548
- if screen_height < 720:
549
- layout = [
550
- [
551
- sg.Column(
552
- layout,
553
- vertical_alignment="top",
554
- scrollable=False,
555
- expand_x=True,
556
- expand_y=True,
557
- vertical_scroll_only=True,
558
- key="main_column",
559
- )
560
- ]
561
- ]
562
- window = sg.Window(
563
- f"{__name__.split('.')[0].replace('_', '-')} v{__version__}",
564
- layout,
565
- grab_anywhere=True,
566
- finalize=True,
567
- scaling=1,
568
- font=("Yu Gothic UI", 11) if os.name == "nt" else None,
569
- # resizable=True,
570
- # size=(1280, 720),
571
- # Below disables taskbar, which may be not useful for some users
572
- # use_custom_titlebar=True, no_titlebar=False
573
- # Keep on top
574
- # keep_on_top=True
575
- )
576
-
577
- # event, values = window.read(timeout=0.01)
578
- # window["main_column"].Scrollable = True
579
-
580
- # make slider height smaller
581
- try:
582
- for v in window.element_list():
583
- if isinstance(v, sg.Slider):
584
- v.Widget.configure(sliderrelief="flat", width=10, sliderlength=20)
585
- except Exception as e:
586
- LOG.exception(e)
587
-
588
- # for n in ["input_device", "output_device"]:
589
- # window[n].Widget.configure(justify="right")
590
- event, values = window.read(timeout=0.01)
591
-
592
- def update_speaker() -> None:
593
- from . import utils
594
-
595
- config_path = Path(values["config_path"])
596
- if config_path.exists() and config_path.is_file():
597
- hp = utils.get_hparams(values["config_path"])
598
- LOG.debug(f"Loaded config from {values['config_path']}")
599
- window["speaker"].update(
600
- values=list(hp.__dict__["spk"].keys()), set_to_index=0
601
- )
602
-
603
- def update_devices() -> None:
604
- (
605
- input_devices,
606
- output_devices,
607
- input_device_indices,
608
- output_device_indices,
609
- ) = get_devices()
610
- input_device_indices_reversed = {
611
- v: k for k, v in enumerate(input_device_indices)
612
- }
613
- output_device_indices_reversed = {
614
- v: k for k, v in enumerate(output_device_indices)
615
- }
616
- window["input_device"].update(
617
- values=input_devices, value=values["input_device"]
618
- )
619
- window["output_device"].update(
620
- values=output_devices, value=values["output_device"]
621
- )
622
- input_default, output_default = sd.default.device
623
- if values["input_device"] not in input_devices:
624
- window["input_device"].update(
625
- values=input_devices,
626
- set_to_index=input_device_indices_reversed.get(input_default, 0),
627
- )
628
- if values["output_device"] not in output_devices:
629
- window["output_device"].update(
630
- values=output_devices,
631
- set_to_index=output_device_indices_reversed.get(output_default, 0),
632
- )
633
-
634
- PRESET_KEYS = [
635
- key
636
- for key in values.keys()
637
- if not any(exclude in key for exclude in ["preset", "browse"])
638
- ]
639
-
640
- def apply_preset(name: str) -> None:
641
- for key, value in load_presets()[name].items():
642
- if key in PRESET_KEYS:
643
- window[key].update(value)
644
- values[key] = value
645
-
646
- default_name = list(load_presets().keys())[0]
647
- apply_preset(default_name)
648
- window["presets"].update(default_name)
649
- del default_name
650
- update_speaker()
651
- update_devices()
652
- # with ProcessPool(max_workers=1) as pool:
653
- # to support Linux
654
- with ProcessPool(
655
- max_workers=min(2, multiprocessing.cpu_count()),
656
- context=multiprocessing.get_context("spawn"),
657
- ) as pool:
658
- future: None | ProcessFuture = None
659
- infer_futures: set[ProcessFuture] = set()
660
- while True:
661
- event, values = window.read(200)
662
- if event == sg.WIN_CLOSED:
663
- break
664
- if not event == sg.EVENT_TIMEOUT:
665
- LOG.info(f"Event {event}, values {values}")
666
- if event.endswith("_path"):
667
- for name in window.AllKeysDict:
668
- if str(name).endswith("_browse"):
669
- browser = window[name]
670
- if isinstance(browser, sg.Button):
671
- LOG.info(
672
- f"Updating browser {browser} to {Path(values[event]).parent}"
673
- )
674
- browser.InitialFolder = Path(values[event]).parent
675
- browser.update()
676
- else:
677
- LOG.warning(f"Browser {browser} is not a FileBrowse")
678
- window["transpose"].update(
679
- disabled=values["auto_predict_f0"],
680
- visible=not values["auto_predict_f0"],
681
- )
682
-
683
- input_path = Path(values["input_path"])
684
- output_path = Path(values["output_path"])
685
-
686
- if event == "add_preset":
687
- presets = add_preset(
688
- values["preset_name"], {key: values[key] for key in PRESET_KEYS}
689
- )
690
- window["presets"].update(values=list(presets.keys()))
691
- elif event == "delete_preset":
692
- presets = delete_preset(values["presets"])
693
- window["presets"].update(values=list(presets.keys()))
694
- elif event == "presets":
695
- apply_preset(values["presets"])
696
- update_speaker()
697
- elif event == "refresh_devices":
698
- update_devices()
699
- elif event == "config_path":
700
- update_speaker()
701
- elif event == "input_path":
702
- # Don't change the output path if it's already set
703
- # if values["output_path"]:
704
- # continue
705
- # Set a sensible default output path
706
- window.Element("output_path").Update(str(get_output_path(input_path)))
707
- elif event == "infer":
708
- if "Default VC" in values["presets"]:
709
- window["presets"].update(
710
- set_to_index=list(load_presets().keys()).index("Default File")
711
- )
712
- apply_preset("Default File")
713
- if values["input_path"] == "":
714
- LOG.warning("Input path is empty.")
715
- continue
716
- if not input_path.exists():
717
- LOG.warning(f"Input path {input_path} does not exist.")
718
- continue
719
- # if not validate_output_file_type(output_path):
720
- # continue
721
-
722
- try:
723
- from so_vits_svc_fork.inference.main import infer
724
-
725
- LOG.info("Starting inference...")
726
- window["infer"].update(disabled=True)
727
- infer_future = pool.schedule(
728
- infer,
729
- kwargs=dict(
730
- # paths
731
- model_path=Path(values["model_path"]),
732
- output_path=output_path,
733
- input_path=input_path,
734
- config_path=Path(values["config_path"]),
735
- recursive=True,
736
- # svc config
737
- speaker=values["speaker"],
738
- cluster_model_path=Path(values["cluster_model_path"])
739
- if values["cluster_model_path"]
740
- else None,
741
- transpose=values["transpose"],
742
- auto_predict_f0=values["auto_predict_f0"],
743
- cluster_infer_ratio=values["cluster_infer_ratio"],
744
- noise_scale=values["noise_scale"],
745
- f0_method=values["f0_method"],
746
- # slice config
747
- db_thresh=values["silence_threshold"],
748
- pad_seconds=values["pad_seconds"],
749
- chunk_seconds=values["chunk_seconds"],
750
- absolute_thresh=values["absolute_thresh"],
751
- max_chunk_seconds=values["max_chunk_seconds"],
752
- device="cpu"
753
- if not values["use_gpu"]
754
- else get_optimal_device(),
755
- ),
756
- )
757
- infer_future.add_done_callback(
758
- lambda _future: after_inference(
759
- window, input_path, values["auto_play"], output_path
760
- )
761
- )
762
- infer_futures.add(infer_future)
763
- except Exception as e:
764
- LOG.exception(e)
765
- elif event == "play_input":
766
- if Path(values["input_path"]).exists():
767
- pool.schedule(play_audio, args=[Path(values["input_path"])])
768
- elif event == "start_vc":
769
- _, _, input_device_indices, output_device_indices = get_devices(
770
- update=False
771
- )
772
- from so_vits_svc_fork.inference.main import realtime
773
-
774
- if future:
775
- LOG.info("Canceling previous task")
776
- future.cancel()
777
- future = pool.schedule(
778
- realtime,
779
- kwargs=dict(
780
- # paths
781
- model_path=Path(values["model_path"]),
782
- config_path=Path(values["config_path"]),
783
- speaker=values["speaker"],
784
- # svc config
785
- cluster_model_path=Path(values["cluster_model_path"])
786
- if values["cluster_model_path"]
787
- else None,
788
- transpose=values["transpose"],
789
- auto_predict_f0=values["auto_predict_f0"],
790
- cluster_infer_ratio=values["cluster_infer_ratio"],
791
- noise_scale=values["noise_scale"],
792
- f0_method=values["f0_method"],
793
- # slice config
794
- db_thresh=values["silence_threshold"],
795
- pad_seconds=values["pad_seconds"],
796
- chunk_seconds=values["chunk_seconds"],
797
- # realtime config
798
- crossfade_seconds=values["crossfade_seconds"],
799
- additional_infer_before_seconds=values[
800
- "additional_infer_before_seconds"
801
- ],
802
- additional_infer_after_seconds=values[
803
- "additional_infer_after_seconds"
804
- ],
805
- block_seconds=values["block_seconds"],
806
- version=int(values["realtime_algorithm"][0]),
807
- input_device=input_device_indices[
808
- window["input_device"].widget.current()
809
- ],
810
- output_device=output_device_indices[
811
- window["output_device"].widget.current()
812
- ],
813
- device=get_optimal_device() if values["use_gpu"] else "cpu",
814
- passthrough_original=values["passthrough_original"],
815
- ),
816
- )
817
- elif event == "stop_vc":
818
- if future:
819
- future.cancel()
820
- future = None
821
- elif event == "onnx_export":
822
- try:
823
- raise NotImplementedError("ONNX export is not implemented yet.")
824
- from so_vits_svc_fork.modules.onnx._export import onnx_export
825
-
826
- onnx_export(
827
- input_path=Path(values["model_path"]),
828
- output_path=Path(values["model_path"]).with_suffix(".onnx"),
829
- config_path=Path(values["config_path"]),
830
- device="cpu",
831
- )
832
- except Exception as e:
833
- LOG.exception(e)
834
- if future is not None and future.done():
835
- try:
836
- future.result()
837
- except Exception as e:
838
- LOG.error("Error in realtime: ")
839
- LOG.exception(e)
840
- future = None
841
- for future in copy(infer_futures):
842
- if future.done():
843
- try:
844
- future.result()
845
- except Exception as e:
846
- LOG.error("Error in inference: ")
847
- LOG.exception(e)
848
- infer_futures.remove(future)
849
- if future:
850
- future.cancel()
851
- window.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/hparams.py DELETED
@@ -1,38 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Any
4
-
5
-
6
- class HParams:
7
- def __init__(self, **kwargs: Any) -> None:
8
- for k, v in kwargs.items():
9
- if type(v) == dict:
10
- v = HParams(**v)
11
- self[k] = v
12
-
13
- def keys(self):
14
- return self.__dict__.keys()
15
-
16
- def items(self):
17
- return self.__dict__.items()
18
-
19
- def values(self):
20
- return self.__dict__.values()
21
-
22
- def get(self, key: str, default: Any = None):
23
- return self.__dict__.get(key, default)
24
-
25
- def __len__(self):
26
- return len(self.__dict__)
27
-
28
- def __getitem__(self, key):
29
- return getattr(self, key)
30
-
31
- def __setitem__(self, key, value):
32
- return setattr(self, key, value)
33
-
34
- def __contains__(self, key):
35
- return key in self.__dict__
36
-
37
- def __repr__(self):
38
- return self.__dict__.__repr__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/inference/__init__.py DELETED
File without changes
so_vits_svc_fork/inference/core.py DELETED
@@ -1,692 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from copy import deepcopy
4
- from logging import getLogger
5
- from pathlib import Path
6
- from typing import Any, Callable, Iterable, Literal
7
-
8
- import attrs
9
- import librosa
10
- import numpy as np
11
- import torch
12
- from cm_time import timer
13
- from numpy import dtype, float32, ndarray
14
-
15
- import so_vits_svc_fork.f0
16
- from so_vits_svc_fork import cluster, utils
17
-
18
- from ..modules.synthesizers import SynthesizerTrn
19
- from ..utils import get_optimal_device
20
-
21
- LOG = getLogger(__name__)
22
-
23
-
24
- def pad_array(array_, target_length: int):
25
- current_length = array_.shape[0]
26
- if current_length >= target_length:
27
- return array_[
28
- (current_length - target_length)
29
- // 2 : (current_length - target_length)
30
- // 2
31
- + target_length,
32
- ...,
33
- ]
34
- else:
35
- pad_width = target_length - current_length
36
- pad_left = pad_width // 2
37
- pad_right = pad_width - pad_left
38
- padded_arr = np.pad(
39
- array_, (pad_left, pad_right), "constant", constant_values=(0, 0)
40
- )
41
- return padded_arr
42
-
43
-
44
- @attrs.frozen(kw_only=True)
45
- class Chunk:
46
- is_speech: bool
47
- audio: ndarray[Any, dtype[float32]]
48
- start: int
49
- end: int
50
-
51
- @property
52
- def duration(self) -> float32:
53
- # return self.end - self.start
54
- return float32(self.audio.shape[0])
55
-
56
- def __repr__(self) -> str:
57
- return f"Chunk(Speech: {self.is_speech}, {self.duration})"
58
-
59
-
60
- def split_silence(
61
- audio: ndarray[Any, dtype[float32]],
62
- top_db: int = 40,
63
- ref: float | Callable[[ndarray[Any, dtype[float32]]], float] = 1,
64
- frame_length: int = 2048,
65
- hop_length: int = 512,
66
- aggregate: Callable[[ndarray[Any, dtype[float32]]], float] = np.mean,
67
- max_chunk_length: int = 0,
68
- ) -> Iterable[Chunk]:
69
- non_silence_indices = librosa.effects.split(
70
- audio,
71
- top_db=top_db,
72
- ref=ref,
73
- frame_length=frame_length,
74
- hop_length=hop_length,
75
- aggregate=aggregate,
76
- )
77
- last_end = 0
78
- for start, end in non_silence_indices:
79
- if start != last_end:
80
- yield Chunk(
81
- is_speech=False, audio=audio[last_end:start], start=last_end, end=start
82
- )
83
- while max_chunk_length > 0 and end - start > max_chunk_length:
84
- yield Chunk(
85
- is_speech=True,
86
- audio=audio[start : start + max_chunk_length],
87
- start=start,
88
- end=start + max_chunk_length,
89
- )
90
- start += max_chunk_length
91
- if end - start > 0:
92
- yield Chunk(is_speech=True, audio=audio[start:end], start=start, end=end)
93
- last_end = end
94
- if last_end != len(audio):
95
- yield Chunk(
96
- is_speech=False, audio=audio[last_end:], start=last_end, end=len(audio)
97
- )
98
-
99
-
100
- class Svc:
101
- def __init__(
102
- self,
103
- *,
104
- net_g_path: Path | str,
105
- config_path: Path | str,
106
- device: torch.device | str | None = None,
107
- cluster_model_path: Path | str | None = None,
108
- half: bool = False,
109
- ):
110
- self.net_g_path = net_g_path
111
- if device is None:
112
- self.device = (get_optimal_device(),)
113
- else:
114
- self.device = torch.device(device)
115
- self.hps = utils.get_hparams(config_path)
116
- self.target_sample = self.hps.data.sampling_rate
117
- self.hop_size = self.hps.data.hop_length
118
- self.spk2id = self.hps.spk
119
- self.hubert_model = utils.get_hubert_model(
120
- self.device, self.hps.data.get("contentvec_final_proj", True)
121
- )
122
- self.dtype = torch.float16 if half else torch.float32
123
- self.contentvec_final_proj = self.hps.data.__dict__.get(
124
- "contentvec_final_proj", True
125
- )
126
- self.load_model()
127
- if cluster_model_path is not None and Path(cluster_model_path).exists():
128
- self.cluster_model = cluster.get_cluster_model(cluster_model_path)
129
-
130
- def load_model(self):
131
- self.net_g = SynthesizerTrn(
132
- self.hps.data.filter_length // 2 + 1,
133
- self.hps.train.segment_size // self.hps.data.hop_length,
134
- **self.hps.model,
135
- )
136
- _ = utils.load_checkpoint(self.net_g_path, self.net_g, None)
137
- _ = self.net_g.eval()
138
- for m in self.net_g.modules():
139
- utils.remove_weight_norm_if_exists(m)
140
- _ = self.net_g.to(self.device, dtype=self.dtype)
141
- self.net_g = self.net_g
142
-
143
- def get_unit_f0(
144
- self,
145
- audio: ndarray[Any, dtype[float32]],
146
- tran: int,
147
- cluster_infer_ratio: float,
148
- speaker: int | str,
149
- f0_method: Literal[
150
- "crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
151
- ] = "dio",
152
- ):
153
- f0 = so_vits_svc_fork.f0.compute_f0(
154
- audio,
155
- sampling_rate=self.target_sample,
156
- hop_length=self.hop_size,
157
- method=f0_method,
158
- )
159
- f0, uv = so_vits_svc_fork.f0.interpolate_f0(f0)
160
- f0 = torch.as_tensor(f0, dtype=self.dtype, device=self.device)
161
- uv = torch.as_tensor(uv, dtype=self.dtype, device=self.device)
162
- f0 = f0 * 2 ** (tran / 12)
163
- f0 = f0.unsqueeze(0)
164
- uv = uv.unsqueeze(0)
165
-
166
- c = utils.get_content(
167
- self.hubert_model,
168
- audio,
169
- self.device,
170
- self.target_sample,
171
- self.contentvec_final_proj,
172
- ).to(self.dtype)
173
- c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1])
174
-
175
- if cluster_infer_ratio != 0:
176
- cluster_c = cluster.get_cluster_center_result(
177
- self.cluster_model, c.cpu().numpy().T, speaker
178
- ).T
179
- cluster_c = torch.FloatTensor(cluster_c).to(self.device)
180
- c = cluster_infer_ratio * cluster_c + (1 - cluster_infer_ratio) * c
181
-
182
- c = c.unsqueeze(0)
183
- return c, f0, uv
184
-
185
- def infer(
186
- self,
187
- speaker: int | str,
188
- transpose: int,
189
- audio: ndarray[Any, dtype[float32]],
190
- cluster_infer_ratio: float = 0,
191
- auto_predict_f0: bool = False,
192
- noise_scale: float = 0.4,
193
- f0_method: Literal[
194
- "crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
195
- ] = "dio",
196
- ) -> tuple[torch.Tensor, int]:
197
- audio = audio.astype(np.float32)
198
- # get speaker id
199
- if isinstance(speaker, int):
200
- if len(self.spk2id.__dict__) >= speaker:
201
- speaker_id = speaker
202
- else:
203
- raise ValueError(
204
- f"Speaker id {speaker} >= number of speakers {len(self.spk2id.__dict__)}"
205
- )
206
- else:
207
- if speaker in self.spk2id.__dict__:
208
- speaker_id = self.spk2id.__dict__[speaker]
209
- else:
210
- LOG.warning(f"Speaker {speaker} is not found. Use speaker 0 instead.")
211
- speaker_id = 0
212
- speaker_candidates = list(
213
- filter(lambda x: x[1] == speaker_id, self.spk2id.__dict__.items())
214
- )
215
- if len(speaker_candidates) > 1:
216
- raise ValueError(
217
- f"Speaker_id {speaker_id} is not unique. Candidates: {speaker_candidates}"
218
- )
219
- elif len(speaker_candidates) == 0:
220
- raise ValueError(f"Speaker_id {speaker_id} is not found.")
221
- speaker = speaker_candidates[0][0]
222
- sid = torch.LongTensor([int(speaker_id)]).to(self.device).unsqueeze(0)
223
-
224
- # get unit f0
225
- c, f0, uv = self.get_unit_f0(
226
- audio, transpose, cluster_infer_ratio, speaker, f0_method
227
- )
228
-
229
- # inference
230
- with torch.no_grad():
231
- with timer() as t:
232
- audio = self.net_g.infer(
233
- c,
234
- f0=f0,
235
- g=sid,
236
- uv=uv,
237
- predict_f0=auto_predict_f0,
238
- noice_scale=noise_scale,
239
- )[0, 0].data.float()
240
- audio_duration = audio.shape[-1] / self.target_sample
241
- LOG.info(
242
- f"Inference time: {t.elapsed:.2f}s, RTF: {t.elapsed / audio_duration:.2f}"
243
- )
244
- torch.cuda.empty_cache()
245
- return audio, audio.shape[-1]
246
-
247
- def infer_silence(
248
- self,
249
- audio: np.ndarray[Any, np.dtype[np.float32]],
250
- *,
251
- # svc config
252
- speaker: int | str,
253
- transpose: int = 0,
254
- auto_predict_f0: bool = False,
255
- cluster_infer_ratio: float = 0,
256
- noise_scale: float = 0.4,
257
- f0_method: Literal[
258
- "crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
259
- ] = "dio",
260
- # slice config
261
- db_thresh: int = -40,
262
- pad_seconds: float = 0.5,
263
- chunk_seconds: float = 0.5,
264
- absolute_thresh: bool = False,
265
- max_chunk_seconds: float = 40,
266
- # fade_seconds: float = 0.0,
267
- ) -> np.ndarray[Any, np.dtype[np.float32]]:
268
- sr = self.target_sample
269
- result_audio = np.array([], dtype=np.float32)
270
- chunk_length_min = chunk_length_min = (
271
- int(
272
- min(
273
- sr / so_vits_svc_fork.f0.f0_min * 20 + 1,
274
- chunk_seconds * sr,
275
- )
276
- )
277
- // 2
278
- )
279
- for chunk in split_silence(
280
- audio,
281
- top_db=-db_thresh,
282
- frame_length=chunk_length_min * 2,
283
- hop_length=chunk_length_min,
284
- ref=1 if absolute_thresh else np.max,
285
- max_chunk_length=int(max_chunk_seconds * sr),
286
- ):
287
- LOG.info(f"Chunk: {chunk}")
288
- if not chunk.is_speech:
289
- audio_chunk_infer = np.zeros_like(chunk.audio)
290
- else:
291
- # pad
292
- pad_len = int(sr * pad_seconds)
293
- audio_chunk_pad = np.concatenate(
294
- [
295
- np.zeros([pad_len], dtype=np.float32),
296
- chunk.audio,
297
- np.zeros([pad_len], dtype=np.float32),
298
- ]
299
- )
300
- audio_chunk_pad_infer_tensor, _ = self.infer(
301
- speaker,
302
- transpose,
303
- audio_chunk_pad,
304
- cluster_infer_ratio=cluster_infer_ratio,
305
- auto_predict_f0=auto_predict_f0,
306
- noise_scale=noise_scale,
307
- f0_method=f0_method,
308
- )
309
- audio_chunk_pad_infer = audio_chunk_pad_infer_tensor.cpu().numpy()
310
- pad_len = int(self.target_sample * pad_seconds)
311
- cut_len_2 = (len(audio_chunk_pad_infer) - len(chunk.audio)) // 2
312
- audio_chunk_infer = audio_chunk_pad_infer[
313
- cut_len_2 : cut_len_2 + len(chunk.audio)
314
- ]
315
-
316
- # add fade
317
- # fade_len = int(self.target_sample * fade_seconds)
318
- # _audio[:fade_len] = _audio[:fade_len] * np.linspace(0, 1, fade_len)
319
- # _audio[-fade_len:] = _audio[-fade_len:] * np.linspace(1, 0, fade_len)
320
-
321
- # empty cache
322
- torch.cuda.empty_cache()
323
- result_audio = np.concatenate([result_audio, audio_chunk_infer])
324
- result_audio = result_audio[: audio.shape[0]]
325
- return result_audio
326
-
327
-
328
- def sola_crossfade(
329
- first: ndarray[Any, dtype[float32]],
330
- second: ndarray[Any, dtype[float32]],
331
- crossfade_len: int,
332
- sola_search_len: int,
333
- ) -> ndarray[Any, dtype[float32]]:
334
- cor_nom = np.convolve(
335
- second[: sola_search_len + crossfade_len],
336
- np.flip(first[-crossfade_len:]),
337
- "valid",
338
- )
339
- cor_den = np.sqrt(
340
- np.convolve(
341
- second[: sola_search_len + crossfade_len] ** 2,
342
- np.ones(crossfade_len),
343
- "valid",
344
- )
345
- + 1e-8
346
- )
347
- sola_shift = np.argmax(cor_nom / cor_den)
348
- LOG.info(f"SOLA shift: {sola_shift}")
349
- second = second[sola_shift : sola_shift + len(second) - sola_search_len]
350
- return np.concatenate(
351
- [
352
- first[:-crossfade_len],
353
- first[-crossfade_len:] * np.linspace(1, 0, crossfade_len)
354
- + second[:crossfade_len] * np.linspace(0, 1, crossfade_len),
355
- second[crossfade_len:],
356
- ]
357
- )
358
-
359
-
360
- class Crossfader:
361
- def __init__(
362
- self,
363
- *,
364
- additional_infer_before_len: int,
365
- additional_infer_after_len: int,
366
- crossfade_len: int,
367
- sola_search_len: int = 384,
368
- ) -> None:
369
- if additional_infer_before_len < 0:
370
- raise ValueError("additional_infer_len must be >= 0")
371
- if crossfade_len < 0:
372
- raise ValueError("crossfade_len must be >= 0")
373
- if additional_infer_after_len < 0:
374
- raise ValueError("additional_infer_len must be >= 0")
375
- if additional_infer_before_len < 0:
376
- raise ValueError("additional_infer_len must be >= 0")
377
- self.additional_infer_before_len = additional_infer_before_len
378
- self.additional_infer_after_len = additional_infer_after_len
379
- self.crossfade_len = crossfade_len
380
- self.sola_search_len = sola_search_len
381
- self.last_input_left = np.zeros(
382
- sola_search_len
383
- + crossfade_len
384
- + additional_infer_before_len
385
- + additional_infer_after_len,
386
- dtype=np.float32,
387
- )
388
- self.last_infered_left = np.zeros(crossfade_len, dtype=np.float32)
389
-
390
- def process(
391
- self, input_audio: ndarray[Any, dtype[float32]], *args, **kwargs: Any
392
- ) -> ndarray[Any, dtype[float32]]:
393
- """
394
- chunks : ■■■■■■□□□□□□
395
- add last input:□■■■■■■
396
- ■□□□□□□
397
- infer :□■■■■■■
398
- ■□□□□□□
399
- crossfade :▲■■■■■
400
- ▲□□□□□
401
- """
402
- # check input
403
- if input_audio.ndim != 1:
404
- raise ValueError("Input audio must be 1-dimensional.")
405
- if (
406
- input_audio.shape[0] + self.additional_infer_before_len
407
- <= self.crossfade_len
408
- ):
409
- raise ValueError(
410
- f"Input audio length ({input_audio.shape[0]}) + additional_infer_len ({self.additional_infer_before_len}) must be greater than crossfade_len ({self.crossfade_len})."
411
- )
412
- input_audio = input_audio.astype(np.float32)
413
- input_audio_len = len(input_audio)
414
-
415
- # concat last input and infer
416
- input_audio_concat = np.concatenate([self.last_input_left, input_audio])
417
- del input_audio
418
- pad_len = 0
419
- if pad_len:
420
- infer_audio_concat = self.infer(
421
- np.pad(input_audio_concat, (pad_len, pad_len), mode="reflect"),
422
- *args,
423
- **kwargs,
424
- )[pad_len:-pad_len]
425
- else:
426
- infer_audio_concat = self.infer(input_audio_concat, *args, **kwargs)
427
-
428
- # debug SOLA (using copy synthesis with a random shift)
429
- """
430
- rs = int(np.random.uniform(-200,200))
431
- LOG.info(f"Debug random shift: {rs}")
432
- infer_audio_concat = np.roll(input_audio_concat, rs)
433
- """
434
-
435
- if len(infer_audio_concat) != len(input_audio_concat):
436
- raise ValueError(
437
- f"Inferred audio length ({len(infer_audio_concat)}) should be equal to input audio length ({len(input_audio_concat)})."
438
- )
439
- infer_audio_to_use = infer_audio_concat[
440
- -(
441
- self.sola_search_len
442
- + self.crossfade_len
443
- + input_audio_len
444
- + self.additional_infer_after_len
445
- ) : -self.additional_infer_after_len
446
- ]
447
- assert (
448
- len(infer_audio_to_use)
449
- == input_audio_len + self.sola_search_len + self.crossfade_len
450
- ), f"{len(infer_audio_to_use)} != {input_audio_len + self.sola_search_len + self.cross_fade_len}"
451
- _audio = sola_crossfade(
452
- self.last_infered_left,
453
- infer_audio_to_use,
454
- self.crossfade_len,
455
- self.sola_search_len,
456
- )
457
- result_audio = _audio[: -self.crossfade_len]
458
- assert (
459
- len(result_audio) == input_audio_len
460
- ), f"{len(result_audio)} != {input_audio_len}"
461
-
462
- # update last input and inferred
463
- self.last_input_left = input_audio_concat[
464
- -(
465
- self.sola_search_len
466
- + self.crossfade_len
467
- + self.additional_infer_before_len
468
- + self.additional_infer_after_len
469
- ) :
470
- ]
471
- self.last_infered_left = _audio[-self.crossfade_len :]
472
- return result_audio
473
-
474
- def infer(
475
- self, input_audio: ndarray[Any, dtype[float32]]
476
- ) -> ndarray[Any, dtype[float32]]:
477
- return input_audio
478
-
479
-
480
- class RealtimeVC(Crossfader):
481
- def __init__(
482
- self,
483
- *,
484
- svc_model: Svc,
485
- crossfade_len: int = 3840,
486
- additional_infer_before_len: int = 7680,
487
- additional_infer_after_len: int = 7680,
488
- split: bool = True,
489
- ) -> None:
490
- self.svc_model = svc_model
491
- self.split = split
492
- super().__init__(
493
- crossfade_len=crossfade_len,
494
- additional_infer_before_len=additional_infer_before_len,
495
- additional_infer_after_len=additional_infer_after_len,
496
- )
497
-
498
- def process(
499
- self,
500
- input_audio: ndarray[Any, dtype[float32]],
501
- *args: Any,
502
- **kwargs: Any,
503
- ) -> ndarray[Any, dtype[float32]]:
504
- return super().process(input_audio, *args, **kwargs)
505
-
506
- def infer(
507
- self,
508
- input_audio: np.ndarray[Any, np.dtype[np.float32]],
509
- # svc config
510
- speaker: int | str,
511
- transpose: int,
512
- cluster_infer_ratio: float = 0,
513
- auto_predict_f0: bool = False,
514
- noise_scale: float = 0.4,
515
- f0_method: Literal[
516
- "crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
517
- ] = "dio",
518
- # slice config
519
- db_thresh: int = -40,
520
- pad_seconds: float = 0.5,
521
- chunk_seconds: float = 0.5,
522
- ) -> ndarray[Any, dtype[float32]]:
523
- # infer
524
- if self.split:
525
- return self.svc_model.infer_silence(
526
- audio=input_audio,
527
- speaker=speaker,
528
- transpose=transpose,
529
- cluster_infer_ratio=cluster_infer_ratio,
530
- auto_predict_f0=auto_predict_f0,
531
- noise_scale=noise_scale,
532
- f0_method=f0_method,
533
- db_thresh=db_thresh,
534
- pad_seconds=pad_seconds,
535
- chunk_seconds=chunk_seconds,
536
- absolute_thresh=True,
537
- )
538
- else:
539
- rms = np.sqrt(np.mean(input_audio**2))
540
- min_rms = 10 ** (db_thresh / 20)
541
- if rms < min_rms:
542
- LOG.info(f"Skip silence: RMS={rms:.2f} < {min_rms:.2f}")
543
- return np.zeros_like(input_audio)
544
- else:
545
- LOG.info(f"Start inference: RMS={rms:.2f} >= {min_rms:.2f}")
546
- infered_audio_c, _ = self.svc_model.infer(
547
- speaker=speaker,
548
- transpose=transpose,
549
- audio=input_audio,
550
- cluster_infer_ratio=cluster_infer_ratio,
551
- auto_predict_f0=auto_predict_f0,
552
- noise_scale=noise_scale,
553
- f0_method=f0_method,
554
- )
555
- return infered_audio_c.cpu().numpy()
556
-
557
-
558
- class RealtimeVC2:
559
- chunk_store: list[Chunk]
560
-
561
- def __init__(self, svc_model: Svc) -> None:
562
- self.input_audio_store = np.array([], dtype=np.float32)
563
- self.chunk_store = []
564
- self.svc_model = svc_model
565
-
566
- def process(
567
- self,
568
- input_audio: np.ndarray[Any, np.dtype[np.float32]],
569
- # svc config
570
- speaker: int | str,
571
- transpose: int,
572
- cluster_infer_ratio: float = 0,
573
- auto_predict_f0: bool = False,
574
- noise_scale: float = 0.4,
575
- f0_method: Literal[
576
- "crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
577
- ] = "dio",
578
- # slice config
579
- db_thresh: int = -40,
580
- chunk_seconds: float = 0.5,
581
- ) -> ndarray[Any, dtype[float32]]:
582
- def infer(audio: ndarray[Any, dtype[float32]]) -> ndarray[Any, dtype[float32]]:
583
- infered_audio_c, _ = self.svc_model.infer(
584
- speaker=speaker,
585
- transpose=transpose,
586
- audio=audio,
587
- cluster_infer_ratio=cluster_infer_ratio,
588
- auto_predict_f0=auto_predict_f0,
589
- noise_scale=noise_scale,
590
- f0_method=f0_method,
591
- )
592
- return infered_audio_c.cpu().numpy()
593
-
594
- self.input_audio_store = np.concatenate([self.input_audio_store, input_audio])
595
- LOG.info(f"input_audio_store: {self.input_audio_store.shape}")
596
- sr = self.svc_model.target_sample
597
- chunk_length_min = (
598
- int(min(sr / so_vits_svc_fork.f0.f0_min * 20 + 1, chunk_seconds * sr)) // 2
599
- )
600
- LOG.info(f"Chunk length min: {chunk_length_min}")
601
- chunk_list = list(
602
- split_silence(
603
- self.input_audio_store,
604
- -db_thresh,
605
- frame_length=chunk_length_min * 2,
606
- hop_length=chunk_length_min,
607
- ref=1, # use absolute threshold
608
- )
609
- )
610
- assert len(chunk_list) > 0
611
- LOG.info(f"Chunk list: {chunk_list}")
612
- # do not infer LAST incomplete is_speech chunk and save to store
613
- if chunk_list[-1].is_speech:
614
- self.input_audio_store = chunk_list.pop().audio
615
- else:
616
- self.input_audio_store = np.array([], dtype=np.float32)
617
-
618
- # infer complete is_speech chunk and save to store
619
- self.chunk_store.extend(
620
- [
621
- attrs.evolve(c, audio=infer(c.audio) if c.is_speech else c.audio)
622
- for c in chunk_list
623
- ]
624
- )
625
-
626
- # calculate lengths and determine compress rate
627
- total_speech_len = sum(
628
- [c.duration if c.is_speech else 0 for c in self.chunk_store]
629
- )
630
- total_silence_len = sum(
631
- [c.duration if not c.is_speech else 0 for c in self.chunk_store]
632
- )
633
- input_audio_len = input_audio.shape[0]
634
- silence_compress_rate = total_silence_len / max(
635
- 0, input_audio_len - total_speech_len
636
- )
637
- LOG.info(
638
- f"Total speech len: {total_speech_len}, silence len: {total_silence_len}, silence compress rate: {silence_compress_rate}"
639
- )
640
-
641
- # generate output audio
642
- output_audio = np.array([], dtype=np.float32)
643
- break_flag = False
644
- LOG.info(f"Chunk store: {self.chunk_store}")
645
- for chunk in deepcopy(self.chunk_store):
646
- compress_rate = 1 if chunk.is_speech else silence_compress_rate
647
- left_len = input_audio_len - output_audio.shape[0]
648
- # calculate chunk duration
649
- chunk_duration_output = int(min(chunk.duration / compress_rate, left_len))
650
- chunk_duration_input = int(min(chunk.duration, left_len * compress_rate))
651
- LOG.info(
652
- f"Chunk duration output: {chunk_duration_output}, input: {chunk_duration_input}, left len: {left_len}"
653
- )
654
-
655
- # remove chunk from store
656
- self.chunk_store.pop(0)
657
- if chunk.duration > chunk_duration_input:
658
- left_chunk = attrs.evolve(
659
- chunk, audio=chunk.audio[chunk_duration_input:]
660
- )
661
- chunk = attrs.evolve(chunk, audio=chunk.audio[:chunk_duration_input])
662
-
663
- self.chunk_store.insert(0, left_chunk)
664
- break_flag = True
665
-
666
- if chunk.is_speech:
667
- # if is_speech, just concat
668
- output_audio = np.concatenate([output_audio, chunk.audio])
669
- else:
670
- # if is_silence, concat with zeros and compress with silence_compress_rate
671
- output_audio = np.concatenate(
672
- [
673
- output_audio,
674
- np.zeros(
675
- chunk_duration_output,
676
- dtype=np.float32,
677
- ),
678
- ]
679
- )
680
-
681
- if break_flag:
682
- break
683
- LOG.info(f"Chunk store: {self.chunk_store}, output_audio: {output_audio.shape}")
684
- # make same length (errors)
685
- output_audio = output_audio[:input_audio_len]
686
- output_audio = np.concatenate(
687
- [
688
- output_audio,
689
- np.zeros(input_audio_len - output_audio.shape[0], dtype=np.float32),
690
- ]
691
- )
692
- return output_audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/inference/main.py DELETED
@@ -1,272 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from logging import getLogger
4
- from pathlib import Path
5
- from typing import Literal, Sequence
6
-
7
- import librosa
8
- import numpy as np
9
- import soundfile
10
- import torch
11
- from cm_time import timer
12
- from tqdm import tqdm
13
-
14
- from so_vits_svc_fork.inference.core import RealtimeVC, RealtimeVC2, Svc
15
- from so_vits_svc_fork.utils import get_optimal_device
16
-
17
- LOG = getLogger(__name__)
18
-
19
-
20
- def infer(
21
- *,
22
- # paths
23
- input_path: Path | str | Sequence[Path | str],
24
- output_path: Path | str | Sequence[Path | str],
25
- model_path: Path | str,
26
- config_path: Path | str,
27
- recursive: bool = False,
28
- # svc config
29
- speaker: int | str,
30
- cluster_model_path: Path | str | None = None,
31
- transpose: int = 0,
32
- auto_predict_f0: bool = False,
33
- cluster_infer_ratio: float = 0,
34
- noise_scale: float = 0.4,
35
- f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
36
- # slice config
37
- db_thresh: int = -40,
38
- pad_seconds: float = 0.5,
39
- chunk_seconds: float = 0.5,
40
- absolute_thresh: bool = False,
41
- max_chunk_seconds: float = 40,
42
- device: str | torch.device = get_optimal_device(),
43
- ):
44
- if isinstance(input_path, (str, Path)):
45
- input_path = [input_path]
46
- if isinstance(output_path, (str, Path)):
47
- output_path = [output_path]
48
- if len(input_path) != len(output_path):
49
- raise ValueError(
50
- f"input_path and output_path must have same length, but got {len(input_path)} and {len(output_path)}"
51
- )
52
-
53
- model_path = Path(model_path)
54
- config_path = Path(config_path)
55
- output_path = [Path(p) for p in output_path]
56
- input_path = [Path(p) for p in input_path]
57
- output_paths = []
58
- input_paths = []
59
-
60
- for input_path, output_path in zip(input_path, output_path):
61
- if input_path.is_dir():
62
- if not recursive:
63
- raise ValueError(
64
- f"input_path is a directory, but recursive is False: {input_path}"
65
- )
66
- input_paths.extend(list(input_path.rglob("*.*")))
67
- output_paths.extend(
68
- [output_path / p.relative_to(input_path) for p in input_paths]
69
- )
70
- continue
71
- input_paths.append(input_path)
72
- output_paths.append(output_path)
73
-
74
- cluster_model_path = Path(cluster_model_path) if cluster_model_path else None
75
- svc_model = Svc(
76
- net_g_path=model_path.as_posix(),
77
- config_path=config_path.as_posix(),
78
- cluster_model_path=cluster_model_path.as_posix()
79
- if cluster_model_path
80
- else None,
81
- device=device,
82
- )
83
-
84
- try:
85
- pbar = tqdm(list(zip(input_paths, output_paths)), disable=len(input_paths) == 1)
86
- for input_path, output_path in pbar:
87
- pbar.set_description(f"{input_path}")
88
- try:
89
- audio, _ = librosa.load(str(input_path), sr=svc_model.target_sample)
90
- except Exception as e:
91
- LOG.error(f"Failed to load {input_path}")
92
- LOG.exception(e)
93
- continue
94
- output_path.parent.mkdir(parents=True, exist_ok=True)
95
- audio = svc_model.infer_silence(
96
- audio.astype(np.float32),
97
- speaker=speaker,
98
- transpose=transpose,
99
- auto_predict_f0=auto_predict_f0,
100
- cluster_infer_ratio=cluster_infer_ratio,
101
- noise_scale=noise_scale,
102
- f0_method=f0_method,
103
- db_thresh=db_thresh,
104
- pad_seconds=pad_seconds,
105
- chunk_seconds=chunk_seconds,
106
- absolute_thresh=absolute_thresh,
107
- max_chunk_seconds=max_chunk_seconds,
108
- )
109
- soundfile.write(str(output_path), audio, svc_model.target_sample)
110
- finally:
111
- del svc_model
112
- torch.cuda.empty_cache()
113
-
114
-
115
- def realtime(
116
- *,
117
- # paths
118
- model_path: Path | str,
119
- config_path: Path | str,
120
- # svc config
121
- speaker: str,
122
- cluster_model_path: Path | str | None = None,
123
- transpose: int = 0,
124
- auto_predict_f0: bool = False,
125
- cluster_infer_ratio: float = 0,
126
- noise_scale: float = 0.4,
127
- f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
128
- # slice config
129
- db_thresh: int = -40,
130
- pad_seconds: float = 0.5,
131
- chunk_seconds: float = 0.5,
132
- # realtime config
133
- crossfade_seconds: float = 0.05,
134
- additional_infer_before_seconds: float = 0.2,
135
- additional_infer_after_seconds: float = 0.1,
136
- block_seconds: float = 0.5,
137
- version: int = 2,
138
- input_device: int | str | None = None,
139
- output_device: int | str | None = None,
140
- device: str | torch.device = get_optimal_device(),
141
- passthrough_original: bool = False,
142
- ):
143
- import sounddevice as sd
144
-
145
- model_path = Path(model_path)
146
- config_path = Path(config_path)
147
- cluster_model_path = Path(cluster_model_path) if cluster_model_path else None
148
- svc_model = Svc(
149
- net_g_path=model_path.as_posix(),
150
- config_path=config_path.as_posix(),
151
- cluster_model_path=cluster_model_path.as_posix()
152
- if cluster_model_path
153
- else None,
154
- device=device,
155
- )
156
-
157
- LOG.info("Creating realtime model...")
158
- if version == 1:
159
- model = RealtimeVC(
160
- svc_model=svc_model,
161
- crossfade_len=int(crossfade_seconds * svc_model.target_sample),
162
- additional_infer_before_len=int(
163
- additional_infer_before_seconds * svc_model.target_sample
164
- ),
165
- additional_infer_after_len=int(
166
- additional_infer_after_seconds * svc_model.target_sample
167
- ),
168
- )
169
- else:
170
- model = RealtimeVC2(
171
- svc_model=svc_model,
172
- )
173
-
174
- # LOG all device info
175
- devices = sd.query_devices()
176
- LOG.info(f"Device: {devices}")
177
- if isinstance(input_device, str):
178
- input_device_candidates = [
179
- i for i, d in enumerate(devices) if d["name"] == input_device
180
- ]
181
- if len(input_device_candidates) == 0:
182
- LOG.warning(f"Input device {input_device} not found, using default")
183
- input_device = None
184
- else:
185
- input_device = input_device_candidates[0]
186
- if isinstance(output_device, str):
187
- output_device_candidates = [
188
- i for i, d in enumerate(devices) if d["name"] == output_device
189
- ]
190
- if len(output_device_candidates) == 0:
191
- LOG.warning(f"Output device {output_device} not found, using default")
192
- output_device = None
193
- else:
194
- output_device = output_device_candidates[0]
195
- if input_device is None or input_device >= len(devices):
196
- input_device = sd.default.device[0]
197
- if output_device is None or output_device >= len(devices):
198
- output_device = sd.default.device[1]
199
- LOG.info(
200
- f"Input Device: {devices[input_device]['name']}, Output Device: {devices[output_device]['name']}"
201
- )
202
-
203
- # the model RTL is somewhat significantly high only in the first inference
204
- # there could be no better way to warm up the model than to do a dummy inference
205
- # (there are not differences in the behavior of the model between the first and the later inferences)
206
- # so we do a dummy inference to warm up the model (1 second of audio)
207
- LOG.info("Warming up the model...")
208
- svc_model.infer(
209
- speaker=speaker,
210
- transpose=transpose,
211
- auto_predict_f0=auto_predict_f0,
212
- cluster_infer_ratio=cluster_infer_ratio,
213
- noise_scale=noise_scale,
214
- f0_method=f0_method,
215
- audio=np.zeros(svc_model.target_sample, dtype=np.float32),
216
- )
217
-
218
- def callback(
219
- indata: np.ndarray,
220
- outdata: np.ndarray,
221
- frames: int,
222
- time: int,
223
- status: sd.CallbackFlags,
224
- ) -> None:
225
- LOG.debug(
226
- f"Frames: {frames}, Status: {status}, Shape: {indata.shape}, Time: {time}"
227
- )
228
-
229
- kwargs = dict(
230
- input_audio=indata.mean(axis=1).astype(np.float32),
231
- # svc config
232
- speaker=speaker,
233
- transpose=transpose,
234
- auto_predict_f0=auto_predict_f0,
235
- cluster_infer_ratio=cluster_infer_ratio,
236
- noise_scale=noise_scale,
237
- f0_method=f0_method,
238
- # slice config
239
- db_thresh=db_thresh,
240
- # pad_seconds=pad_seconds,
241
- chunk_seconds=chunk_seconds,
242
- )
243
- if version == 1:
244
- kwargs["pad_seconds"] = pad_seconds
245
- with timer() as t:
246
- inference = model.process(
247
- **kwargs,
248
- ).reshape(-1, 1)
249
- if passthrough_original:
250
- outdata[:] = (indata + inference) / 2
251
- else:
252
- outdata[:] = inference
253
- rtf = t.elapsed / block_seconds
254
- LOG.info(f"Realtime inference time: {t.elapsed:.3f}s, RTF: {rtf:.3f}")
255
- if rtf > 1:
256
- LOG.warning("RTF is too high, consider increasing block_seconds")
257
-
258
- try:
259
- with sd.Stream(
260
- device=(input_device, output_device),
261
- channels=1,
262
- callback=callback,
263
- samplerate=svc_model.target_sample,
264
- blocksize=int(block_seconds * svc_model.target_sample),
265
- latency="low",
266
- ) as stream:
267
- LOG.info(f"Latency: {stream.latency}")
268
- while True:
269
- sd.sleep(1000)
270
- finally:
271
- # del model, svc_model
272
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/logger.py DELETED
@@ -1,46 +0,0 @@
1
- import os
2
- import sys
3
- from logging import DEBUG, INFO, StreamHandler, basicConfig, captureWarnings, getLogger
4
- from pathlib import Path
5
-
6
- from rich.logging import RichHandler
7
-
8
- LOGGER_INIT = False
9
-
10
-
11
- def init_logger() -> None:
12
- global LOGGER_INIT
13
- if LOGGER_INIT:
14
- return
15
-
16
- IS_TEST = "test" in Path.cwd().stem
17
- package_name = sys.modules[__name__].__package__
18
- basicConfig(
19
- level=INFO,
20
- format="%(asctime)s %(message)s",
21
- datefmt="[%X]",
22
- handlers=[
23
- StreamHandler() if is_notebook() else RichHandler(),
24
- # FileHandler(f"{package_name}.log"),
25
- ],
26
- )
27
- if IS_TEST:
28
- getLogger(package_name).setLevel(DEBUG)
29
- captureWarnings(True)
30
- LOGGER_INIT = True
31
-
32
-
33
- def is_notebook():
34
- try:
35
- from IPython import get_ipython
36
-
37
- if "IPKernelApp" not in get_ipython().config: # pragma: no cover
38
- raise ImportError("console")
39
- return False
40
- if "VSCODE_PID" in os.environ: # pragma: no cover
41
- raise ImportError("vscode")
42
- return False
43
- except Exception:
44
- return False
45
- else: # pragma: no cover
46
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/__init__.py DELETED
File without changes
so_vits_svc_fork/modules/attentions.py DELETED
@@ -1,488 +0,0 @@
1
- import math
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
-
7
- from so_vits_svc_fork.modules import commons
8
- from so_vits_svc_fork.modules.modules import LayerNorm
9
-
10
-
11
- class FFT(nn.Module):
12
- def __init__(
13
- self,
14
- hidden_channels,
15
- filter_channels,
16
- n_heads,
17
- n_layers=1,
18
- kernel_size=1,
19
- p_dropout=0.0,
20
- proximal_bias=False,
21
- proximal_init=True,
22
- **kwargs
23
- ):
24
- super().__init__()
25
- self.hidden_channels = hidden_channels
26
- self.filter_channels = filter_channels
27
- self.n_heads = n_heads
28
- self.n_layers = n_layers
29
- self.kernel_size = kernel_size
30
- self.p_dropout = p_dropout
31
- self.proximal_bias = proximal_bias
32
- self.proximal_init = proximal_init
33
-
34
- self.drop = nn.Dropout(p_dropout)
35
- self.self_attn_layers = nn.ModuleList()
36
- self.norm_layers_0 = nn.ModuleList()
37
- self.ffn_layers = nn.ModuleList()
38
- self.norm_layers_1 = nn.ModuleList()
39
- for i in range(self.n_layers):
40
- self.self_attn_layers.append(
41
- MultiHeadAttention(
42
- hidden_channels,
43
- hidden_channels,
44
- n_heads,
45
- p_dropout=p_dropout,
46
- proximal_bias=proximal_bias,
47
- proximal_init=proximal_init,
48
- )
49
- )
50
- self.norm_layers_0.append(LayerNorm(hidden_channels))
51
- self.ffn_layers.append(
52
- FFN(
53
- hidden_channels,
54
- hidden_channels,
55
- filter_channels,
56
- kernel_size,
57
- p_dropout=p_dropout,
58
- causal=True,
59
- )
60
- )
61
- self.norm_layers_1.append(LayerNorm(hidden_channels))
62
-
63
- def forward(self, x, x_mask):
64
- """
65
- x: decoder input
66
- h: encoder output
67
- """
68
- self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
69
- device=x.device, dtype=x.dtype
70
- )
71
- x = x * x_mask
72
- for i in range(self.n_layers):
73
- y = self.self_attn_layers[i](x, x, self_attn_mask)
74
- y = self.drop(y)
75
- x = self.norm_layers_0[i](x + y)
76
-
77
- y = self.ffn_layers[i](x, x_mask)
78
- y = self.drop(y)
79
- x = self.norm_layers_1[i](x + y)
80
- x = x * x_mask
81
- return x
82
-
83
-
84
- class Encoder(nn.Module):
85
- def __init__(
86
- self,
87
- hidden_channels,
88
- filter_channels,
89
- n_heads,
90
- n_layers,
91
- kernel_size=1,
92
- p_dropout=0.0,
93
- window_size=4,
94
- **kwargs
95
- ):
96
- super().__init__()
97
- self.hidden_channels = hidden_channels
98
- self.filter_channels = filter_channels
99
- self.n_heads = n_heads
100
- self.n_layers = n_layers
101
- self.kernel_size = kernel_size
102
- self.p_dropout = p_dropout
103
- self.window_size = window_size
104
-
105
- self.drop = nn.Dropout(p_dropout)
106
- self.attn_layers = nn.ModuleList()
107
- self.norm_layers_1 = nn.ModuleList()
108
- self.ffn_layers = nn.ModuleList()
109
- self.norm_layers_2 = nn.ModuleList()
110
- for i in range(self.n_layers):
111
- self.attn_layers.append(
112
- MultiHeadAttention(
113
- hidden_channels,
114
- hidden_channels,
115
- n_heads,
116
- p_dropout=p_dropout,
117
- window_size=window_size,
118
- )
119
- )
120
- self.norm_layers_1.append(LayerNorm(hidden_channels))
121
- self.ffn_layers.append(
122
- FFN(
123
- hidden_channels,
124
- hidden_channels,
125
- filter_channels,
126
- kernel_size,
127
- p_dropout=p_dropout,
128
- )
129
- )
130
- self.norm_layers_2.append(LayerNorm(hidden_channels))
131
-
132
- def forward(self, x, x_mask):
133
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
134
- x = x * x_mask
135
- for i in range(self.n_layers):
136
- y = self.attn_layers[i](x, x, attn_mask)
137
- y = self.drop(y)
138
- x = self.norm_layers_1[i](x + y)
139
-
140
- y = self.ffn_layers[i](x, x_mask)
141
- y = self.drop(y)
142
- x = self.norm_layers_2[i](x + y)
143
- x = x * x_mask
144
- return x
145
-
146
-
147
- class Decoder(nn.Module):
148
- def __init__(
149
- self,
150
- hidden_channels,
151
- filter_channels,
152
- n_heads,
153
- n_layers,
154
- kernel_size=1,
155
- p_dropout=0.0,
156
- proximal_bias=False,
157
- proximal_init=True,
158
- **kwargs
159
- ):
160
- super().__init__()
161
- self.hidden_channels = hidden_channels
162
- self.filter_channels = filter_channels
163
- self.n_heads = n_heads
164
- self.n_layers = n_layers
165
- self.kernel_size = kernel_size
166
- self.p_dropout = p_dropout
167
- self.proximal_bias = proximal_bias
168
- self.proximal_init = proximal_init
169
-
170
- self.drop = nn.Dropout(p_dropout)
171
- self.self_attn_layers = nn.ModuleList()
172
- self.norm_layers_0 = nn.ModuleList()
173
- self.encdec_attn_layers = nn.ModuleList()
174
- self.norm_layers_1 = nn.ModuleList()
175
- self.ffn_layers = nn.ModuleList()
176
- self.norm_layers_2 = nn.ModuleList()
177
- for i in range(self.n_layers):
178
- self.self_attn_layers.append(
179
- MultiHeadAttention(
180
- hidden_channels,
181
- hidden_channels,
182
- n_heads,
183
- p_dropout=p_dropout,
184
- proximal_bias=proximal_bias,
185
- proximal_init=proximal_init,
186
- )
187
- )
188
- self.norm_layers_0.append(LayerNorm(hidden_channels))
189
- self.encdec_attn_layers.append(
190
- MultiHeadAttention(
191
- hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
192
- )
193
- )
194
- self.norm_layers_1.append(LayerNorm(hidden_channels))
195
- self.ffn_layers.append(
196
- FFN(
197
- hidden_channels,
198
- hidden_channels,
199
- filter_channels,
200
- kernel_size,
201
- p_dropout=p_dropout,
202
- causal=True,
203
- )
204
- )
205
- self.norm_layers_2.append(LayerNorm(hidden_channels))
206
-
207
- def forward(self, x, x_mask, h, h_mask):
208
- """
209
- x: decoder input
210
- h: encoder output
211
- """
212
- self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
213
- device=x.device, dtype=x.dtype
214
- )
215
- encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
216
- x = x * x_mask
217
- for i in range(self.n_layers):
218
- y = self.self_attn_layers[i](x, x, self_attn_mask)
219
- y = self.drop(y)
220
- x = self.norm_layers_0[i](x + y)
221
-
222
- y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
223
- y = self.drop(y)
224
- x = self.norm_layers_1[i](x + y)
225
-
226
- y = self.ffn_layers[i](x, x_mask)
227
- y = self.drop(y)
228
- x = self.norm_layers_2[i](x + y)
229
- x = x * x_mask
230
- return x
231
-
232
-
233
- class MultiHeadAttention(nn.Module):
234
- def __init__(
235
- self,
236
- channels,
237
- out_channels,
238
- n_heads,
239
- p_dropout=0.0,
240
- window_size=None,
241
- heads_share=True,
242
- block_length=None,
243
- proximal_bias=False,
244
- proximal_init=False,
245
- ):
246
- super().__init__()
247
- assert channels % n_heads == 0
248
-
249
- self.channels = channels
250
- self.out_channels = out_channels
251
- self.n_heads = n_heads
252
- self.p_dropout = p_dropout
253
- self.window_size = window_size
254
- self.heads_share = heads_share
255
- self.block_length = block_length
256
- self.proximal_bias = proximal_bias
257
- self.proximal_init = proximal_init
258
- self.attn = None
259
-
260
- self.k_channels = channels // n_heads
261
- self.conv_q = nn.Conv1d(channels, channels, 1)
262
- self.conv_k = nn.Conv1d(channels, channels, 1)
263
- self.conv_v = nn.Conv1d(channels, channels, 1)
264
- self.conv_o = nn.Conv1d(channels, out_channels, 1)
265
- self.drop = nn.Dropout(p_dropout)
266
-
267
- if window_size is not None:
268
- n_heads_rel = 1 if heads_share else n_heads
269
- rel_stddev = self.k_channels**-0.5
270
- self.emb_rel_k = nn.Parameter(
271
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
272
- * rel_stddev
273
- )
274
- self.emb_rel_v = nn.Parameter(
275
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
276
- * rel_stddev
277
- )
278
-
279
- nn.init.xavier_uniform_(self.conv_q.weight)
280
- nn.init.xavier_uniform_(self.conv_k.weight)
281
- nn.init.xavier_uniform_(self.conv_v.weight)
282
- if proximal_init:
283
- with torch.no_grad():
284
- self.conv_k.weight.copy_(self.conv_q.weight)
285
- self.conv_k.bias.copy_(self.conv_q.bias)
286
-
287
- def forward(self, x, c, attn_mask=None):
288
- q = self.conv_q(x)
289
- k = self.conv_k(c)
290
- v = self.conv_v(c)
291
-
292
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
293
-
294
- x = self.conv_o(x)
295
- return x
296
-
297
- def attention(self, query, key, value, mask=None):
298
- # reshape [b, d, t] -> [b, n_h, t, d_k]
299
- b, d, t_s, t_t = (*key.size(), query.size(2))
300
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
301
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
302
- value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
303
-
304
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
305
- if self.window_size is not None:
306
- assert (
307
- t_s == t_t
308
- ), "Relative attention is only available for self-attention."
309
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
310
- rel_logits = self._matmul_with_relative_keys(
311
- query / math.sqrt(self.k_channels), key_relative_embeddings
312
- )
313
- scores_local = self._relative_position_to_absolute_position(rel_logits)
314
- scores = scores + scores_local
315
- if self.proximal_bias:
316
- assert t_s == t_t, "Proximal bias is only available for self-attention."
317
- scores = scores + self._attention_bias_proximal(t_s).to(
318
- device=scores.device, dtype=scores.dtype
319
- )
320
- if mask is not None:
321
- scores = scores.masked_fill(mask == 0, -1e4)
322
- if self.block_length is not None:
323
- assert (
324
- t_s == t_t
325
- ), "Local attention is only available for self-attention."
326
- block_mask = (
327
- torch.ones_like(scores)
328
- .triu(-self.block_length)
329
- .tril(self.block_length)
330
- )
331
- scores = scores.masked_fill(block_mask == 0, -1e4)
332
- p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
333
- p_attn = self.drop(p_attn)
334
- output = torch.matmul(p_attn, value)
335
- if self.window_size is not None:
336
- relative_weights = self._absolute_position_to_relative_position(p_attn)
337
- value_relative_embeddings = self._get_relative_embeddings(
338
- self.emb_rel_v, t_s
339
- )
340
- output = output + self._matmul_with_relative_values(
341
- relative_weights, value_relative_embeddings
342
- )
343
- output = (
344
- output.transpose(2, 3).contiguous().view(b, d, t_t)
345
- ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
346
- return output, p_attn
347
-
348
- def _matmul_with_relative_values(self, x, y):
349
- """
350
- x: [b, h, l, m]
351
- y: [h or 1, m, d]
352
- ret: [b, h, l, d]
353
- """
354
- ret = torch.matmul(x, y.unsqueeze(0))
355
- return ret
356
-
357
- def _matmul_with_relative_keys(self, x, y):
358
- """
359
- x: [b, h, l, d]
360
- y: [h or 1, m, d]
361
- ret: [b, h, l, m]
362
- """
363
- ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
364
- return ret
365
-
366
- def _get_relative_embeddings(self, relative_embeddings, length):
367
- 2 * self.window_size + 1
368
- # Pad first before slice to avoid using cond ops.
369
- pad_length = max(length - (self.window_size + 1), 0)
370
- slice_start_position = max((self.window_size + 1) - length, 0)
371
- slice_end_position = slice_start_position + 2 * length - 1
372
- if pad_length > 0:
373
- padded_relative_embeddings = F.pad(
374
- relative_embeddings,
375
- commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
376
- )
377
- else:
378
- padded_relative_embeddings = relative_embeddings
379
- used_relative_embeddings = padded_relative_embeddings[
380
- :, slice_start_position:slice_end_position
381
- ]
382
- return used_relative_embeddings
383
-
384
- def _relative_position_to_absolute_position(self, x):
385
- """
386
- x: [b, h, l, 2*l-1]
387
- ret: [b, h, l, l]
388
- """
389
- batch, heads, length, _ = x.size()
390
- # Concat columns of pad to shift from relative to absolute indexing.
391
- x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
392
-
393
- # Concat extra elements so to add up to shape (len+1, 2*len-1).
394
- x_flat = x.view([batch, heads, length * 2 * length])
395
- x_flat = F.pad(
396
- x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
397
- )
398
-
399
- # Reshape and slice out the padded elements.
400
- x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
401
- :, :, :length, length - 1 :
402
- ]
403
- return x_final
404
-
405
- def _absolute_position_to_relative_position(self, x):
406
- """
407
- x: [b, h, l, l]
408
- ret: [b, h, l, 2*l-1]
409
- """
410
- batch, heads, length, _ = x.size()
411
- # pad along column
412
- x = F.pad(
413
- x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
414
- )
415
- x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
416
- # add 0's in the beginning that will skew the elements after reshape
417
- x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
418
- x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
419
- return x_final
420
-
421
- def _attention_bias_proximal(self, length):
422
- """Bias for self-attention to encourage attention to close positions.
423
- Args:
424
- length: an integer scalar.
425
- Returns:
426
- a Tensor with shape [1, 1, length, length]
427
- """
428
- r = torch.arange(length, dtype=torch.float32)
429
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
430
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
431
-
432
-
433
- class FFN(nn.Module):
434
- def __init__(
435
- self,
436
- in_channels,
437
- out_channels,
438
- filter_channels,
439
- kernel_size,
440
- p_dropout=0.0,
441
- activation=None,
442
- causal=False,
443
- ):
444
- super().__init__()
445
- self.in_channels = in_channels
446
- self.out_channels = out_channels
447
- self.filter_channels = filter_channels
448
- self.kernel_size = kernel_size
449
- self.p_dropout = p_dropout
450
- self.activation = activation
451
- self.causal = causal
452
-
453
- if causal:
454
- self.padding = self._causal_padding
455
- else:
456
- self.padding = self._same_padding
457
-
458
- self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
459
- self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
460
- self.drop = nn.Dropout(p_dropout)
461
-
462
- def forward(self, x, x_mask):
463
- x = self.conv_1(self.padding(x * x_mask))
464
- if self.activation == "gelu":
465
- x = x * torch.sigmoid(1.702 * x)
466
- else:
467
- x = torch.relu(x)
468
- x = self.drop(x)
469
- x = self.conv_2(self.padding(x * x_mask))
470
- return x * x_mask
471
-
472
- def _causal_padding(self, x):
473
- if self.kernel_size == 1:
474
- return x
475
- pad_l = self.kernel_size - 1
476
- pad_r = 0
477
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
478
- x = F.pad(x, commons.convert_pad_shape(padding))
479
- return x
480
-
481
- def _same_padding(self, x):
482
- if self.kernel_size == 1:
483
- return x
484
- pad_l = (self.kernel_size - 1) // 2
485
- pad_r = self.kernel_size // 2
486
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
487
- x = F.pad(x, commons.convert_pad_shape(padding))
488
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/commons.py DELETED
@@ -1,132 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from torch import Tensor
6
-
7
-
8
- def slice_segments(x: Tensor, starts: Tensor, length: int) -> Tensor:
9
- if length is None:
10
- return x
11
- length = min(length, x.size(-1))
12
- x_slice = torch.zeros((x.size()[:-1] + (length,)), dtype=x.dtype, device=x.device)
13
- ends = starts + length
14
- for i, (start, end) in enumerate(zip(starts, ends)):
15
- # LOG.debug(i, start, end, x.size(), x[i, ..., start:end].size(), x_slice.size())
16
- # x_slice[i, ...] = x[i, ..., start:end] need to pad
17
- # x_slice[i, ..., :end - start] = x[i, ..., start:end] this does not work
18
- x_slice[i, ...] = F.pad(x[i, ..., start:end], (0, max(0, length - x.size(-1))))
19
- return x_slice
20
-
21
-
22
- def rand_slice_segments_with_pitch(
23
- x: Tensor, f0: Tensor, x_lengths: Tensor | int | None, segment_size: int | None
24
- ):
25
- if segment_size is None:
26
- return x, f0, torch.arange(x.size(0), device=x.device)
27
- if x_lengths is None:
28
- x_lengths = x.size(-1) * torch.ones(
29
- x.size(0), dtype=torch.long, device=x.device
30
- )
31
- # slice_starts = (torch.rand(z.size(0), device=z.device) * (z_lengths - segment_size)).long()
32
- slice_starts = (
33
- torch.rand(x.size(0), device=x.device)
34
- * torch.max(
35
- x_lengths - segment_size, torch.zeros_like(x_lengths, device=x.device)
36
- )
37
- ).long()
38
- z_slice = slice_segments(x, slice_starts, segment_size)
39
- f0_slice = slice_segments(f0, slice_starts, segment_size)
40
- return z_slice, f0_slice, slice_starts
41
-
42
-
43
- def slice_2d_segments(x: Tensor, starts: Tensor, length: int) -> Tensor:
44
- batch_size, num_features, seq_len = x.shape
45
- ends = starts + length
46
- idxs = (
47
- torch.arange(seq_len, device=x.device)
48
- .unsqueeze(0)
49
- .unsqueeze(1)
50
- .repeat(batch_size, num_features, 1)
51
- )
52
- mask = (idxs >= starts.unsqueeze(-1).unsqueeze(-1)) & (
53
- idxs < ends.unsqueeze(-1).unsqueeze(-1)
54
- )
55
- return x[mask].reshape(batch_size, num_features, length)
56
-
57
-
58
- def slice_1d_segments(x: Tensor, starts: Tensor, length: int) -> Tensor:
59
- batch_size, seq_len = x.shape
60
- ends = starts + length
61
- idxs = torch.arange(seq_len, device=x.device).unsqueeze(0).repeat(batch_size, 1)
62
- mask = (idxs >= starts.unsqueeze(-1)) & (idxs < ends.unsqueeze(-1))
63
- return x[mask].reshape(batch_size, length)
64
-
65
-
66
- def _slice_segments_v3(x: Tensor, starts: Tensor, length: int) -> Tensor:
67
- shape = x.shape[:-1] + (length,)
68
- ends = starts + length
69
- idxs = torch.arange(x.shape[-1], device=x.device).unsqueeze(0).unsqueeze(0)
70
- unsqueeze_dims = len(shape) - len(
71
- x.shape
72
- ) # calculate number of dimensions to unsqueeze
73
- starts = starts.reshape(starts.shape + (1,) * unsqueeze_dims)
74
- ends = ends.reshape(ends.shape + (1,) * unsqueeze_dims)
75
- mask = (idxs >= starts) & (idxs < ends)
76
- return x[mask].reshape(shape)
77
-
78
-
79
- def init_weights(m, mean=0.0, std=0.01):
80
- classname = m.__class__.__name__
81
- if classname.find("Conv") != -1:
82
- m.weight.data.normal_(mean, std)
83
-
84
-
85
- def get_padding(kernel_size, dilation=1):
86
- return int((kernel_size * dilation - dilation) / 2)
87
-
88
-
89
- def convert_pad_shape(pad_shape):
90
- l = pad_shape[::-1]
91
- pad_shape = [item for sublist in l for item in sublist]
92
- return pad_shape
93
-
94
-
95
- def subsequent_mask(length):
96
- mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97
- return mask
98
-
99
-
100
- @torch.jit.script
101
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102
- n_channels_int = n_channels[0]
103
- in_act = input_a + input_b
104
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
105
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106
- acts = t_act * s_act
107
- return acts
108
-
109
-
110
- def sequence_mask(length, max_length=None):
111
- if max_length is None:
112
- max_length = length.max()
113
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
114
- return x.unsqueeze(0) < length.unsqueeze(1)
115
-
116
-
117
- def clip_grad_value_(parameters, clip_value, norm_type=2):
118
- if isinstance(parameters, torch.Tensor):
119
- parameters = [parameters]
120
- parameters = list(filter(lambda p: p.grad is not None, parameters))
121
- norm_type = float(norm_type)
122
- if clip_value is not None:
123
- clip_value = float(clip_value)
124
-
125
- total_norm = 0
126
- for p in parameters:
127
- param_norm = p.grad.data.norm(norm_type)
128
- total_norm += param_norm.item() ** norm_type
129
- if clip_value is not None:
130
- p.grad.data.clamp_(min=-clip_value, max=clip_value)
131
- total_norm = total_norm ** (1.0 / norm_type)
132
- return total_norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/decoders/__init__.py DELETED
File without changes
so_vits_svc_fork/modules/decoders/f0.py DELETED
@@ -1,46 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
- from so_vits_svc_fork.modules import attentions as attentions
5
-
6
-
7
- class F0Decoder(nn.Module):
8
- def __init__(
9
- self,
10
- out_channels,
11
- hidden_channels,
12
- filter_channels,
13
- n_heads,
14
- n_layers,
15
- kernel_size,
16
- p_dropout,
17
- spk_channels=0,
18
- ):
19
- super().__init__()
20
- self.out_channels = out_channels
21
- self.hidden_channels = hidden_channels
22
- self.filter_channels = filter_channels
23
- self.n_heads = n_heads
24
- self.n_layers = n_layers
25
- self.kernel_size = kernel_size
26
- self.p_dropout = p_dropout
27
- self.spk_channels = spk_channels
28
-
29
- self.prenet = nn.Conv1d(hidden_channels, hidden_channels, 3, padding=1)
30
- self.decoder = attentions.FFT(
31
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
32
- )
33
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
34
- self.f0_prenet = nn.Conv1d(1, hidden_channels, 3, padding=1)
35
- self.cond = nn.Conv1d(spk_channels, hidden_channels, 1)
36
-
37
- def forward(self, x, norm_f0, x_mask, spk_emb=None):
38
- x = torch.detach(x)
39
- if spk_emb is not None:
40
- spk_emb = torch.detach(spk_emb)
41
- x = x + self.cond(spk_emb)
42
- x += self.f0_prenet(norm_f0)
43
- x = self.prenet(x) * x_mask
44
- x = self.decoder(x * x_mask, x_mask)
45
- x = self.proj(x) * x_mask
46
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/decoders/hifigan/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from ._models import NSFHifiGANGenerator
2
-
3
- __all__ = ["NSFHifiGANGenerator"]
 
 
 
 
so_vits_svc_fork/modules/decoders/hifigan/_models.py DELETED
@@ -1,311 +0,0 @@
1
- from logging import getLogger
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from torch.nn import Conv1d, ConvTranspose1d
8
- from torch.nn.utils import remove_weight_norm, weight_norm
9
-
10
- from ...modules import ResBlock1, ResBlock2
11
- from ._utils import init_weights
12
-
13
- LOG = getLogger(__name__)
14
-
15
- LRELU_SLOPE = 0.1
16
-
17
-
18
- def padDiff(x):
19
- return F.pad(
20
- F.pad(x, (0, 0, -1, 1), "constant", 0) - x, (0, 0, 0, -1), "constant", 0
21
- )
22
-
23
-
24
- class SineGen(torch.nn.Module):
25
- """Definition of sine generator
26
- SineGen(samp_rate, harmonic_num = 0,
27
- sine_amp = 0.1, noise_std = 0.003,
28
- voiced_threshold = 0,
29
- flag_for_pulse=False)
30
- samp_rate: sampling rate in Hz
31
- harmonic_num: number of harmonic overtones (default 0)
32
- sine_amp: amplitude of sine-wavefrom (default 0.1)
33
- noise_std: std of Gaussian noise (default 0.003)
34
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
35
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
36
- Note: when flag_for_pulse is True, the first time step of a voiced
37
- segment is always sin(np.pi) or cos(0)
38
- """
39
-
40
- def __init__(
41
- self,
42
- samp_rate,
43
- harmonic_num=0,
44
- sine_amp=0.1,
45
- noise_std=0.003,
46
- voiced_threshold=0,
47
- flag_for_pulse=False,
48
- ):
49
- super().__init__()
50
- self.sine_amp = sine_amp
51
- self.noise_std = noise_std
52
- self.harmonic_num = harmonic_num
53
- self.dim = self.harmonic_num + 1
54
- self.sampling_rate = samp_rate
55
- self.voiced_threshold = voiced_threshold
56
- self.flag_for_pulse = flag_for_pulse
57
-
58
- def _f02uv(self, f0):
59
- # generate uv signal
60
- uv = (f0 > self.voiced_threshold).type(torch.float32)
61
- return uv
62
-
63
- def _f02sine(self, f0_values):
64
- """f0_values: (batchsize, length, dim)
65
- where dim indicates fundamental tone and overtones
66
- """
67
- # convert to F0 in rad. The integer part n can be ignored
68
- # because 2 * np.pi * n doesn't affect phase
69
- rad_values = (f0_values / self.sampling_rate) % 1
70
-
71
- # initial phase noise (no noise for fundamental component)
72
- rand_ini = torch.rand(
73
- f0_values.shape[0], f0_values.shape[2], device=f0_values.device
74
- )
75
- rand_ini[:, 0] = 0
76
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
77
-
78
- # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
79
- if not self.flag_for_pulse:
80
- # for normal case
81
-
82
- # To prevent torch.cumsum numerical overflow,
83
- # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
84
- # Buffer tmp_over_one_idx indicates the time step to add -1.
85
- # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
86
- tmp_over_one = torch.cumsum(rad_values, 1) % 1
87
- tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
88
- cumsum_shift = torch.zeros_like(rad_values)
89
- cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
90
-
91
- sines = torch.sin(
92
- torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
93
- )
94
- else:
95
- # If necessary, make sure that the first time step of every
96
- # voiced segments is sin(pi) or cos(0)
97
- # This is used for pulse-train generation
98
-
99
- # identify the last time step in unvoiced segments
100
- uv = self._f02uv(f0_values)
101
- uv_1 = torch.roll(uv, shifts=-1, dims=1)
102
- uv_1[:, -1, :] = 1
103
- u_loc = (uv < 1) * (uv_1 > 0)
104
-
105
- # get the instantanouse phase
106
- tmp_cumsum = torch.cumsum(rad_values, dim=1)
107
- # different batch needs to be processed differently
108
- for idx in range(f0_values.shape[0]):
109
- temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
110
- temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
111
- # stores the accumulation of i.phase within
112
- # each voiced segments
113
- tmp_cumsum[idx, :, :] = 0
114
- tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
115
-
116
- # rad_values - tmp_cumsum: remove the accumulation of i.phase
117
- # within the previous voiced segment.
118
- i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
119
-
120
- # get the sines
121
- sines = torch.cos(i_phase * 2 * np.pi)
122
- return sines
123
-
124
- def forward(self, f0):
125
- """sine_tensor, uv = forward(f0)
126
- input F0: tensor(batchsize=1, length, dim=1)
127
- f0 for unvoiced steps should be 0
128
- output sine_tensor: tensor(batchsize=1, length, dim)
129
- output uv: tensor(batchsize=1, length, 1)
130
- """
131
- with torch.no_grad():
132
- # f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
133
- # fundamental component
134
- # fn = torch.multiply(
135
- # f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)
136
- # )
137
- fn = torch.multiply(
138
- f0, torch.arange(1, self.harmonic_num + 2).to(f0.device).to(f0.dtype)
139
- )
140
-
141
- # generate sine waveforms
142
- sine_waves = self._f02sine(fn) * self.sine_amp
143
-
144
- # generate uv signal
145
- # uv = torch.ones(f0.shape)
146
- # uv = uv * (f0 > self.voiced_threshold)
147
- uv = self._f02uv(f0)
148
-
149
- # noise: for unvoiced should be similar to sine_amp
150
- # std = self.sine_amp/3 -> max value ~ self.sine_amp
151
- # . for voiced regions is self.noise_std
152
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
153
- noise = noise_amp * torch.randn_like(sine_waves)
154
-
155
- # first: set the unvoiced part to 0 by uv
156
- # then: additive noise
157
- sine_waves = sine_waves * uv + noise
158
- return sine_waves, uv, noise
159
-
160
-
161
- class SourceModuleHnNSF(torch.nn.Module):
162
- """SourceModule for hn-nsf
163
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
164
- add_noise_std=0.003, voiced_threshod=0)
165
- sampling_rate: sampling_rate in Hz
166
- harmonic_num: number of harmonic above F0 (default: 0)
167
- sine_amp: amplitude of sine source signal (default: 0.1)
168
- add_noise_std: std of additive Gaussian noise (default: 0.003)
169
- note that amplitude of noise in unvoiced is decided
170
- by sine_amp
171
- voiced_threshold: threshold to set U/V given F0 (default: 0)
172
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
173
- F0_sampled (batchsize, length, 1)
174
- Sine_source (batchsize, length, 1)
175
- noise_source (batchsize, length 1)
176
- uv (batchsize, length, 1)
177
- """
178
-
179
- def __init__(
180
- self,
181
- sampling_rate,
182
- harmonic_num=0,
183
- sine_amp=0.1,
184
- add_noise_std=0.003,
185
- voiced_threshod=0,
186
- ):
187
- super().__init__()
188
-
189
- self.sine_amp = sine_amp
190
- self.noise_std = add_noise_std
191
-
192
- # to produce sine waveforms
193
- self.l_sin_gen = SineGen(
194
- sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
195
- )
196
-
197
- # to merge source harmonics into a single excitation
198
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
199
- self.l_tanh = torch.nn.Tanh()
200
-
201
- def forward(self, x):
202
- """
203
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
204
- F0_sampled (batchsize, length, 1)
205
- Sine_source (batchsize, length, 1)
206
- noise_source (batchsize, length 1)
207
- """
208
- # source for harmonic branch
209
- sine_wavs, uv, _ = self.l_sin_gen(x)
210
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
211
-
212
- # source for noise branch, in the same shape as uv
213
- noise = torch.randn_like(uv) * self.sine_amp / 3
214
- return sine_merge, noise, uv
215
-
216
-
217
- class NSFHifiGANGenerator(torch.nn.Module):
218
- def __init__(self, h):
219
- super().__init__()
220
- self.h = h
221
-
222
- self.num_kernels = len(h["resblock_kernel_sizes"])
223
- self.num_upsamples = len(h["upsample_rates"])
224
- self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h["upsample_rates"]))
225
- self.m_source = SourceModuleHnNSF(
226
- sampling_rate=h["sampling_rate"], harmonic_num=8
227
- )
228
- self.noise_convs = nn.ModuleList()
229
- self.conv_pre = weight_norm(
230
- Conv1d(h["inter_channels"], h["upsample_initial_channel"], 7, 1, padding=3)
231
- )
232
- resblock = ResBlock1 if h["resblock"] == "1" else ResBlock2
233
- self.ups = nn.ModuleList()
234
- for i, (u, k) in enumerate(
235
- zip(h["upsample_rates"], h["upsample_kernel_sizes"])
236
- ):
237
- c_cur = h["upsample_initial_channel"] // (2 ** (i + 1))
238
- self.ups.append(
239
- weight_norm(
240
- ConvTranspose1d(
241
- h["upsample_initial_channel"] // (2**i),
242
- h["upsample_initial_channel"] // (2 ** (i + 1)),
243
- k,
244
- u,
245
- padding=(k - u) // 2,
246
- )
247
- )
248
- )
249
- if i + 1 < len(h["upsample_rates"]): #
250
- stride_f0 = np.prod(h["upsample_rates"][i + 1 :])
251
- self.noise_convs.append(
252
- Conv1d(
253
- 1,
254
- c_cur,
255
- kernel_size=stride_f0 * 2,
256
- stride=stride_f0,
257
- padding=stride_f0 // 2,
258
- )
259
- )
260
- else:
261
- self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
262
- self.resblocks = nn.ModuleList()
263
- for i in range(len(self.ups)):
264
- ch = h["upsample_initial_channel"] // (2 ** (i + 1))
265
- for j, (k, d) in enumerate(
266
- zip(h["resblock_kernel_sizes"], h["resblock_dilation_sizes"])
267
- ):
268
- self.resblocks.append(resblock(ch, k, d))
269
-
270
- self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
271
- self.ups.apply(init_weights)
272
- self.conv_post.apply(init_weights)
273
- self.cond = nn.Conv1d(h["gin_channels"], h["upsample_initial_channel"], 1)
274
-
275
- def forward(self, x, f0, g=None):
276
- # LOG.info(1,x.shape,f0.shape,f0[:, None].shape)
277
- f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
278
- # LOG.info(2,f0.shape)
279
- har_source, noi_source, uv = self.m_source(f0)
280
- har_source = har_source.transpose(1, 2)
281
- x = self.conv_pre(x)
282
- x = x + self.cond(g)
283
- # LOG.info(124,x.shape,har_source.shape)
284
- for i in range(self.num_upsamples):
285
- x = F.leaky_relu(x, LRELU_SLOPE)
286
- # LOG.info(3,x.shape)
287
- x = self.ups[i](x)
288
- x_source = self.noise_convs[i](har_source)
289
- # LOG.info(4,x_source.shape,har_source.shape,x.shape)
290
- x = x + x_source
291
- xs = None
292
- for j in range(self.num_kernels):
293
- if xs is None:
294
- xs = self.resblocks[i * self.num_kernels + j](x)
295
- else:
296
- xs += self.resblocks[i * self.num_kernels + j](x)
297
- x = xs / self.num_kernels
298
- x = F.leaky_relu(x)
299
- x = self.conv_post(x)
300
- x = torch.tanh(x)
301
-
302
- return x
303
-
304
- def remove_weight_norm(self):
305
- LOG.info("Removing weight norm...")
306
- for l in self.ups:
307
- remove_weight_norm(l)
308
- for l in self.resblocks:
309
- l.remove_weight_norm()
310
- remove_weight_norm(self.conv_pre)
311
- remove_weight_norm(self.conv_post)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/decoders/hifigan/_utils.py DELETED
@@ -1,15 +0,0 @@
1
- from logging import getLogger
2
-
3
- # matplotlib.use("Agg")
4
-
5
- LOG = getLogger(__name__)
6
-
7
-
8
- def init_weights(m, mean=0.0, std=0.01):
9
- classname = m.__class__.__name__
10
- if classname.find("Conv") != -1:
11
- m.weight.data.normal_(mean, std)
12
-
13
-
14
- def get_padding(kernel_size, dilation=1):
15
- return int((kernel_size * dilation - dilation) / 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/decoders/mb_istft/__init__.py DELETED
@@ -1,15 +0,0 @@
1
- from ._generators import (
2
- Multiband_iSTFT_Generator,
3
- Multistream_iSTFT_Generator,
4
- iSTFT_Generator,
5
- )
6
- from ._loss import subband_stft_loss
7
- from ._pqmf import PQMF
8
-
9
- __all__ = [
10
- "subband_stft_loss",
11
- "PQMF",
12
- "iSTFT_Generator",
13
- "Multiband_iSTFT_Generator",
14
- "Multistream_iSTFT_Generator",
15
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/decoders/mb_istft/_generators.py DELETED
@@ -1,376 +0,0 @@
1
- import math
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn import Conv1d, ConvTranspose1d
6
- from torch.nn import functional as F
7
- from torch.nn.utils import remove_weight_norm, weight_norm
8
-
9
- from ....modules import modules
10
- from ....modules.commons import get_padding, init_weights
11
- from ._pqmf import PQMF
12
- from ._stft import TorchSTFT
13
-
14
-
15
- class iSTFT_Generator(torch.nn.Module):
16
- def __init__(
17
- self,
18
- initial_channel,
19
- resblock,
20
- resblock_kernel_sizes,
21
- resblock_dilation_sizes,
22
- upsample_rates,
23
- upsample_initial_channel,
24
- upsample_kernel_sizes,
25
- gen_istft_n_fft,
26
- gen_istft_hop_size,
27
- gin_channels=0,
28
- ):
29
- super().__init__()
30
- # self.h = h
31
- self.gen_istft_n_fft = gen_istft_n_fft
32
- self.gen_istft_hop_size = gen_istft_hop_size
33
-
34
- self.num_kernels = len(resblock_kernel_sizes)
35
- self.num_upsamples = len(upsample_rates)
36
- self.conv_pre = weight_norm(
37
- Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
38
- )
39
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
40
-
41
- self.ups = nn.ModuleList()
42
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
43
- self.ups.append(
44
- weight_norm(
45
- ConvTranspose1d(
46
- upsample_initial_channel // (2**i),
47
- upsample_initial_channel // (2 ** (i + 1)),
48
- k,
49
- u,
50
- padding=(k - u) // 2,
51
- )
52
- )
53
- )
54
-
55
- self.resblocks = nn.ModuleList()
56
- for i in range(len(self.ups)):
57
- ch = upsample_initial_channel // (2 ** (i + 1))
58
- for j, (k, d) in enumerate(
59
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
60
- ):
61
- self.resblocks.append(resblock(ch, k, d))
62
-
63
- self.post_n_fft = self.gen_istft_n_fft
64
- self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
65
- self.ups.apply(init_weights)
66
- self.conv_post.apply(init_weights)
67
- self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
68
- self.stft = TorchSTFT(
69
- filter_length=self.gen_istft_n_fft,
70
- hop_length=self.gen_istft_hop_size,
71
- win_length=self.gen_istft_n_fft,
72
- )
73
-
74
- def forward(self, x, g=None):
75
- x = self.conv_pre(x)
76
- for i in range(self.num_upsamples):
77
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
78
- x = self.ups[i](x)
79
- xs = None
80
- for j in range(self.num_kernels):
81
- if xs is None:
82
- xs = self.resblocks[i * self.num_kernels + j](x)
83
- else:
84
- xs += self.resblocks[i * self.num_kernels + j](x)
85
- x = xs / self.num_kernels
86
- x = F.leaky_relu(x)
87
- x = self.reflection_pad(x)
88
- x = self.conv_post(x)
89
- spec = torch.exp(x[:, : self.post_n_fft // 2 + 1, :])
90
- phase = math.pi * torch.sin(x[:, self.post_n_fft // 2 + 1 :, :])
91
- out = self.stft.inverse(spec, phase).to(x.device)
92
- return out, None
93
-
94
- def remove_weight_norm(self):
95
- print("Removing weight norm...")
96
- for l in self.ups:
97
- remove_weight_norm(l)
98
- for l in self.resblocks:
99
- l.remove_weight_norm()
100
- remove_weight_norm(self.conv_pre)
101
- remove_weight_norm(self.conv_post)
102
-
103
-
104
- class Multiband_iSTFT_Generator(torch.nn.Module):
105
- def __init__(
106
- self,
107
- initial_channel,
108
- resblock,
109
- resblock_kernel_sizes,
110
- resblock_dilation_sizes,
111
- upsample_rates,
112
- upsample_initial_channel,
113
- upsample_kernel_sizes,
114
- gen_istft_n_fft,
115
- gen_istft_hop_size,
116
- subbands,
117
- gin_channels=0,
118
- ):
119
- super().__init__()
120
- # self.h = h
121
- self.subbands = subbands
122
- self.num_kernels = len(resblock_kernel_sizes)
123
- self.num_upsamples = len(upsample_rates)
124
- self.conv_pre = weight_norm(
125
- Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
126
- )
127
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
128
-
129
- self.ups = nn.ModuleList()
130
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
131
- self.ups.append(
132
- weight_norm(
133
- ConvTranspose1d(
134
- upsample_initial_channel // (2**i),
135
- upsample_initial_channel // (2 ** (i + 1)),
136
- k,
137
- u,
138
- padding=(k - u) // 2,
139
- )
140
- )
141
- )
142
-
143
- self.resblocks = nn.ModuleList()
144
- for i in range(len(self.ups)):
145
- ch = upsample_initial_channel // (2 ** (i + 1))
146
- for j, (k, d) in enumerate(
147
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
148
- ):
149
- self.resblocks.append(resblock(ch, k, d))
150
-
151
- self.post_n_fft = gen_istft_n_fft
152
- self.ups.apply(init_weights)
153
- self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
154
- self.reshape_pixelshuffle = []
155
-
156
- self.subband_conv_post = weight_norm(
157
- Conv1d(ch, self.subbands * (self.post_n_fft + 2), 7, 1, padding=3)
158
- )
159
-
160
- self.subband_conv_post.apply(init_weights)
161
-
162
- self.gen_istft_n_fft = gen_istft_n_fft
163
- self.gen_istft_hop_size = gen_istft_hop_size
164
-
165
- def forward(self, x, g=None):
166
- stft = TorchSTFT(
167
- filter_length=self.gen_istft_n_fft,
168
- hop_length=self.gen_istft_hop_size,
169
- win_length=self.gen_istft_n_fft,
170
- ).to(x.device)
171
- pqmf = PQMF(x.device, subbands=self.subbands).to(x.device, dtype=x.dtype)
172
-
173
- x = self.conv_pre(x) # [B, ch, length]
174
-
175
- for i in range(self.num_upsamples):
176
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
177
- x = self.ups[i](x)
178
-
179
- xs = None
180
- for j in range(self.num_kernels):
181
- if xs is None:
182
- xs = self.resblocks[i * self.num_kernels + j](x)
183
- else:
184
- xs += self.resblocks[i * self.num_kernels + j](x)
185
- x = xs / self.num_kernels
186
-
187
- x = F.leaky_relu(x)
188
- x = self.reflection_pad(x)
189
- x = self.subband_conv_post(x)
190
- x = torch.reshape(
191
- x, (x.shape[0], self.subbands, x.shape[1] // self.subbands, x.shape[-1])
192
- )
193
-
194
- spec = torch.exp(x[:, :, : self.post_n_fft // 2 + 1, :])
195
- phase = math.pi * torch.sin(x[:, :, self.post_n_fft // 2 + 1 :, :])
196
-
197
- y_mb_hat = stft.inverse(
198
- torch.reshape(
199
- spec,
200
- (
201
- spec.shape[0] * self.subbands,
202
- self.gen_istft_n_fft // 2 + 1,
203
- spec.shape[-1],
204
- ),
205
- ),
206
- torch.reshape(
207
- phase,
208
- (
209
- phase.shape[0] * self.subbands,
210
- self.gen_istft_n_fft // 2 + 1,
211
- phase.shape[-1],
212
- ),
213
- ),
214
- )
215
- y_mb_hat = torch.reshape(
216
- y_mb_hat, (x.shape[0], self.subbands, 1, y_mb_hat.shape[-1])
217
- )
218
- y_mb_hat = y_mb_hat.squeeze(-2)
219
-
220
- y_g_hat = pqmf.synthesis(y_mb_hat)
221
-
222
- return y_g_hat, y_mb_hat
223
-
224
- def remove_weight_norm(self):
225
- print("Removing weight norm...")
226
- for l in self.ups:
227
- remove_weight_norm(l)
228
- for l in self.resblocks:
229
- l.remove_weight_norm()
230
-
231
-
232
- class Multistream_iSTFT_Generator(torch.nn.Module):
233
- def __init__(
234
- self,
235
- initial_channel,
236
- resblock,
237
- resblock_kernel_sizes,
238
- resblock_dilation_sizes,
239
- upsample_rates,
240
- upsample_initial_channel,
241
- upsample_kernel_sizes,
242
- gen_istft_n_fft,
243
- gen_istft_hop_size,
244
- subbands,
245
- gin_channels=0,
246
- ):
247
- super().__init__()
248
- # self.h = h
249
- self.subbands = subbands
250
- self.num_kernels = len(resblock_kernel_sizes)
251
- self.num_upsamples = len(upsample_rates)
252
- self.conv_pre = weight_norm(
253
- Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
254
- )
255
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
256
-
257
- self.ups = nn.ModuleList()
258
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
259
- self.ups.append(
260
- weight_norm(
261
- ConvTranspose1d(
262
- upsample_initial_channel // (2**i),
263
- upsample_initial_channel // (2 ** (i + 1)),
264
- k,
265
- u,
266
- padding=(k - u) // 2,
267
- )
268
- )
269
- )
270
-
271
- self.resblocks = nn.ModuleList()
272
- for i in range(len(self.ups)):
273
- ch = upsample_initial_channel // (2 ** (i + 1))
274
- for j, (k, d) in enumerate(
275
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
276
- ):
277
- self.resblocks.append(resblock(ch, k, d))
278
-
279
- self.post_n_fft = gen_istft_n_fft
280
- self.ups.apply(init_weights)
281
- self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
282
- self.reshape_pixelshuffle = []
283
-
284
- self.subband_conv_post = weight_norm(
285
- Conv1d(ch, self.subbands * (self.post_n_fft + 2), 7, 1, padding=3)
286
- )
287
-
288
- self.subband_conv_post.apply(init_weights)
289
-
290
- self.gen_istft_n_fft = gen_istft_n_fft
291
- self.gen_istft_hop_size = gen_istft_hop_size
292
-
293
- updown_filter = torch.zeros(
294
- (self.subbands, self.subbands, self.subbands)
295
- ).float()
296
- for k in range(self.subbands):
297
- updown_filter[k, k, 0] = 1.0
298
- self.register_buffer("updown_filter", updown_filter)
299
- self.multistream_conv_post = weight_norm(
300
- Conv1d(
301
- self.subbands, 1, kernel_size=63, bias=False, padding=get_padding(63, 1)
302
- )
303
- )
304
- self.multistream_conv_post.apply(init_weights)
305
-
306
- def forward(self, x, g=None):
307
- stft = TorchSTFT(
308
- filter_length=self.gen_istft_n_fft,
309
- hop_length=self.gen_istft_hop_size,
310
- win_length=self.gen_istft_n_fft,
311
- ).to(x.device)
312
- # pqmf = PQMF(x.device)
313
-
314
- x = self.conv_pre(x) # [B, ch, length]
315
-
316
- for i in range(self.num_upsamples):
317
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
318
- x = self.ups[i](x)
319
-
320
- xs = None
321
- for j in range(self.num_kernels):
322
- if xs is None:
323
- xs = self.resblocks[i * self.num_kernels + j](x)
324
- else:
325
- xs += self.resblocks[i * self.num_kernels + j](x)
326
- x = xs / self.num_kernels
327
-
328
- x = F.leaky_relu(x)
329
- x = self.reflection_pad(x)
330
- x = self.subband_conv_post(x)
331
- x = torch.reshape(
332
- x, (x.shape[0], self.subbands, x.shape[1] // self.subbands, x.shape[-1])
333
- )
334
-
335
- spec = torch.exp(x[:, :, : self.post_n_fft // 2 + 1, :])
336
- phase = math.pi * torch.sin(x[:, :, self.post_n_fft // 2 + 1 :, :])
337
-
338
- y_mb_hat = stft.inverse(
339
- torch.reshape(
340
- spec,
341
- (
342
- spec.shape[0] * self.subbands,
343
- self.gen_istft_n_fft // 2 + 1,
344
- spec.shape[-1],
345
- ),
346
- ),
347
- torch.reshape(
348
- phase,
349
- (
350
- phase.shape[0] * self.subbands,
351
- self.gen_istft_n_fft // 2 + 1,
352
- phase.shape[-1],
353
- ),
354
- ),
355
- )
356
- y_mb_hat = torch.reshape(
357
- y_mb_hat, (x.shape[0], self.subbands, 1, y_mb_hat.shape[-1])
358
- )
359
- y_mb_hat = y_mb_hat.squeeze(-2)
360
-
361
- y_mb_hat = F.conv_transpose1d(
362
- y_mb_hat,
363
- self.updown_filter.to(x.device) * self.subbands,
364
- stride=self.subbands,
365
- )
366
-
367
- y_g_hat = self.multistream_conv_post(y_mb_hat)
368
-
369
- return y_g_hat, y_mb_hat
370
-
371
- def remove_weight_norm(self):
372
- print("Removing weight norm...")
373
- for l in self.ups:
374
- remove_weight_norm(l)
375
- for l in self.resblocks:
376
- l.remove_weight_norm()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/decoders/mb_istft/_loss.py DELETED
@@ -1,11 +0,0 @@
1
- from ._stft_loss import MultiResolutionSTFTLoss
2
-
3
-
4
- def subband_stft_loss(h, y_mb, y_hat_mb):
5
- sub_stft_loss = MultiResolutionSTFTLoss(
6
- h.train.fft_sizes, h.train.hop_sizes, h.train.win_lengths
7
- )
8
- y_mb = y_mb.view(-1, y_mb.size(2))
9
- y_hat_mb = y_hat_mb.view(-1, y_hat_mb.size(2))
10
- sub_sc_loss, sub_mag_loss = sub_stft_loss(y_hat_mb[:, : y_mb.size(-1)], y_mb)
11
- return sub_sc_loss + sub_mag_loss
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/decoders/mb_istft/_pqmf.py DELETED
@@ -1,128 +0,0 @@
1
- # Copyright 2020 Tomoki Hayashi
2
- # MIT License (https://opensource.org/licenses/MIT)
3
-
4
- """Pseudo QMF modules."""
5
-
6
- import numpy as np
7
- import torch
8
- import torch.nn.functional as F
9
- from scipy.signal import kaiser
10
-
11
-
12
- def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
13
- """Design prototype filter for PQMF.
14
- This method is based on `A Kaiser window approach for the design of prototype
15
- filters of cosine modulated filterbanks`_.
16
- Args:
17
- taps (int): The number of filter taps.
18
- cutoff_ratio (float): Cut-off frequency ratio.
19
- beta (float): Beta coefficient for kaiser window.
20
- Returns:
21
- ndarray: Impluse response of prototype filter (taps + 1,).
22
- .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
23
- https://ieeexplore.ieee.org/abstract/document/681427
24
- """
25
- # check the arguments are valid
26
- assert taps % 2 == 0, "The number of taps mush be even number."
27
- assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
28
-
29
- # make initial filter
30
- omega_c = np.pi * cutoff_ratio
31
- with np.errstate(invalid="ignore"):
32
- h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (
33
- np.pi * (np.arange(taps + 1) - 0.5 * taps)
34
- )
35
- h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
36
-
37
- # apply kaiser window
38
- w = kaiser(taps + 1, beta)
39
- h = h_i * w
40
-
41
- return h
42
-
43
-
44
- class PQMF(torch.nn.Module):
45
- """PQMF module.
46
- This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
47
- .. _`Near-perfect-reconstruction pseudo-QMF banks`:
48
- https://ieeexplore.ieee.org/document/258122
49
- """
50
-
51
- def __init__(self, device, subbands=8, taps=62, cutoff_ratio=0.15, beta=9.0):
52
- """Initialize PQMF module.
53
- Args:
54
- subbands (int): The number of subbands.
55
- taps (int): The number of filter taps.
56
- cutoff_ratio (float): Cut-off frequency ratio.
57
- beta (float): Beta coefficient for kaiser window.
58
- """
59
- super().__init__()
60
-
61
- # define filter coefficient
62
- h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
63
- h_analysis = np.zeros((subbands, len(h_proto)))
64
- h_synthesis = np.zeros((subbands, len(h_proto)))
65
- for k in range(subbands):
66
- h_analysis[k] = (
67
- 2
68
- * h_proto
69
- * np.cos(
70
- (2 * k + 1)
71
- * (np.pi / (2 * subbands))
72
- * (np.arange(taps + 1) - ((taps - 1) / 2))
73
- + (-1) ** k * np.pi / 4
74
- )
75
- )
76
- h_synthesis[k] = (
77
- 2
78
- * h_proto
79
- * np.cos(
80
- (2 * k + 1)
81
- * (np.pi / (2 * subbands))
82
- * (np.arange(taps + 1) - ((taps - 1) / 2))
83
- - (-1) ** k * np.pi / 4
84
- )
85
- )
86
-
87
- # convert to tensor
88
- analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1).to(device)
89
- synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0).to(device)
90
-
91
- # register coefficients as buffer
92
- self.register_buffer("analysis_filter", analysis_filter)
93
- self.register_buffer("synthesis_filter", synthesis_filter)
94
-
95
- # filter for downsampling & upsampling
96
- updown_filter = torch.zeros((subbands, subbands, subbands)).float().to(device)
97
- for k in range(subbands):
98
- updown_filter[k, k, 0] = 1.0
99
- self.register_buffer("updown_filter", updown_filter)
100
- self.subbands = subbands
101
-
102
- # keep padding info
103
- self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
104
-
105
- def analysis(self, x):
106
- """Analysis with PQMF.
107
- Args:
108
- x (Tensor): Input tensor (B, 1, T).
109
- Returns:
110
- Tensor: Output tensor (B, subbands, T // subbands).
111
- """
112
- x = F.conv1d(self.pad_fn(x), self.analysis_filter)
113
- return F.conv1d(x, self.updown_filter, stride=self.subbands)
114
-
115
- def synthesis(self, x):
116
- """Synthesis with PQMF.
117
- Args:
118
- x (Tensor): Input tensor (B, subbands, T // subbands).
119
- Returns:
120
- Tensor: Output tensor (B, 1, T).
121
- """
122
- # NOTE(kan-bayashi): Power will be dreased so here multiply by # subbands.
123
- # Not sure this is the correct way, it is better to check again.
124
- # TODO(kan-bayashi): Understand the reconstruction procedure
125
- x = F.conv_transpose1d(
126
- x, self.updown_filter * self.subbands, stride=self.subbands
127
- )
128
- return F.conv1d(self.pad_fn(x), self.synthesis_filter)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/decoders/mb_istft/_stft.py DELETED
@@ -1,244 +0,0 @@
1
- """
2
- BSD 3-Clause License
3
- Copyright (c) 2017, Prem Seetharaman
4
- All rights reserved.
5
- * Redistribution and use in source and binary forms, with or without
6
- modification, are permitted provided that the following conditions are met:
7
- * Redistributions of source code must retain the above copyright notice,
8
- this list of conditions and the following disclaimer.
9
- * Redistributions in binary form must reproduce the above copyright notice, this
10
- list of conditions and the following disclaimer in the
11
- documentation and/or other materials provided with the distribution.
12
- * Neither the name of the copyright holder nor the names of its
13
- contributors may be used to endorse or promote products derived from this
14
- software without specific prior written permission.
15
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
- ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
- WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
19
- ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
- (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
- LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
22
- ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
- (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
- SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
- """
26
-
27
- import librosa.util as librosa_util
28
- import numpy as np
29
- import torch
30
- import torch.nn.functional as F
31
- from librosa.util import pad_center, tiny
32
- from scipy.signal import get_window
33
- from torch.autograd import Variable
34
-
35
-
36
- def window_sumsquare(
37
- window,
38
- n_frames,
39
- hop_length=200,
40
- win_length=800,
41
- n_fft=800,
42
- dtype=np.float32,
43
- norm=None,
44
- ):
45
- """
46
- # from librosa 0.6
47
- Compute the sum-square envelope of a window function at a given hop length.
48
- This is used to estimate modulation effects induced by windowing
49
- observations in short-time fourier transforms.
50
- Parameters
51
- ----------
52
- window : string, tuple, number, callable, or list-like
53
- Window specification, as in `get_window`
54
- n_frames : int > 0
55
- The number of analysis frames
56
- hop_length : int > 0
57
- The number of samples to advance between frames
58
- win_length : [optional]
59
- The length of the window function. By default, this matches `n_fft`.
60
- n_fft : int > 0
61
- The length of each analysis frame.
62
- dtype : np.dtype
63
- The data type of the output
64
- Returns
65
- -------
66
- wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
67
- The sum-squared envelope of the window function
68
- """
69
- if win_length is None:
70
- win_length = n_fft
71
-
72
- n = n_fft + hop_length * (n_frames - 1)
73
- x = np.zeros(n, dtype=dtype)
74
-
75
- # Compute the squared window at the desired length
76
- win_sq = get_window(window, win_length, fftbins=True)
77
- win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
78
- win_sq = librosa_util.pad_center(win_sq, n_fft)
79
-
80
- # Fill the envelope
81
- for i in range(n_frames):
82
- sample = i * hop_length
83
- x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
84
- return x
85
-
86
-
87
- class STFT(torch.nn.Module):
88
- """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
89
-
90
- def __init__(
91
- self, filter_length=800, hop_length=200, win_length=800, window="hann"
92
- ):
93
- super().__init__()
94
- self.filter_length = filter_length
95
- self.hop_length = hop_length
96
- self.win_length = win_length
97
- self.window = window
98
- self.forward_transform = None
99
- scale = self.filter_length / self.hop_length
100
- fourier_basis = np.fft.fft(np.eye(self.filter_length))
101
-
102
- cutoff = int(self.filter_length / 2 + 1)
103
- fourier_basis = np.vstack(
104
- [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
105
- )
106
-
107
- forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
108
- inverse_basis = torch.FloatTensor(
109
- np.linalg.pinv(scale * fourier_basis).T[:, None, :]
110
- )
111
-
112
- if window is not None:
113
- assert filter_length >= win_length
114
- # get window and zero center pad it to filter_length
115
- fft_window = get_window(window, win_length, fftbins=True)
116
- fft_window = pad_center(fft_window, filter_length)
117
- fft_window = torch.from_numpy(fft_window).float()
118
-
119
- # window the bases
120
- forward_basis *= fft_window
121
- inverse_basis *= fft_window
122
-
123
- self.register_buffer("forward_basis", forward_basis.float())
124
- self.register_buffer("inverse_basis", inverse_basis.float())
125
-
126
- def transform(self, input_data):
127
- num_batches = input_data.size(0)
128
- num_samples = input_data.size(1)
129
-
130
- self.num_samples = num_samples
131
-
132
- # similar to librosa, reflect-pad the input
133
- input_data = input_data.view(num_batches, 1, num_samples)
134
- input_data = F.pad(
135
- input_data.unsqueeze(1),
136
- (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
137
- mode="reflect",
138
- )
139
- input_data = input_data.squeeze(1)
140
-
141
- forward_transform = F.conv1d(
142
- input_data,
143
- Variable(self.forward_basis, requires_grad=False),
144
- stride=self.hop_length,
145
- padding=0,
146
- )
147
-
148
- cutoff = int((self.filter_length / 2) + 1)
149
- real_part = forward_transform[:, :cutoff, :]
150
- imag_part = forward_transform[:, cutoff:, :]
151
-
152
- magnitude = torch.sqrt(real_part**2 + imag_part**2)
153
- phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
154
-
155
- return magnitude, phase
156
-
157
- def inverse(self, magnitude, phase):
158
- recombine_magnitude_phase = torch.cat(
159
- [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
160
- )
161
-
162
- inverse_transform = F.conv_transpose1d(
163
- recombine_magnitude_phase,
164
- Variable(self.inverse_basis, requires_grad=False),
165
- stride=self.hop_length,
166
- padding=0,
167
- )
168
-
169
- if self.window is not None:
170
- window_sum = window_sumsquare(
171
- self.window,
172
- magnitude.size(-1),
173
- hop_length=self.hop_length,
174
- win_length=self.win_length,
175
- n_fft=self.filter_length,
176
- dtype=np.float32,
177
- )
178
- # remove modulation effects
179
- approx_nonzero_indices = torch.from_numpy(
180
- np.where(window_sum > tiny(window_sum))[0]
181
- )
182
- window_sum = torch.autograd.Variable(
183
- torch.from_numpy(window_sum), requires_grad=False
184
- )
185
- window_sum = window_sum.to(inverse_transform.device())
186
- inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
187
- approx_nonzero_indices
188
- ]
189
-
190
- # scale by hop ratio
191
- inverse_transform *= float(self.filter_length) / self.hop_length
192
-
193
- inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
194
- inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
195
-
196
- return inverse_transform
197
-
198
- def forward(self, input_data):
199
- self.magnitude, self.phase = self.transform(input_data)
200
- reconstruction = self.inverse(self.magnitude, self.phase)
201
- return reconstruction
202
-
203
-
204
- class TorchSTFT(torch.nn.Module):
205
- def __init__(
206
- self, filter_length=800, hop_length=200, win_length=800, window="hann"
207
- ):
208
- super().__init__()
209
- self.filter_length = filter_length
210
- self.hop_length = hop_length
211
- self.win_length = win_length
212
- self.window = torch.from_numpy(
213
- get_window(window, win_length, fftbins=True).astype(np.float32)
214
- )
215
-
216
- def transform(self, input_data):
217
- forward_transform = torch.stft(
218
- input_data,
219
- self.filter_length,
220
- self.hop_length,
221
- self.win_length,
222
- window=self.window,
223
- return_complex=True,
224
- )
225
-
226
- return torch.abs(forward_transform), torch.angle(forward_transform)
227
-
228
- def inverse(self, magnitude, phase):
229
- inverse_transform = torch.istft(
230
- magnitude * torch.exp(phase * 1j),
231
- self.filter_length,
232
- self.hop_length,
233
- self.win_length,
234
- window=self.window.to(magnitude.device),
235
- )
236
-
237
- return inverse_transform.unsqueeze(
238
- -2
239
- ) # unsqueeze to stay consistent with conv_transpose1d implementation
240
-
241
- def forward(self, input_data):
242
- self.magnitude, self.phase = self.transform(input_data)
243
- reconstruction = self.inverse(self.magnitude, self.phase)
244
- return reconstruction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/decoders/mb_istft/_stft_loss.py DELETED
@@ -1,142 +0,0 @@
1
- # Copyright 2019 Tomoki Hayashi
2
- # MIT License (https://opensource.org/licenses/MIT)
3
-
4
- """STFT-based Loss modules."""
5
-
6
- import torch
7
- import torch.nn.functional as F
8
-
9
-
10
- def stft(x, fft_size, hop_size, win_length, window):
11
- """Perform STFT and convert to magnitude spectrogram.
12
- Args:
13
- x (Tensor): Input signal tensor (B, T).
14
- fft_size (int): FFT size.
15
- hop_size (int): Hop size.
16
- win_length (int): Window length.
17
- window (str): Window function type.
18
- Returns:
19
- Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
20
- """
21
- x_stft = torch.stft(
22
- x, fft_size, hop_size, win_length, window.to(x.device), return_complex=False
23
- )
24
- real = x_stft[..., 0]
25
- imag = x_stft[..., 1]
26
-
27
- # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
28
- return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
29
-
30
-
31
- class SpectralConvergengeLoss(torch.nn.Module):
32
- """Spectral convergence loss module."""
33
-
34
- def __init__(self):
35
- """Initialize spectral convergence loss module."""
36
- super().__init__()
37
-
38
- def forward(self, x_mag, y_mag):
39
- """Calculate forward propagation.
40
- Args:
41
- x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
42
- y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
43
- Returns:
44
- Tensor: Spectral convergence loss value.
45
- """
46
- return torch.norm(y_mag - x_mag) / torch.norm(
47
- y_mag
48
- ) # MB-iSTFT-VITS changed here due to codespell
49
-
50
-
51
- class LogSTFTMagnitudeLoss(torch.nn.Module):
52
- """Log STFT magnitude loss module."""
53
-
54
- def __init__(self):
55
- """Initialize los STFT magnitude loss module."""
56
- super().__init__()
57
-
58
- def forward(self, x_mag, y_mag):
59
- """Calculate forward propagation.
60
- Args:
61
- x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
62
- y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
63
- Returns:
64
- Tensor: Log STFT magnitude loss value.
65
- """
66
- return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
67
-
68
-
69
- class STFTLoss(torch.nn.Module):
70
- """STFT loss module."""
71
-
72
- def __init__(
73
- self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"
74
- ):
75
- """Initialize STFT loss module."""
76
- super().__init__()
77
- self.fft_size = fft_size
78
- self.shift_size = shift_size
79
- self.win_length = win_length
80
- self.window = getattr(torch, window)(win_length)
81
- self.spectral_convergenge_loss = SpectralConvergengeLoss()
82
- self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
83
-
84
- def forward(self, x, y):
85
- """Calculate forward propagation.
86
- Args:
87
- x (Tensor): Predicted signal (B, T).
88
- y (Tensor): Groundtruth signal (B, T).
89
- Returns:
90
- Tensor: Spectral convergence loss value.
91
- Tensor: Log STFT magnitude loss value.
92
- """
93
- x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
94
- y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
95
- sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
96
- mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
97
-
98
- return sc_loss, mag_loss
99
-
100
-
101
- class MultiResolutionSTFTLoss(torch.nn.Module):
102
- """Multi resolution STFT loss module."""
103
-
104
- def __init__(
105
- self,
106
- fft_sizes=[1024, 2048, 512],
107
- hop_sizes=[120, 240, 50],
108
- win_lengths=[600, 1200, 240],
109
- window="hann_window",
110
- ):
111
- """Initialize Multi resolution STFT loss module.
112
- Args:
113
- fft_sizes (list): List of FFT sizes.
114
- hop_sizes (list): List of hop sizes.
115
- win_lengths (list): List of window lengths.
116
- window (str): Window function type.
117
- """
118
- super().__init__()
119
- assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
120
- self.stft_losses = torch.nn.ModuleList()
121
- for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
122
- self.stft_losses += [STFTLoss(fs, ss, wl, window)]
123
-
124
- def forward(self, x, y):
125
- """Calculate forward propagation.
126
- Args:
127
- x (Tensor): Predicted signal (B, T).
128
- y (Tensor): Groundtruth signal (B, T).
129
- Returns:
130
- Tensor: Multi resolution spectral convergence loss value.
131
- Tensor: Multi resolution log STFT magnitude loss value.
132
- """
133
- sc_loss = 0.0
134
- mag_loss = 0.0
135
- for f in self.stft_losses:
136
- sc_l, mag_l = f(x, y)
137
- sc_loss += sc_l
138
- mag_loss += mag_l
139
- sc_loss /= len(self.stft_losses)
140
- mag_loss /= len(self.stft_losses)
141
-
142
- return sc_loss, mag_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/descriminators.py DELETED
@@ -1,177 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import AvgPool1d, Conv1d, Conv2d
4
- from torch.nn import functional as F
5
- from torch.nn.utils import spectral_norm, weight_norm
6
-
7
- from so_vits_svc_fork.modules import modules as modules
8
- from so_vits_svc_fork.modules.commons import get_padding
9
-
10
-
11
- class DiscriminatorP(torch.nn.Module):
12
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
13
- super().__init__()
14
- self.period = period
15
- self.use_spectral_norm = use_spectral_norm
16
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
17
- self.convs = nn.ModuleList(
18
- [
19
- norm_f(
20
- Conv2d(
21
- 1,
22
- 32,
23
- (kernel_size, 1),
24
- (stride, 1),
25
- padding=(get_padding(kernel_size, 1), 0),
26
- )
27
- ),
28
- norm_f(
29
- Conv2d(
30
- 32,
31
- 128,
32
- (kernel_size, 1),
33
- (stride, 1),
34
- padding=(get_padding(kernel_size, 1), 0),
35
- )
36
- ),
37
- norm_f(
38
- Conv2d(
39
- 128,
40
- 512,
41
- (kernel_size, 1),
42
- (stride, 1),
43
- padding=(get_padding(kernel_size, 1), 0),
44
- )
45
- ),
46
- norm_f(
47
- Conv2d(
48
- 512,
49
- 1024,
50
- (kernel_size, 1),
51
- (stride, 1),
52
- padding=(get_padding(kernel_size, 1), 0),
53
- )
54
- ),
55
- norm_f(
56
- Conv2d(
57
- 1024,
58
- 1024,
59
- (kernel_size, 1),
60
- 1,
61
- padding=(get_padding(kernel_size, 1), 0),
62
- )
63
- ),
64
- ]
65
- )
66
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
67
-
68
- def forward(self, x):
69
- fmap = []
70
-
71
- # 1d to 2d
72
- b, c, t = x.shape
73
- if t % self.period != 0: # pad first
74
- n_pad = self.period - (t % self.period)
75
- x = F.pad(x, (0, n_pad), "reflect")
76
- t = t + n_pad
77
- x = x.view(b, c, t // self.period, self.period)
78
-
79
- for l in self.convs:
80
- x = l(x)
81
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
82
- fmap.append(x)
83
- x = self.conv_post(x)
84
- fmap.append(x)
85
- x = torch.flatten(x, 1, -1)
86
-
87
- return x, fmap
88
-
89
-
90
- class DiscriminatorS(torch.nn.Module):
91
- def __init__(self, use_spectral_norm=False):
92
- super().__init__()
93
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
94
- self.convs = nn.ModuleList(
95
- [
96
- norm_f(Conv1d(1, 16, 15, 1, padding=7)),
97
- norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
98
- norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
99
- norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
100
- norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
101
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
102
- ]
103
- )
104
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
105
-
106
- def forward(self, x):
107
- fmap = []
108
-
109
- for l in self.convs:
110
- x = l(x)
111
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
112
- fmap.append(x)
113
- x = self.conv_post(x)
114
- fmap.append(x)
115
- x = torch.flatten(x, 1, -1)
116
-
117
- return x, fmap
118
-
119
-
120
- class MultiPeriodDiscriminator(torch.nn.Module):
121
- def __init__(self, use_spectral_norm=False):
122
- super().__init__()
123
- periods = [2, 3, 5, 7, 11]
124
-
125
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
126
- discs = discs + [
127
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
128
- ]
129
- self.discriminators = nn.ModuleList(discs)
130
-
131
- def forward(self, y, y_hat):
132
- y_d_rs = []
133
- y_d_gs = []
134
- fmap_rs = []
135
- fmap_gs = []
136
- for i, d in enumerate(self.discriminators):
137
- y_d_r, fmap_r = d(y)
138
- y_d_g, fmap_g = d(y_hat)
139
- y_d_rs.append(y_d_r)
140
- y_d_gs.append(y_d_g)
141
- fmap_rs.append(fmap_r)
142
- fmap_gs.append(fmap_g)
143
-
144
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
145
-
146
-
147
- class MultiScaleDiscriminator(torch.nn.Module):
148
- def __init__(self):
149
- super().__init__()
150
- self.discriminators = nn.ModuleList(
151
- [
152
- DiscriminatorS(use_spectral_norm=True),
153
- DiscriminatorS(),
154
- DiscriminatorS(),
155
- ]
156
- )
157
- self.meanpools = nn.ModuleList(
158
- [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
159
- )
160
-
161
- def forward(self, y, y_hat):
162
- y_d_rs = []
163
- y_d_gs = []
164
- fmap_rs = []
165
- fmap_gs = []
166
- for i, d in enumerate(self.discriminators):
167
- if i != 0:
168
- y = self.meanpools[i - 1](y)
169
- y_hat = self.meanpools[i - 1](y_hat)
170
- y_d_r, fmap_r = d(y)
171
- y_d_g, fmap_g = d(y_hat)
172
- y_d_rs.append(y_d_r)
173
- fmap_rs.append(fmap_r)
174
- y_d_gs.append(y_d_g)
175
- fmap_gs.append(fmap_g)
176
-
177
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/encoders.py DELETED
@@ -1,136 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
- from so_vits_svc_fork.modules import attentions as attentions
5
- from so_vits_svc_fork.modules import commons as commons
6
- from so_vits_svc_fork.modules import modules as modules
7
-
8
-
9
- class SpeakerEncoder(torch.nn.Module):
10
- def __init__(
11
- self,
12
- mel_n_channels=80,
13
- model_num_layers=3,
14
- model_hidden_size=256,
15
- model_embedding_size=256,
16
- ):
17
- super().__init__()
18
- self.lstm = nn.LSTM(
19
- mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
20
- )
21
- self.linear = nn.Linear(model_hidden_size, model_embedding_size)
22
- self.relu = nn.ReLU()
23
-
24
- def forward(self, mels):
25
- self.lstm.flatten_parameters()
26
- _, (hidden, _) = self.lstm(mels)
27
- embeds_raw = self.relu(self.linear(hidden[-1]))
28
- return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
29
-
30
- def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
31
- mel_slices = []
32
- for i in range(0, total_frames - partial_frames, partial_hop):
33
- mel_range = torch.arange(i, i + partial_frames)
34
- mel_slices.append(mel_range)
35
-
36
- return mel_slices
37
-
38
- def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
39
- mel_len = mel.size(1)
40
- last_mel = mel[:, -partial_frames:]
41
-
42
- if mel_len > partial_frames:
43
- mel_slices = self.compute_partial_slices(
44
- mel_len, partial_frames, partial_hop
45
- )
46
- mels = list(mel[:, s] for s in mel_slices)
47
- mels.append(last_mel)
48
- mels = torch.stack(tuple(mels), 0).squeeze(1)
49
-
50
- with torch.no_grad():
51
- partial_embeds = self(mels)
52
- embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
53
- # embed = embed / torch.linalg.norm(embed, 2)
54
- else:
55
- with torch.no_grad():
56
- embed = self(last_mel)
57
-
58
- return embed
59
-
60
-
61
- class Encoder(nn.Module):
62
- def __init__(
63
- self,
64
- in_channels,
65
- out_channels,
66
- hidden_channels,
67
- kernel_size,
68
- dilation_rate,
69
- n_layers,
70
- gin_channels=0,
71
- ):
72
- super().__init__()
73
- self.in_channels = in_channels
74
- self.out_channels = out_channels
75
- self.hidden_channels = hidden_channels
76
- self.kernel_size = kernel_size
77
- self.dilation_rate = dilation_rate
78
- self.n_layers = n_layers
79
- self.gin_channels = gin_channels
80
-
81
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
82
- self.enc = modules.WN(
83
- hidden_channels,
84
- kernel_size,
85
- dilation_rate,
86
- n_layers,
87
- gin_channels=gin_channels,
88
- )
89
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
90
-
91
- def forward(self, x, x_lengths, g=None):
92
- # print(x.shape,x_lengths.shape)
93
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
94
- x.dtype
95
- )
96
- x = self.pre(x) * x_mask
97
- x = self.enc(x, x_mask, g=g)
98
- stats = self.proj(x) * x_mask
99
- m, logs = torch.split(stats, self.out_channels, dim=1)
100
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
101
- return z, m, logs, x_mask
102
-
103
-
104
- class TextEncoder(nn.Module):
105
- def __init__(
106
- self,
107
- out_channels,
108
- hidden_channels,
109
- kernel_size,
110
- n_layers,
111
- gin_channels=0,
112
- filter_channels=None,
113
- n_heads=None,
114
- p_dropout=None,
115
- ):
116
- super().__init__()
117
- self.out_channels = out_channels
118
- self.hidden_channels = hidden_channels
119
- self.kernel_size = kernel_size
120
- self.n_layers = n_layers
121
- self.gin_channels = gin_channels
122
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
123
- self.f0_emb = nn.Embedding(256, hidden_channels)
124
-
125
- self.enc_ = attentions.Encoder(
126
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
127
- )
128
-
129
- def forward(self, x, x_mask, f0=None, noice_scale=1):
130
- x = x + self.f0_emb(f0).transpose(1, 2)
131
- x = self.enc_(x * x_mask, x_mask)
132
- stats = self.proj(x) * x_mask
133
- m, logs = torch.split(stats, self.out_channels, dim=1)
134
- z = (m + torch.randn_like(m) * torch.exp(logs) * noice_scale) * x_mask
135
-
136
- return z, m, logs, x_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/flows.py DELETED
@@ -1,48 +0,0 @@
1
- from torch import nn
2
-
3
- from so_vits_svc_fork.modules import modules as modules
4
-
5
-
6
- class ResidualCouplingBlock(nn.Module):
7
- def __init__(
8
- self,
9
- channels,
10
- hidden_channels,
11
- kernel_size,
12
- dilation_rate,
13
- n_layers,
14
- n_flows=4,
15
- gin_channels=0,
16
- ):
17
- super().__init__()
18
- self.channels = channels
19
- self.hidden_channels = hidden_channels
20
- self.kernel_size = kernel_size
21
- self.dilation_rate = dilation_rate
22
- self.n_layers = n_layers
23
- self.n_flows = n_flows
24
- self.gin_channels = gin_channels
25
-
26
- self.flows = nn.ModuleList()
27
- for i in range(n_flows):
28
- self.flows.append(
29
- modules.ResidualCouplingLayer(
30
- channels,
31
- hidden_channels,
32
- kernel_size,
33
- dilation_rate,
34
- n_layers,
35
- gin_channels=gin_channels,
36
- mean_only=True,
37
- )
38
- )
39
- self.flows.append(modules.Flip())
40
-
41
- def forward(self, x, x_mask, g=None, reverse=False):
42
- if not reverse:
43
- for flow in self.flows:
44
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
45
- else:
46
- for flow in reversed(self.flows):
47
- x = flow(x, x_mask, g=g, reverse=reverse)
48
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/losses.py DELETED
@@ -1,58 +0,0 @@
1
- import torch
2
-
3
-
4
- def feature_loss(fmap_r, fmap_g):
5
- loss = 0
6
- for dr, dg in zip(fmap_r, fmap_g):
7
- for rl, gl in zip(dr, dg):
8
- rl = rl.float().detach()
9
- gl = gl.float()
10
- loss += torch.mean(torch.abs(rl - gl))
11
-
12
- return loss * 2
13
-
14
-
15
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
16
- loss = 0
17
- r_losses = []
18
- g_losses = []
19
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
20
- dr = dr.float()
21
- dg = dg.float()
22
- r_loss = torch.mean((1 - dr) ** 2)
23
- g_loss = torch.mean(dg**2)
24
- loss += r_loss + g_loss
25
- r_losses.append(r_loss.item())
26
- g_losses.append(g_loss.item())
27
-
28
- return loss, r_losses, g_losses
29
-
30
-
31
- def generator_loss(disc_outputs):
32
- loss = 0
33
- gen_losses = []
34
- for dg in disc_outputs:
35
- dg = dg.float()
36
- l = torch.mean((1 - dg) ** 2)
37
- gen_losses.append(l)
38
- loss += l
39
-
40
- return loss, gen_losses
41
-
42
-
43
- def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
44
- """
45
- z_p, logs_q: [b, h, t_t]
46
- m_p, logs_p: [b, h, t_t]
47
- """
48
- z_p = z_p.float()
49
- logs_q = logs_q.float()
50
- m_p = m_p.float()
51
- logs_p = logs_p.float()
52
- z_mask = z_mask.float()
53
- # print(logs_p)
54
- kl = logs_p - logs_q - 0.5
55
- kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
56
- kl = torch.sum(kl * z_mask)
57
- l = kl / torch.sum(z_mask)
58
- return l
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/mel_processing.py DELETED
@@ -1,205 +0,0 @@
1
- """from logging import getLogger
2
-
3
- import torch
4
- import torch.utils.data
5
- import torchaudio
6
-
7
- LOG = getLogger(__name__)
8
-
9
-
10
- from ..hparams import HParams
11
-
12
-
13
- def spectrogram_torch(audio: torch.Tensor, hps: HParams) -> torch.Tensor:
14
- return torchaudio.transforms.Spectrogram(
15
- n_fft=hps.data.filter_length,
16
- win_length=hps.data.win_length,
17
- hop_length=hps.data.hop_length,
18
- power=1.0,
19
- window_fn=torch.hann_window,
20
- normalized=False,
21
- ).to(audio.device)(audio)
22
-
23
-
24
- def spec_to_mel_torch(spec: torch.Tensor, hps: HParams) -> torch.Tensor:
25
- return torchaudio.transforms.MelScale(
26
- n_mels=hps.data.n_mel_channels,
27
- sample_rate=hps.data.sampling_rate,
28
- f_min=hps.data.mel_fmin,
29
- f_max=hps.data.mel_fmax,
30
- ).to(spec.device)(spec)
31
-
32
-
33
- def mel_spectrogram_torch(audio: torch.Tensor, hps: HParams) -> torch.Tensor:
34
- return torchaudio.transforms.MelSpectrogram(
35
- sample_rate=hps.data.sampling_rate,
36
- n_fft=hps.data.filter_length,
37
- n_mels=hps.data.n_mel_channels,
38
- win_length=hps.data.win_length,
39
- hop_length=hps.data.hop_length,
40
- f_min=hps.data.mel_fmin,
41
- f_max=hps.data.mel_fmax,
42
- power=1.0,
43
- window_fn=torch.hann_window,
44
- normalized=False,
45
- ).to(audio.device)(audio)"""
46
-
47
- from logging import getLogger
48
-
49
- import torch
50
- import torch.utils.data
51
- from librosa.filters import mel as librosa_mel_fn
52
-
53
- LOG = getLogger(__name__)
54
-
55
- MAX_WAV_VALUE = 32768.0
56
-
57
-
58
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
59
- """
60
- PARAMS
61
- ------
62
- C: compression factor
63
- """
64
- return torch.log(torch.clamp(x, min=clip_val) * C)
65
-
66
-
67
- def dynamic_range_decompression_torch(x, C=1):
68
- """
69
- PARAMS
70
- ------
71
- C: compression factor used to compress
72
- """
73
- return torch.exp(x) / C
74
-
75
-
76
- def spectral_normalize_torch(magnitudes):
77
- output = dynamic_range_compression_torch(magnitudes)
78
- return output
79
-
80
-
81
- def spectral_de_normalize_torch(magnitudes):
82
- output = dynamic_range_decompression_torch(magnitudes)
83
- return output
84
-
85
-
86
- mel_basis = {}
87
- hann_window = {}
88
-
89
-
90
- def spectrogram_torch(y, hps, center=False):
91
- if torch.min(y) < -1.0:
92
- LOG.info("min value is ", torch.min(y))
93
- if torch.max(y) > 1.0:
94
- LOG.info("max value is ", torch.max(y))
95
- n_fft = hps.data.filter_length
96
- hop_size = hps.data.hop_length
97
- win_size = hps.data.win_length
98
- global hann_window
99
- dtype_device = str(y.dtype) + "_" + str(y.device)
100
- wnsize_dtype_device = str(win_size) + "_" + dtype_device
101
- if wnsize_dtype_device not in hann_window:
102
- hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
103
- dtype=y.dtype, device=y.device
104
- )
105
-
106
- y = torch.nn.functional.pad(
107
- y.unsqueeze(1),
108
- (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
109
- mode="reflect",
110
- )
111
- y = y.squeeze(1)
112
-
113
- spec = torch.stft(
114
- y,
115
- n_fft,
116
- hop_length=hop_size,
117
- win_length=win_size,
118
- window=hann_window[wnsize_dtype_device],
119
- center=center,
120
- pad_mode="reflect",
121
- normalized=False,
122
- onesided=True,
123
- return_complex=False,
124
- )
125
-
126
- spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
127
- return spec
128
-
129
-
130
- def spec_to_mel_torch(spec, hps):
131
- sampling_rate = hps.data.sampling_rate
132
- n_fft = hps.data.filter_length
133
- num_mels = hps.data.n_mel_channels
134
- fmin = hps.data.mel_fmin
135
- fmax = hps.data.mel_fmax
136
- global mel_basis
137
- dtype_device = str(spec.dtype) + "_" + str(spec.device)
138
- fmax_dtype_device = str(fmax) + "_" + dtype_device
139
- if fmax_dtype_device not in mel_basis:
140
- mel = librosa_mel_fn(
141
- sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
142
- )
143
- mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
144
- dtype=spec.dtype, device=spec.device
145
- )
146
- spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
147
- spec = spectral_normalize_torch(spec)
148
- return spec
149
-
150
-
151
- def mel_spectrogram_torch(y, hps, center=False):
152
- sampling_rate = hps.data.sampling_rate
153
- n_fft = hps.data.filter_length
154
- num_mels = hps.data.n_mel_channels
155
- fmin = hps.data.mel_fmin
156
- fmax = hps.data.mel_fmax
157
- hop_size = hps.data.hop_length
158
- win_size = hps.data.win_length
159
- if torch.min(y) < -1.0:
160
- LOG.info(f"min value is {torch.min(y)}")
161
- if torch.max(y) > 1.0:
162
- LOG.info(f"max value is {torch.max(y)}")
163
-
164
- global mel_basis, hann_window
165
- dtype_device = str(y.dtype) + "_" + str(y.device)
166
- fmax_dtype_device = str(fmax) + "_" + dtype_device
167
- wnsize_dtype_device = str(win_size) + "_" + dtype_device
168
- if fmax_dtype_device not in mel_basis:
169
- mel = librosa_mel_fn(
170
- sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
171
- )
172
- mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
173
- dtype=y.dtype, device=y.device
174
- )
175
- if wnsize_dtype_device not in hann_window:
176
- hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
177
- dtype=y.dtype, device=y.device
178
- )
179
-
180
- y = torch.nn.functional.pad(
181
- y.unsqueeze(1),
182
- (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
183
- mode="reflect",
184
- )
185
- y = y.squeeze(1)
186
-
187
- spec = torch.stft(
188
- y,
189
- n_fft,
190
- hop_length=hop_size,
191
- win_length=win_size,
192
- window=hann_window[wnsize_dtype_device],
193
- center=center,
194
- pad_mode="reflect",
195
- normalized=False,
196
- onesided=True,
197
- return_complex=False,
198
- )
199
-
200
- spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
201
-
202
- spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
203
- spec = spectral_normalize_torch(spec)
204
-
205
- return spec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/modules.py DELETED
@@ -1,452 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import Conv1d
4
- from torch.nn import functional as F
5
- from torch.nn.utils import remove_weight_norm, weight_norm
6
-
7
- from so_vits_svc_fork.modules import commons
8
- from so_vits_svc_fork.modules.commons import get_padding, init_weights
9
-
10
- LRELU_SLOPE = 0.1
11
-
12
-
13
- class LayerNorm(nn.Module):
14
- def __init__(self, channels, eps=1e-5):
15
- super().__init__()
16
- self.channels = channels
17
- self.eps = eps
18
-
19
- self.gamma = nn.Parameter(torch.ones(channels))
20
- self.beta = nn.Parameter(torch.zeros(channels))
21
-
22
- def forward(self, x):
23
- x = x.transpose(1, -1)
24
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
25
- return x.transpose(1, -1)
26
-
27
-
28
- class ConvReluNorm(nn.Module):
29
- def __init__(
30
- self,
31
- in_channels,
32
- hidden_channels,
33
- out_channels,
34
- kernel_size,
35
- n_layers,
36
- p_dropout,
37
- ):
38
- super().__init__()
39
- self.in_channels = in_channels
40
- self.hidden_channels = hidden_channels
41
- self.out_channels = out_channels
42
- self.kernel_size = kernel_size
43
- self.n_layers = n_layers
44
- self.p_dropout = p_dropout
45
- assert n_layers > 1, "Number of layers should be larger than 0."
46
-
47
- self.conv_layers = nn.ModuleList()
48
- self.norm_layers = nn.ModuleList()
49
- self.conv_layers.append(
50
- nn.Conv1d(
51
- in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
52
- )
53
- )
54
- self.norm_layers.append(LayerNorm(hidden_channels))
55
- self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
56
- for _ in range(n_layers - 1):
57
- self.conv_layers.append(
58
- nn.Conv1d(
59
- hidden_channels,
60
- hidden_channels,
61
- kernel_size,
62
- padding=kernel_size // 2,
63
- )
64
- )
65
- self.norm_layers.append(LayerNorm(hidden_channels))
66
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
67
- self.proj.weight.data.zero_()
68
- self.proj.bias.data.zero_()
69
-
70
- def forward(self, x, x_mask):
71
- x_org = x
72
- for i in range(self.n_layers):
73
- x = self.conv_layers[i](x * x_mask)
74
- x = self.norm_layers[i](x)
75
- x = self.relu_drop(x)
76
- x = x_org + self.proj(x)
77
- return x * x_mask
78
-
79
-
80
- class DDSConv(nn.Module):
81
- """
82
- Dialted and Depth-Separable Convolution
83
- """
84
-
85
- def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
86
- super().__init__()
87
- self.channels = channels
88
- self.kernel_size = kernel_size
89
- self.n_layers = n_layers
90
- self.p_dropout = p_dropout
91
-
92
- self.drop = nn.Dropout(p_dropout)
93
- self.convs_sep = nn.ModuleList()
94
- self.convs_1x1 = nn.ModuleList()
95
- self.norms_1 = nn.ModuleList()
96
- self.norms_2 = nn.ModuleList()
97
- for i in range(n_layers):
98
- dilation = kernel_size**i
99
- padding = (kernel_size * dilation - dilation) // 2
100
- self.convs_sep.append(
101
- nn.Conv1d(
102
- channels,
103
- channels,
104
- kernel_size,
105
- groups=channels,
106
- dilation=dilation,
107
- padding=padding,
108
- )
109
- )
110
- self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
111
- self.norms_1.append(LayerNorm(channels))
112
- self.norms_2.append(LayerNorm(channels))
113
-
114
- def forward(self, x, x_mask, g=None):
115
- if g is not None:
116
- x = x + g
117
- for i in range(self.n_layers):
118
- y = self.convs_sep[i](x * x_mask)
119
- y = self.norms_1[i](y)
120
- y = F.gelu(y)
121
- y = self.convs_1x1[i](y)
122
- y = self.norms_2[i](y)
123
- y = F.gelu(y)
124
- y = self.drop(y)
125
- x = x + y
126
- return x * x_mask
127
-
128
-
129
- class WN(torch.nn.Module):
130
- def __init__(
131
- self,
132
- hidden_channels,
133
- kernel_size,
134
- dilation_rate,
135
- n_layers,
136
- gin_channels=0,
137
- p_dropout=0,
138
- ):
139
- super().__init__()
140
- assert kernel_size % 2 == 1
141
- self.hidden_channels = hidden_channels
142
- self.kernel_size = (kernel_size,)
143
- self.dilation_rate = dilation_rate
144
- self.n_layers = n_layers
145
- self.gin_channels = gin_channels
146
- self.p_dropout = p_dropout
147
-
148
- self.in_layers = torch.nn.ModuleList()
149
- self.res_skip_layers = torch.nn.ModuleList()
150
- self.drop = nn.Dropout(p_dropout)
151
-
152
- if gin_channels != 0:
153
- cond_layer = torch.nn.Conv1d(
154
- gin_channels, 2 * hidden_channels * n_layers, 1
155
- )
156
- self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
157
-
158
- for i in range(n_layers):
159
- dilation = dilation_rate**i
160
- padding = int((kernel_size * dilation - dilation) / 2)
161
- in_layer = torch.nn.Conv1d(
162
- hidden_channels,
163
- 2 * hidden_channels,
164
- kernel_size,
165
- dilation=dilation,
166
- padding=padding,
167
- )
168
- in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
169
- self.in_layers.append(in_layer)
170
-
171
- # last one is not necessary
172
- if i < n_layers - 1:
173
- res_skip_channels = 2 * hidden_channels
174
- else:
175
- res_skip_channels = hidden_channels
176
-
177
- res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
178
- res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
179
- self.res_skip_layers.append(res_skip_layer)
180
-
181
- def forward(self, x, x_mask, g=None, **kwargs):
182
- output = torch.zeros_like(x)
183
- n_channels_tensor = torch.IntTensor([self.hidden_channels])
184
-
185
- if g is not None:
186
- g = self.cond_layer(g)
187
-
188
- for i in range(self.n_layers):
189
- x_in = self.in_layers[i](x)
190
- if g is not None:
191
- cond_offset = i * 2 * self.hidden_channels
192
- g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
193
- else:
194
- g_l = torch.zeros_like(x_in)
195
-
196
- acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
197
- acts = self.drop(acts)
198
-
199
- res_skip_acts = self.res_skip_layers[i](acts)
200
- if i < self.n_layers - 1:
201
- res_acts = res_skip_acts[:, : self.hidden_channels, :]
202
- x = (x + res_acts) * x_mask
203
- output = output + res_skip_acts[:, self.hidden_channels :, :]
204
- else:
205
- output = output + res_skip_acts
206
- return output * x_mask
207
-
208
- def remove_weight_norm(self):
209
- if self.gin_channels != 0:
210
- torch.nn.utils.remove_weight_norm(self.cond_layer)
211
- for l in self.in_layers:
212
- torch.nn.utils.remove_weight_norm(l)
213
- for l in self.res_skip_layers:
214
- torch.nn.utils.remove_weight_norm(l)
215
-
216
-
217
- class ResBlock1(torch.nn.Module):
218
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
219
- super().__init__()
220
- self.convs1 = nn.ModuleList(
221
- [
222
- weight_norm(
223
- Conv1d(
224
- channels,
225
- channels,
226
- kernel_size,
227
- 1,
228
- dilation=dilation[0],
229
- padding=get_padding(kernel_size, dilation[0]),
230
- )
231
- ),
232
- weight_norm(
233
- Conv1d(
234
- channels,
235
- channels,
236
- kernel_size,
237
- 1,
238
- dilation=dilation[1],
239
- padding=get_padding(kernel_size, dilation[1]),
240
- )
241
- ),
242
- weight_norm(
243
- Conv1d(
244
- channels,
245
- channels,
246
- kernel_size,
247
- 1,
248
- dilation=dilation[2],
249
- padding=get_padding(kernel_size, dilation[2]),
250
- )
251
- ),
252
- ]
253
- )
254
- self.convs1.apply(init_weights)
255
-
256
- self.convs2 = nn.ModuleList(
257
- [
258
- weight_norm(
259
- Conv1d(
260
- channels,
261
- channels,
262
- kernel_size,
263
- 1,
264
- dilation=1,
265
- padding=get_padding(kernel_size, 1),
266
- )
267
- ),
268
- weight_norm(
269
- Conv1d(
270
- channels,
271
- channels,
272
- kernel_size,
273
- 1,
274
- dilation=1,
275
- padding=get_padding(kernel_size, 1),
276
- )
277
- ),
278
- weight_norm(
279
- Conv1d(
280
- channels,
281
- channels,
282
- kernel_size,
283
- 1,
284
- dilation=1,
285
- padding=get_padding(kernel_size, 1),
286
- )
287
- ),
288
- ]
289
- )
290
- self.convs2.apply(init_weights)
291
-
292
- def forward(self, x, x_mask=None):
293
- for c1, c2 in zip(self.convs1, self.convs2):
294
- xt = F.leaky_relu(x, LRELU_SLOPE)
295
- if x_mask is not None:
296
- xt = xt * x_mask
297
- xt = c1(xt)
298
- xt = F.leaky_relu(xt, LRELU_SLOPE)
299
- if x_mask is not None:
300
- xt = xt * x_mask
301
- xt = c2(xt)
302
- x = xt + x
303
- if x_mask is not None:
304
- x = x * x_mask
305
- return x
306
-
307
- def remove_weight_norm(self):
308
- for l in self.convs1:
309
- remove_weight_norm(l)
310
- for l in self.convs2:
311
- remove_weight_norm(l)
312
-
313
-
314
- class ResBlock2(torch.nn.Module):
315
- def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
316
- super().__init__()
317
- self.convs = nn.ModuleList(
318
- [
319
- weight_norm(
320
- Conv1d(
321
- channels,
322
- channels,
323
- kernel_size,
324
- 1,
325
- dilation=dilation[0],
326
- padding=get_padding(kernel_size, dilation[0]),
327
- )
328
- ),
329
- weight_norm(
330
- Conv1d(
331
- channels,
332
- channels,
333
- kernel_size,
334
- 1,
335
- dilation=dilation[1],
336
- padding=get_padding(kernel_size, dilation[1]),
337
- )
338
- ),
339
- ]
340
- )
341
- self.convs.apply(init_weights)
342
-
343
- def forward(self, x, x_mask=None):
344
- for c in self.convs:
345
- xt = F.leaky_relu(x, LRELU_SLOPE)
346
- if x_mask is not None:
347
- xt = xt * x_mask
348
- xt = c(xt)
349
- x = xt + x
350
- if x_mask is not None:
351
- x = x * x_mask
352
- return x
353
-
354
- def remove_weight_norm(self):
355
- for l in self.convs:
356
- remove_weight_norm(l)
357
-
358
-
359
- class Log(nn.Module):
360
- def forward(self, x, x_mask, reverse=False, **kwargs):
361
- if not reverse:
362
- y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
363
- logdet = torch.sum(-y, [1, 2])
364
- return y, logdet
365
- else:
366
- x = torch.exp(x) * x_mask
367
- return x
368
-
369
-
370
- class Flip(nn.Module):
371
- def forward(self, x, *args, reverse=False, **kwargs):
372
- x = torch.flip(x, [1])
373
- if not reverse:
374
- logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
375
- return x, logdet
376
- else:
377
- return x
378
-
379
-
380
- class ElementwiseAffine(nn.Module):
381
- def __init__(self, channels):
382
- super().__init__()
383
- self.channels = channels
384
- self.m = nn.Parameter(torch.zeros(channels, 1))
385
- self.logs = nn.Parameter(torch.zeros(channels, 1))
386
-
387
- def forward(self, x, x_mask, reverse=False, **kwargs):
388
- if not reverse:
389
- y = self.m + torch.exp(self.logs) * x
390
- y = y * x_mask
391
- logdet = torch.sum(self.logs * x_mask, [1, 2])
392
- return y, logdet
393
- else:
394
- x = (x - self.m) * torch.exp(-self.logs) * x_mask
395
- return x
396
-
397
-
398
- class ResidualCouplingLayer(nn.Module):
399
- def __init__(
400
- self,
401
- channels,
402
- hidden_channels,
403
- kernel_size,
404
- dilation_rate,
405
- n_layers,
406
- p_dropout=0,
407
- gin_channels=0,
408
- mean_only=False,
409
- ):
410
- assert channels % 2 == 0, "channels should be divisible by 2"
411
- super().__init__()
412
- self.channels = channels
413
- self.hidden_channels = hidden_channels
414
- self.kernel_size = kernel_size
415
- self.dilation_rate = dilation_rate
416
- self.n_layers = n_layers
417
- self.half_channels = channels // 2
418
- self.mean_only = mean_only
419
-
420
- self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
421
- self.enc = WN(
422
- hidden_channels,
423
- kernel_size,
424
- dilation_rate,
425
- n_layers,
426
- p_dropout=p_dropout,
427
- gin_channels=gin_channels,
428
- )
429
- self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
430
- self.post.weight.data.zero_()
431
- self.post.bias.data.zero_()
432
-
433
- def forward(self, x, x_mask, g=None, reverse=False):
434
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
435
- h = self.pre(x0) * x_mask
436
- h = self.enc(h, x_mask, g=g)
437
- stats = self.post(h) * x_mask
438
- if not self.mean_only:
439
- m, logs = torch.split(stats, [self.half_channels] * 2, 1)
440
- else:
441
- m = stats
442
- logs = torch.zeros_like(m)
443
-
444
- if not reverse:
445
- x1 = m + x1 * torch.exp(logs) * x_mask
446
- x = torch.cat([x0, x1], 1)
447
- logdet = torch.sum(logs, [1, 2])
448
- return x, logdet
449
- else:
450
- x1 = (x1 - m) * torch.exp(-logs) * x_mask
451
- x = torch.cat([x0, x1], 1)
452
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/modules/synthesizers.py DELETED
@@ -1,233 +0,0 @@
1
- import warnings
2
- from logging import getLogger
3
- from typing import Any, Literal, Sequence
4
-
5
- import torch
6
- from torch import nn
7
-
8
- import so_vits_svc_fork.f0
9
- from so_vits_svc_fork.f0 import f0_to_coarse
10
- from so_vits_svc_fork.modules import commons as commons
11
- from so_vits_svc_fork.modules.decoders.f0 import F0Decoder
12
- from so_vits_svc_fork.modules.decoders.hifigan import NSFHifiGANGenerator
13
- from so_vits_svc_fork.modules.decoders.mb_istft import (
14
- Multiband_iSTFT_Generator,
15
- Multistream_iSTFT_Generator,
16
- iSTFT_Generator,
17
- )
18
- from so_vits_svc_fork.modules.encoders import Encoder, TextEncoder
19
- from so_vits_svc_fork.modules.flows import ResidualCouplingBlock
20
-
21
- LOG = getLogger(__name__)
22
-
23
-
24
- class SynthesizerTrn(nn.Module):
25
- """
26
- Synthesizer for Training
27
- """
28
-
29
- def __init__(
30
- self,
31
- spec_channels: int,
32
- segment_size: int,
33
- inter_channels: int,
34
- hidden_channels: int,
35
- filter_channels: int,
36
- n_heads: int,
37
- n_layers: int,
38
- kernel_size: int,
39
- p_dropout: int,
40
- resblock: str,
41
- resblock_kernel_sizes: Sequence[int],
42
- resblock_dilation_sizes: Sequence[Sequence[int]],
43
- upsample_rates: Sequence[int],
44
- upsample_initial_channel: int,
45
- upsample_kernel_sizes: Sequence[int],
46
- gin_channels: int,
47
- ssl_dim: int,
48
- n_speakers: int,
49
- sampling_rate: int = 44100,
50
- type_: Literal["hifi-gan", "istft", "ms-istft", "mb-istft"] = "hifi-gan",
51
- gen_istft_n_fft: int = 16,
52
- gen_istft_hop_size: int = 4,
53
- subbands: int = 4,
54
- **kwargs: Any,
55
- ):
56
- super().__init__()
57
- self.spec_channels = spec_channels
58
- self.inter_channels = inter_channels
59
- self.hidden_channels = hidden_channels
60
- self.filter_channels = filter_channels
61
- self.n_heads = n_heads
62
- self.n_layers = n_layers
63
- self.kernel_size = kernel_size
64
- self.p_dropout = p_dropout
65
- self.resblock = resblock
66
- self.resblock_kernel_sizes = resblock_kernel_sizes
67
- self.resblock_dilation_sizes = resblock_dilation_sizes
68
- self.upsample_rates = upsample_rates
69
- self.upsample_initial_channel = upsample_initial_channel
70
- self.upsample_kernel_sizes = upsample_kernel_sizes
71
- self.segment_size = segment_size
72
- self.gin_channels = gin_channels
73
- self.ssl_dim = ssl_dim
74
- self.n_speakers = n_speakers
75
- self.sampling_rate = sampling_rate
76
- self.type_ = type_
77
- self.gen_istft_n_fft = gen_istft_n_fft
78
- self.gen_istft_hop_size = gen_istft_hop_size
79
- self.subbands = subbands
80
- if kwargs:
81
- warnings.warn(f"Unused arguments: {kwargs}")
82
-
83
- self.emb_g = nn.Embedding(n_speakers, gin_channels)
84
-
85
- if ssl_dim is None:
86
- self.pre = nn.LazyConv1d(hidden_channels, kernel_size=5, padding=2)
87
- else:
88
- self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2)
89
-
90
- self.enc_p = TextEncoder(
91
- inter_channels,
92
- hidden_channels,
93
- filter_channels=filter_channels,
94
- n_heads=n_heads,
95
- n_layers=n_layers,
96
- kernel_size=kernel_size,
97
- p_dropout=p_dropout,
98
- )
99
-
100
- LOG.info(f"Decoder type: {type_}")
101
- if type_ == "hifi-gan":
102
- hps = {
103
- "sampling_rate": sampling_rate,
104
- "inter_channels": inter_channels,
105
- "resblock": resblock,
106
- "resblock_kernel_sizes": resblock_kernel_sizes,
107
- "resblock_dilation_sizes": resblock_dilation_sizes,
108
- "upsample_rates": upsample_rates,
109
- "upsample_initial_channel": upsample_initial_channel,
110
- "upsample_kernel_sizes": upsample_kernel_sizes,
111
- "gin_channels": gin_channels,
112
- }
113
- self.dec = NSFHifiGANGenerator(h=hps)
114
- self.mb = False
115
- else:
116
- hps = {
117
- "initial_channel": inter_channels,
118
- "resblock": resblock,
119
- "resblock_kernel_sizes": resblock_kernel_sizes,
120
- "resblock_dilation_sizes": resblock_dilation_sizes,
121
- "upsample_rates": upsample_rates,
122
- "upsample_initial_channel": upsample_initial_channel,
123
- "upsample_kernel_sizes": upsample_kernel_sizes,
124
- "gin_channels": gin_channels,
125
- "gen_istft_n_fft": gen_istft_n_fft,
126
- "gen_istft_hop_size": gen_istft_hop_size,
127
- "subbands": subbands,
128
- }
129
-
130
- # gen_istft_n_fft, gen_istft_hop_size, subbands
131
- if type_ == "istft":
132
- del hps["subbands"]
133
- self.dec = iSTFT_Generator(**hps)
134
- elif type_ == "ms-istft":
135
- self.dec = Multistream_iSTFT_Generator(**hps)
136
- elif type_ == "mb-istft":
137
- self.dec = Multiband_iSTFT_Generator(**hps)
138
- else:
139
- raise ValueError(f"Unknown type: {type_}")
140
- self.mb = True
141
-
142
- self.enc_q = Encoder(
143
- spec_channels,
144
- inter_channels,
145
- hidden_channels,
146
- 5,
147
- 1,
148
- 16,
149
- gin_channels=gin_channels,
150
- )
151
- self.flow = ResidualCouplingBlock(
152
- inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
153
- )
154
- self.f0_decoder = F0Decoder(
155
- 1,
156
- hidden_channels,
157
- filter_channels,
158
- n_heads,
159
- n_layers,
160
- kernel_size,
161
- p_dropout,
162
- spk_channels=gin_channels,
163
- )
164
- self.emb_uv = nn.Embedding(2, hidden_channels)
165
-
166
- def forward(self, c, f0, uv, spec, g=None, c_lengths=None, spec_lengths=None):
167
- g = self.emb_g(g).transpose(1, 2)
168
- # ssl prenet
169
- x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(
170
- c.dtype
171
- )
172
- x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2)
173
-
174
- # f0 predict
175
- lf0 = 2595.0 * torch.log10(1.0 + f0.unsqueeze(1) / 700.0) / 500
176
- norm_lf0 = so_vits_svc_fork.f0.normalize_f0(lf0, x_mask, uv)
177
- pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
178
-
179
- # encoder
180
- z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask, f0=f0_to_coarse(f0))
181
- z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
182
-
183
- # flow
184
- z_p = self.flow(z, spec_mask, g=g)
185
- z_slice, pitch_slice, ids_slice = commons.rand_slice_segments_with_pitch(
186
- z, f0, spec_lengths, self.segment_size
187
- )
188
-
189
- # MB-iSTFT-VITS
190
- if self.mb:
191
- o, o_mb = self.dec(z_slice, g=g)
192
- # HiFi-GAN
193
- else:
194
- o = self.dec(z_slice, g=g, f0=pitch_slice)
195
- o_mb = None
196
- return (
197
- o,
198
- o_mb,
199
- ids_slice,
200
- spec_mask,
201
- (z, z_p, m_p, logs_p, m_q, logs_q),
202
- pred_lf0,
203
- norm_lf0,
204
- lf0,
205
- )
206
-
207
- def infer(self, c, f0, uv, g=None, noice_scale=0.35, predict_f0=False):
208
- c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
209
- g = self.emb_g(g).transpose(1, 2)
210
- x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(
211
- c.dtype
212
- )
213
- x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2)
214
-
215
- if predict_f0:
216
- lf0 = 2595.0 * torch.log10(1.0 + f0.unsqueeze(1) / 700.0) / 500
217
- norm_lf0 = so_vits_svc_fork.f0.normalize_f0(
218
- lf0, x_mask, uv, random_scale=False
219
- )
220
- pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
221
- f0 = (700 * (torch.pow(10, pred_lf0 * 500 / 2595) - 1)).squeeze(1)
222
-
223
- z_p, m_p, logs_p, c_mask = self.enc_p(
224
- x, x_mask, f0=f0_to_coarse(f0), noice_scale=noice_scale
225
- )
226
- z = self.flow(z_p, c_mask, g=g, reverse=True)
227
-
228
- # MB-iSTFT-VITS
229
- if self.mb:
230
- o, o_mb = self.dec(z * c_mask, g=g)
231
- else:
232
- o = self.dec(z * c_mask, g=g, f0=f0)
233
- return o
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/preprocessing/__init__.py DELETED
File without changes
so_vits_svc_fork/preprocessing/config_templates/quickvc.json DELETED
@@ -1,78 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 100,
4
- "eval_interval": 200,
5
- "seed": 1234,
6
- "epochs": 10000,
7
- "learning_rate": 0.0001,
8
- "betas": [0.8, 0.99],
9
- "eps": 1e-9,
10
- "batch_size": 16,
11
- "fp16_run": false,
12
- "bf16_run": false,
13
- "lr_decay": 0.999875,
14
- "segment_size": 10240,
15
- "init_lr_ratio": 1,
16
- "warmup_epochs": 0,
17
- "c_mel": 45,
18
- "c_kl": 1.0,
19
- "use_sr": true,
20
- "max_speclen": 512,
21
- "port": "8001",
22
- "keep_ckpts": 3,
23
- "fft_sizes": [768, 1366, 342],
24
- "hop_sizes": [60, 120, 20],
25
- "win_lengths": [300, 600, 120],
26
- "window": "hann_window",
27
- "num_workers": 4,
28
- "log_version": 0,
29
- "ckpt_name_by_step": false,
30
- "accumulate_grad_batches": 1
31
- },
32
- "data": {
33
- "training_files": "filelists/44k/train.txt",
34
- "validation_files": "filelists/44k/val.txt",
35
- "max_wav_value": 32768.0,
36
- "sampling_rate": 44100,
37
- "filter_length": 2048,
38
- "hop_length": 512,
39
- "win_length": 2048,
40
- "n_mel_channels": 80,
41
- "mel_fmin": 0.0,
42
- "mel_fmax": 22050,
43
- "contentvec_final_proj": false
44
- },
45
- "model": {
46
- "inter_channels": 192,
47
- "hidden_channels": 192,
48
- "filter_channels": 768,
49
- "n_heads": 2,
50
- "n_layers": 6,
51
- "kernel_size": 3,
52
- "p_dropout": 0.1,
53
- "resblock": "1",
54
- "resblock_kernel_sizes": [3, 7, 11],
55
- "resblock_dilation_sizes": [
56
- [1, 3, 5],
57
- [1, 3, 5],
58
- [1, 3, 5]
59
- ],
60
- "upsample_rates": [8, 4],
61
- "upsample_initial_channel": 512,
62
- "upsample_kernel_sizes": [32, 16],
63
- "n_layers_q": 3,
64
- "use_spectral_norm": false,
65
- "gin_channels": 256,
66
- "ssl_dim": 768,
67
- "n_speakers": 200,
68
- "type_": "ms-istft",
69
- "gen_istft_n_fft": 16,
70
- "gen_istft_hop_size": 4,
71
- "subbands": 4,
72
- "pretrained": {
73
- "D_0.pth": "https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth",
74
- "G_0.pth": "https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth"
75
- }
76
- },
77
- "spk": {}
78
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1-legacy.json DELETED
@@ -1,69 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "eval_interval": 800,
5
- "seed": 1234,
6
- "epochs": 10000,
7
- "learning_rate": 0.0001,
8
- "betas": [0.8, 0.99],
9
- "eps": 1e-9,
10
- "batch_size": 16,
11
- "fp16_run": false,
12
- "bf16_run": false,
13
- "lr_decay": 0.999875,
14
- "segment_size": 10240,
15
- "init_lr_ratio": 1,
16
- "warmup_epochs": 0,
17
- "c_mel": 45,
18
- "c_kl": 1.0,
19
- "use_sr": true,
20
- "max_speclen": 512,
21
- "port": "8001",
22
- "keep_ckpts": 3,
23
- "num_workers": 4,
24
- "log_version": 0,
25
- "ckpt_name_by_step": false,
26
- "accumulate_grad_batches": 1
27
- },
28
- "data": {
29
- "training_files": "filelists/44k/train.txt",
30
- "validation_files": "filelists/44k/val.txt",
31
- "max_wav_value": 32768.0,
32
- "sampling_rate": 44100,
33
- "filter_length": 2048,
34
- "hop_length": 512,
35
- "win_length": 2048,
36
- "n_mel_channels": 80,
37
- "mel_fmin": 0.0,
38
- "mel_fmax": 22050
39
- },
40
- "model": {
41
- "inter_channels": 192,
42
- "hidden_channels": 192,
43
- "filter_channels": 768,
44
- "n_heads": 2,
45
- "n_layers": 6,
46
- "kernel_size": 3,
47
- "p_dropout": 0.1,
48
- "resblock": "1",
49
- "resblock_kernel_sizes": [3, 7, 11],
50
- "resblock_dilation_sizes": [
51
- [1, 3, 5],
52
- [1, 3, 5],
53
- [1, 3, 5]
54
- ],
55
- "upsample_rates": [8, 8, 2, 2, 2],
56
- "upsample_initial_channel": 512,
57
- "upsample_kernel_sizes": [16, 16, 4, 4, 4],
58
- "n_layers_q": 3,
59
- "use_spectral_norm": false,
60
- "gin_channels": 256,
61
- "ssl_dim": 256,
62
- "n_speakers": 200,
63
- "pretrained": {
64
- "D_0.pth": "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/D_0.pth",
65
- "G_0.pth": "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/G_0.pth"
66
- }
67
- },
68
- "spk": {}
69
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1.json DELETED
@@ -1,71 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 100,
4
- "eval_interval": 200,
5
- "seed": 1234,
6
- "epochs": 10000,
7
- "learning_rate": 0.0001,
8
- "betas": [0.8, 0.99],
9
- "eps": 1e-9,
10
- "batch_size": 16,
11
- "fp16_run": false,
12
- "bf16_run": false,
13
- "lr_decay": 0.999875,
14
- "segment_size": 10240,
15
- "init_lr_ratio": 1,
16
- "warmup_epochs": 0,
17
- "c_mel": 45,
18
- "c_kl": 1.0,
19
- "use_sr": true,
20
- "max_speclen": 512,
21
- "port": "8001",
22
- "keep_ckpts": 3,
23
- "num_workers": 4,
24
- "log_version": 0,
25
- "ckpt_name_by_step": false,
26
- "accumulate_grad_batches": 1
27
- },
28
- "data": {
29
- "training_files": "filelists/44k/train.txt",
30
- "validation_files": "filelists/44k/val.txt",
31
- "max_wav_value": 32768.0,
32
- "sampling_rate": 44100,
33
- "filter_length": 2048,
34
- "hop_length": 512,
35
- "win_length": 2048,
36
- "n_mel_channels": 80,
37
- "mel_fmin": 0.0,
38
- "mel_fmax": 22050,
39
- "contentvec_final_proj": false
40
- },
41
- "model": {
42
- "inter_channels": 192,
43
- "hidden_channels": 192,
44
- "filter_channels": 768,
45
- "n_heads": 2,
46
- "n_layers": 6,
47
- "kernel_size": 3,
48
- "p_dropout": 0.1,
49
- "resblock": "1",
50
- "resblock_kernel_sizes": [3, 7, 11],
51
- "resblock_dilation_sizes": [
52
- [1, 3, 5],
53
- [1, 3, 5],
54
- [1, 3, 5]
55
- ],
56
- "upsample_rates": [8, 8, 2, 2, 2],
57
- "upsample_initial_channel": 512,
58
- "upsample_kernel_sizes": [16, 16, 4, 4, 4],
59
- "n_layers_q": 3,
60
- "use_spectral_norm": false,
61
- "gin_channels": 256,
62
- "ssl_dim": 768,
63
- "n_speakers": 200,
64
- "type_": "hifi-gan",
65
- "pretrained": {
66
- "D_0.pth": "https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth",
67
- "G_0.pth": "https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth"
68
- }
69
- },
70
- "spk": {}
71
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/preprocessing/preprocess_classify.py DELETED
@@ -1,95 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from logging import getLogger
4
- from pathlib import Path
5
-
6
- import keyboard
7
- import librosa
8
- import sounddevice as sd
9
- import soundfile as sf
10
- from rich.console import Console
11
- from tqdm.rich import tqdm
12
-
13
- LOG = getLogger(__name__)
14
-
15
-
16
- def preprocess_classify(
17
- input_dir: Path | str, output_dir: Path | str, create_new: bool = True
18
- ) -> None:
19
- # paths
20
- input_dir_ = Path(input_dir)
21
- output_dir_ = Path(output_dir)
22
- speed = 1
23
- if not input_dir_.is_dir():
24
- raise ValueError(f"{input_dir} is not a directory.")
25
- output_dir_.mkdir(exist_ok=True)
26
-
27
- console = Console()
28
- # get audio paths and folders
29
- audio_paths = list(input_dir_.glob("*.*"))
30
- last_folders = [x for x in output_dir_.glob("*") if x.is_dir()]
31
- console.print("Press ↑ or ↓ to change speed. Press any other key to classify.")
32
- console.print(f"Folders: {[x.name for x in last_folders]}")
33
-
34
- pbar_description = ""
35
-
36
- pbar = tqdm(audio_paths)
37
- for audio_path in pbar:
38
- # read file
39
- audio, sr = sf.read(audio_path)
40
-
41
- # update description
42
- duration = librosa.get_duration(y=audio, sr=sr)
43
- pbar_description = f"{duration:.1f} {pbar_description}"
44
- pbar.set_description(pbar_description)
45
-
46
- while True:
47
- # start playing
48
- sd.play(librosa.effects.time_stretch(audio, rate=speed), sr, loop=True)
49
-
50
- # wait for key press
51
- key = str(keyboard.read_key())
52
- if key == "down":
53
- speed /= 1.1
54
- console.print(f"Speed: {speed:.2f}")
55
- elif key == "up":
56
- speed *= 1.1
57
- console.print(f"Speed: {speed:.2f}")
58
- else:
59
- break
60
-
61
- # stop playing
62
- sd.stop()
63
-
64
- # print if folder changed
65
- folders = [x for x in output_dir_.glob("*") if x.is_dir()]
66
- if folders != last_folders:
67
- console.print(f"Folders updated: {[x.name for x in folders]}")
68
- last_folders = folders
69
-
70
- # get folder
71
- folder_candidates = [x for x in folders if x.name.startswith(key)]
72
- if len(folder_candidates) == 0:
73
- if create_new:
74
- folder = output_dir_ / key
75
- else:
76
- console.print(f"No folder starts with {key}.")
77
- continue
78
- else:
79
- if len(folder_candidates) > 1:
80
- LOG.warning(
81
- f"Multiple folders ({[x.name for x in folder_candidates]}) start with {key}. "
82
- f"Using first one ({folder_candidates[0].name})."
83
- )
84
- folder = folder_candidates[0]
85
- folder.mkdir(exist_ok=True)
86
-
87
- # move file
88
- new_path = folder / audio_path.name
89
- audio_path.rename(new_path)
90
-
91
- # update description
92
- pbar_description = f"Last: {audio_path.name} -> {folder.name}"
93
-
94
- # yield result
95
- # yield audio_path, key, folder, new_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/preprocessing/preprocess_flist_config.py DELETED
@@ -1,86 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import os
5
- from copy import deepcopy
6
- from logging import getLogger
7
- from pathlib import Path
8
-
9
- import numpy as np
10
- from librosa import get_duration
11
- from tqdm import tqdm
12
-
13
- LOG = getLogger(__name__)
14
- CONFIG_TEMPLATE_DIR = Path(__file__).parent / "config_templates"
15
-
16
-
17
- def preprocess_config(
18
- input_dir: Path | str,
19
- train_list_path: Path | str,
20
- val_list_path: Path | str,
21
- test_list_path: Path | str,
22
- config_path: Path | str,
23
- config_name: str,
24
- ):
25
- input_dir = Path(input_dir)
26
- train_list_path = Path(train_list_path)
27
- val_list_path = Path(val_list_path)
28
- test_list_path = Path(test_list_path)
29
- config_path = Path(config_path)
30
- train = []
31
- val = []
32
- test = []
33
- spk_dict = {}
34
- spk_id = 0
35
- random = np.random.RandomState(1234)
36
- for speaker in os.listdir(input_dir):
37
- spk_dict[speaker] = spk_id
38
- spk_id += 1
39
- paths = []
40
- for path in tqdm(list((input_dir / speaker).rglob("*.wav"))):
41
- if get_duration(filename=path) < 0.3:
42
- LOG.warning(f"skip {path} because it is too short.")
43
- continue
44
- paths.append(path)
45
- random.shuffle(paths)
46
- if len(paths) <= 4:
47
- raise ValueError(
48
- f"too few files in {input_dir / speaker} (expected at least 5)."
49
- )
50
- train += paths[2:-2]
51
- val += paths[:2]
52
- test += paths[-2:]
53
-
54
- LOG.info(f"Writing {train_list_path}")
55
- train_list_path.parent.mkdir(parents=True, exist_ok=True)
56
- train_list_path.write_text(
57
- "\n".join([x.as_posix() for x in train]), encoding="utf-8"
58
- )
59
-
60
- LOG.info(f"Writing {val_list_path}")
61
- val_list_path.parent.mkdir(parents=True, exist_ok=True)
62
- val_list_path.write_text("\n".join([x.as_posix() for x in val]), encoding="utf-8")
63
-
64
- LOG.info(f"Writing {test_list_path}")
65
- test_list_path.parent.mkdir(parents=True, exist_ok=True)
66
- test_list_path.write_text("\n".join([x.as_posix() for x in test]), encoding="utf-8")
67
-
68
- config = deepcopy(
69
- json.loads(
70
- (
71
- CONFIG_TEMPLATE_DIR
72
- / (
73
- config_name
74
- if config_name.endswith(".json")
75
- else config_name + ".json"
76
- )
77
- ).read_text(encoding="utf-8")
78
- )
79
- )
80
- config["spk"] = spk_dict
81
- config["data"]["training_files"] = train_list_path.as_posix()
82
- config["data"]["validation_files"] = val_list_path.as_posix()
83
- LOG.info(f"Writing {config_path}")
84
- config_path.parent.mkdir(parents=True, exist_ok=True)
85
- with config_path.open("w", encoding="utf-8") as f:
86
- json.dump(config, f, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/preprocessing/preprocess_hubert_f0.py DELETED
@@ -1,157 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from logging import getLogger
4
- from pathlib import Path
5
- from random import shuffle
6
- from typing import Iterable, Literal
7
-
8
- import librosa
9
- import numpy as np
10
- import torch
11
- import torchaudio
12
- from joblib import Parallel, cpu_count, delayed
13
- from tqdm import tqdm
14
- from transformers import HubertModel
15
-
16
- import so_vits_svc_fork.f0
17
- from so_vits_svc_fork import utils
18
-
19
- from ..hparams import HParams
20
- from ..modules.mel_processing import spec_to_mel_torch, spectrogram_torch
21
- from ..utils import get_optimal_device, get_total_gpu_memory
22
- from .preprocess_utils import check_hubert_min_duration
23
-
24
- LOG = getLogger(__name__)
25
- HUBERT_MEMORY = 2900
26
- HUBERT_MEMORY_CREPE = 3900
27
-
28
-
29
- def _process_one(
30
- *,
31
- filepath: Path,
32
- content_model: HubertModel,
33
- device: torch.device | str = get_optimal_device(),
34
- f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
35
- force_rebuild: bool = False,
36
- hps: HParams,
37
- ):
38
- audio, sr = librosa.load(filepath, sr=hps.data.sampling_rate, mono=True)
39
-
40
- if not check_hubert_min_duration(audio, sr):
41
- LOG.info(f"Skip {filepath} because it is too short.")
42
- return
43
-
44
- data_path = filepath.parent / (filepath.name + ".data.pt")
45
- if data_path.exists() and not force_rebuild:
46
- return
47
-
48
- # Compute f0
49
- f0 = so_vits_svc_fork.f0.compute_f0(
50
- audio, sampling_rate=sr, hop_length=hps.data.hop_length, method=f0_method
51
- )
52
- f0, uv = so_vits_svc_fork.f0.interpolate_f0(f0)
53
- f0 = torch.from_numpy(f0).float()
54
- uv = torch.from_numpy(uv).float()
55
-
56
- # Compute HuBERT content
57
- audio = torch.from_numpy(audio).float().to(device)
58
- c = utils.get_content(
59
- content_model,
60
- audio,
61
- device,
62
- sr=sr,
63
- legacy_final_proj=hps.data.get("contentvec_final_proj", True),
64
- )
65
- c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0])
66
- torch.cuda.empty_cache()
67
-
68
- # Compute spectrogram
69
- audio, sr = torchaudio.load(filepath)
70
- spec = spectrogram_torch(audio, hps).squeeze(0)
71
- mel_spec = spec_to_mel_torch(spec, hps)
72
- torch.cuda.empty_cache()
73
-
74
- # fix lengths
75
- lmin = min(spec.shape[1], mel_spec.shape[1], f0.shape[0], uv.shape[0], c.shape[1])
76
- spec, mel_spec, f0, uv, c = (
77
- spec[:, :lmin],
78
- mel_spec[:, :lmin],
79
- f0[:lmin],
80
- uv[:lmin],
81
- c[:, :lmin],
82
- )
83
-
84
- # get speaker id
85
- spk_name = filepath.parent.name
86
- spk = hps.spk.__dict__[spk_name]
87
- spk = torch.tensor(spk).long()
88
- assert (
89
- spec.shape[1] == mel_spec.shape[1] == f0.shape[0] == uv.shape[0] == c.shape[1]
90
- ), (spec.shape, mel_spec.shape, f0.shape, uv.shape, c.shape)
91
- data = {
92
- "spec": spec,
93
- "mel_spec": mel_spec,
94
- "f0": f0,
95
- "uv": uv,
96
- "content": c,
97
- "audio": audio,
98
- "spk": spk,
99
- }
100
- data = {k: v.cpu() for k, v in data.items()}
101
- with data_path.open("wb") as f:
102
- torch.save(data, f)
103
-
104
-
105
- def _process_batch(filepaths: Iterable[Path], pbar_position: int, **kwargs):
106
- hps = kwargs["hps"]
107
- content_model = utils.get_hubert_model(
108
- get_optimal_device(), hps.data.get("contentvec_final_proj", True)
109
- )
110
-
111
- for filepath in tqdm(filepaths, position=pbar_position):
112
- _process_one(
113
- content_model=content_model,
114
- filepath=filepath,
115
- **kwargs,
116
- )
117
-
118
-
119
- def preprocess_hubert_f0(
120
- input_dir: Path | str,
121
- config_path: Path | str,
122
- n_jobs: int | None = None,
123
- f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
124
- force_rebuild: bool = False,
125
- ):
126
- input_dir = Path(input_dir)
127
- config_path = Path(config_path)
128
- hps = utils.get_hparams(config_path)
129
- if n_jobs is None:
130
- # add cpu_count() to avoid SIGKILL
131
- memory = get_total_gpu_memory("total")
132
- n_jobs = min(
133
- max(
134
- memory
135
- // (HUBERT_MEMORY_CREPE if f0_method == "crepe" else HUBERT_MEMORY)
136
- if memory is not None
137
- else 1,
138
- 1,
139
- ),
140
- cpu_count(),
141
- )
142
- LOG.info(f"n_jobs automatically set to {n_jobs}, memory: {memory} MiB")
143
-
144
- filepaths = list(input_dir.rglob("*.wav"))
145
- n_jobs = min(len(filepaths) // 16 + 1, n_jobs)
146
- shuffle(filepaths)
147
- filepath_chunks = np.array_split(filepaths, n_jobs)
148
- Parallel(n_jobs=n_jobs)(
149
- delayed(_process_batch)(
150
- filepaths=chunk,
151
- pbar_position=pbar_position,
152
- f0_method=f0_method,
153
- force_rebuild=force_rebuild,
154
- hps=hps,
155
- )
156
- for (pbar_position, chunk) in enumerate(filepath_chunks)
157
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/preprocessing/preprocess_resample.py DELETED
@@ -1,144 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import warnings
4
- from logging import getLogger
5
- from pathlib import Path
6
- from typing import Iterable
7
-
8
- import librosa
9
- import soundfile
10
- from joblib import Parallel, delayed
11
- from tqdm_joblib import tqdm_joblib
12
-
13
- from .preprocess_utils import check_hubert_min_duration
14
-
15
- LOG = getLogger(__name__)
16
-
17
- # input_dir and output_dir exists.
18
- # write code to convert input dir audio files to output dir audio files,
19
- # without changing folder structure. Use joblib to parallelize.
20
- # Converting audio files includes:
21
- # - resampling to specified sampling rate
22
- # - trim silence
23
- # - adjust volume in a smart way
24
- # - save as 16-bit wav file
25
-
26
-
27
- def _get_unique_filename(path: Path, existing_paths: Iterable[Path]) -> Path:
28
- """Return a unique path by appending a number to the original path."""
29
- if path not in existing_paths:
30
- return path
31
- i = 1
32
- while True:
33
- new_path = path.parent / f"{path.stem}_{i}{path.suffix}"
34
- if new_path not in existing_paths:
35
- return new_path
36
- i += 1
37
-
38
-
39
- def is_relative_to(path: Path, *other):
40
- """Return True if the path is relative to another path or False.
41
- Python 3.9+ has Path.is_relative_to() method, but we need to support Python 3.8.
42
- """
43
- try:
44
- path.relative_to(*other)
45
- return True
46
- except ValueError:
47
- return False
48
-
49
-
50
- def _preprocess_one(
51
- input_path: Path,
52
- output_path: Path,
53
- sr: int,
54
- *,
55
- top_db: int,
56
- frame_seconds: float,
57
- hop_seconds: float,
58
- ) -> None:
59
- """Preprocess one audio file."""
60
-
61
- try:
62
- audio, sr = librosa.load(input_path, sr=sr, mono=True)
63
-
64
- # Audioread is the last backend it will attempt, so this is the exception thrown on failure
65
- except Exception as e:
66
- # Failure due to attempting to load a file that is not audio, so return early
67
- LOG.warning(f"Failed to load {input_path} due to {e}")
68
- return
69
-
70
- if not check_hubert_min_duration(audio, sr):
71
- LOG.info(f"Skip {input_path} because it is too short.")
72
- return
73
-
74
- # Adjust volume
75
- audio /= max(audio.max(), -audio.min())
76
-
77
- # Trim silence
78
- audio, _ = librosa.effects.trim(
79
- audio,
80
- top_db=top_db,
81
- frame_length=int(frame_seconds * sr),
82
- hop_length=int(hop_seconds * sr),
83
- )
84
-
85
- if not check_hubert_min_duration(audio, sr):
86
- LOG.info(f"Skip {input_path} because it is too short.")
87
- return
88
-
89
- soundfile.write(output_path, audio, samplerate=sr, subtype="PCM_16")
90
-
91
-
92
- def preprocess_resample(
93
- input_dir: Path | str,
94
- output_dir: Path | str,
95
- sampling_rate: int,
96
- n_jobs: int = -1,
97
- *,
98
- top_db: int = 30,
99
- frame_seconds: float = 0.1,
100
- hop_seconds: float = 0.05,
101
- ) -> None:
102
- input_dir = Path(input_dir)
103
- output_dir = Path(output_dir)
104
- """Preprocess audio files in input_dir and save them to output_dir."""
105
-
106
- out_paths = []
107
- in_paths = list(input_dir.rglob("*.*"))
108
- if not in_paths:
109
- raise ValueError(f"No audio files found in {input_dir}")
110
- for in_path in in_paths:
111
- in_path_relative = in_path.relative_to(input_dir)
112
- if not in_path.is_absolute() and is_relative_to(
113
- in_path, Path("dataset_raw") / "44k"
114
- ):
115
- new_in_path_relative = in_path_relative.relative_to("44k")
116
- warnings.warn(
117
- f"Recommended folder structure has changed since v1.0.0. "
118
- "Please move your dataset directly under dataset_raw folder. "
119
- f"Recoginzed {in_path_relative} as {new_in_path_relative}"
120
- )
121
- in_path_relative = new_in_path_relative
122
-
123
- if len(in_path_relative.parts) < 2:
124
- continue
125
- speaker_name = in_path_relative.parts[0]
126
- file_name = in_path_relative.with_suffix(".wav").name
127
- out_path = output_dir / speaker_name / file_name
128
- out_path = _get_unique_filename(out_path, out_paths)
129
- out_path.parent.mkdir(parents=True, exist_ok=True)
130
- out_paths.append(out_path)
131
-
132
- in_and_out_paths = list(zip(in_paths, out_paths))
133
-
134
- with tqdm_joblib(desc="Preprocessing", total=len(in_and_out_paths)):
135
- Parallel(n_jobs=n_jobs)(
136
- delayed(_preprocess_one)(
137
- *args,
138
- sr=sampling_rate,
139
- top_db=top_db,
140
- frame_seconds=frame_seconds,
141
- hop_seconds=hop_seconds,
142
- )
143
- for args in in_and_out_paths
144
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/preprocessing/preprocess_speaker_diarization.py DELETED
@@ -1,93 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from collections import defaultdict
4
- from logging import getLogger
5
- from pathlib import Path
6
-
7
- import librosa
8
- import soundfile as sf
9
- import torch
10
- from joblib import Parallel, delayed
11
- from pyannote.audio import Pipeline
12
- from tqdm import tqdm
13
- from tqdm_joblib import tqdm_joblib
14
-
15
- LOG = getLogger(__name__)
16
-
17
-
18
- def _process_one(
19
- input_path: Path,
20
- output_dir: Path,
21
- sr: int,
22
- *,
23
- min_speakers: int = 1,
24
- max_speakers: int = 1,
25
- huggingface_token: str | None = None,
26
- ) -> None:
27
- try:
28
- audio, sr = librosa.load(input_path, sr=sr, mono=True)
29
- except Exception as e:
30
- LOG.warning(f"Failed to read {input_path}: {e}")
31
- return
32
- pipeline = Pipeline.from_pretrained(
33
- "pyannote/speaker-diarization", use_auth_token=huggingface_token
34
- )
35
- if pipeline is None:
36
- raise ValueError("Failed to load pipeline")
37
-
38
- LOG.info(f"Processing {input_path}. This may take a while...")
39
- diarization = pipeline(
40
- input_path, min_speakers=min_speakers, max_speakers=max_speakers
41
- )
42
-
43
- LOG.info(f"Found {len(diarization)} tracks, writing to {output_dir}")
44
- speaker_count = defaultdict(int)
45
-
46
- output_dir.mkdir(parents=True, exist_ok=True)
47
- for segment, track, speaker in tqdm(
48
- list(diarization.itertracks(yield_label=True)), desc=f"Writing {input_path}"
49
- ):
50
- if segment.end - segment.start < 1:
51
- continue
52
- speaker_count[speaker] += 1
53
- audio_cut = audio[int(segment.start * sr) : int(segment.end * sr)]
54
- sf.write(
55
- (output_dir / f"{speaker}_{speaker_count[speaker]}.wav"),
56
- audio_cut,
57
- sr,
58
- )
59
-
60
- LOG.info(f"Speaker count: {speaker_count}")
61
-
62
-
63
- def preprocess_speaker_diarization(
64
- input_dir: Path | str,
65
- output_dir: Path | str,
66
- sr: int,
67
- *,
68
- min_speakers: int = 1,
69
- max_speakers: int = 1,
70
- huggingface_token: str | None = None,
71
- n_jobs: int = -1,
72
- ) -> None:
73
- if huggingface_token is not None and not huggingface_token.startswith("hf_"):
74
- LOG.warning("Huggingface token probably should start with hf_")
75
- if not torch.cuda.is_available():
76
- LOG.warning("CUDA is not available. This will be extremely slow.")
77
- input_dir = Path(input_dir)
78
- output_dir = Path(output_dir)
79
- input_dir.mkdir(parents=True, exist_ok=True)
80
- output_dir.mkdir(parents=True, exist_ok=True)
81
- input_paths = list(input_dir.rglob("*.*"))
82
- with tqdm_joblib(desc="Preprocessing speaker diarization", total=len(input_paths)):
83
- Parallel(n_jobs=n_jobs)(
84
- delayed(_process_one)(
85
- input_path,
86
- output_dir / input_path.relative_to(input_dir).parent / input_path.stem,
87
- sr,
88
- max_speakers=max_speakers,
89
- min_speakers=min_speakers,
90
- huggingface_token=huggingface_token,
91
- )
92
- for input_path in input_paths
93
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/preprocessing/preprocess_split.py DELETED
@@ -1,78 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from logging import getLogger
4
- from pathlib import Path
5
-
6
- import librosa
7
- import soundfile as sf
8
- from joblib import Parallel, delayed
9
- from tqdm import tqdm
10
- from tqdm_joblib import tqdm_joblib
11
-
12
- LOG = getLogger(__name__)
13
-
14
-
15
- def _process_one(
16
- input_path: Path,
17
- output_dir: Path,
18
- sr: int,
19
- *,
20
- max_length: float = 10.0,
21
- top_db: int = 30,
22
- frame_seconds: float = 0.5,
23
- hop_seconds: float = 0.1,
24
- ):
25
- try:
26
- audio, sr = librosa.load(input_path, sr=sr, mono=True)
27
- except Exception as e:
28
- LOG.warning(f"Failed to read {input_path}: {e}")
29
- return
30
- intervals = librosa.effects.split(
31
- audio,
32
- top_db=top_db,
33
- frame_length=int(sr * frame_seconds),
34
- hop_length=int(sr * hop_seconds),
35
- )
36
- output_dir.mkdir(parents=True, exist_ok=True)
37
- for start, end in tqdm(intervals, desc=f"Writing {input_path}"):
38
- for sub_start in range(start, end, int(sr * max_length)):
39
- sub_end = min(sub_start + int(sr * max_length), end)
40
- audio_cut = audio[sub_start:sub_end]
41
- sf.write(
42
- (
43
- output_dir
44
- / f"{input_path.stem}_{sub_start / sr:.3f}_{sub_end / sr:.3f}.wav"
45
- ),
46
- audio_cut,
47
- sr,
48
- )
49
-
50
-
51
- def preprocess_split(
52
- input_dir: Path | str,
53
- output_dir: Path | str,
54
- sr: int,
55
- *,
56
- max_length: float = 10.0,
57
- top_db: int = 30,
58
- frame_seconds: float = 0.5,
59
- hop_seconds: float = 0.1,
60
- n_jobs: int = -1,
61
- ):
62
- input_dir = Path(input_dir)
63
- output_dir = Path(output_dir)
64
- output_dir.mkdir(parents=True, exist_ok=True)
65
- input_paths = list(input_dir.rglob("*.*"))
66
- with tqdm_joblib(desc="Splitting", total=len(input_paths)):
67
- Parallel(n_jobs=n_jobs)(
68
- delayed(_process_one)(
69
- input_path,
70
- output_dir / input_path.relative_to(input_dir).parent,
71
- sr,
72
- max_length=max_length,
73
- top_db=top_db,
74
- frame_seconds=frame_seconds,
75
- hop_seconds=hop_seconds,
76
- )
77
- for input_path in input_paths
78
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/preprocessing/preprocess_utils.py DELETED
@@ -1,5 +0,0 @@
1
- from numpy import ndarray
2
-
3
-
4
- def check_hubert_min_duration(audio: ndarray, sr: int) -> bool:
5
- return len(audio) / sr >= 0.3
 
 
 
 
 
 
so_vits_svc_fork/py.typed DELETED
File without changes
so_vits_svc_fork/train.py DELETED
@@ -1,571 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- import warnings
5
- from logging import getLogger
6
- from multiprocessing import cpu_count
7
- from pathlib import Path
8
- from typing import Any
9
-
10
- import lightning.pytorch as pl
11
- import torch
12
- from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator
13
- from lightning.pytorch.callbacks import DeviceStatsMonitor
14
- from lightning.pytorch.loggers import TensorBoardLogger
15
- from lightning.pytorch.strategies.ddp import DDPStrategy
16
- from lightning.pytorch.tuner import Tuner
17
- from torch.cuda.amp import autocast
18
- from torch.nn import functional as F
19
- from torch.utils.data import DataLoader
20
- from torch.utils.tensorboard.writer import SummaryWriter
21
-
22
- import so_vits_svc_fork.f0
23
- import so_vits_svc_fork.modules.commons as commons
24
- import so_vits_svc_fork.utils
25
-
26
- from . import utils
27
- from .dataset import TextAudioCollate, TextAudioDataset
28
- from .logger import is_notebook
29
- from .modules.descriminators import MultiPeriodDiscriminator
30
- from .modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
31
- from .modules.mel_processing import mel_spectrogram_torch
32
- from .modules.synthesizers import SynthesizerTrn
33
-
34
- LOG = getLogger(__name__)
35
- torch.set_float32_matmul_precision("high")
36
-
37
-
38
- class VCDataModule(pl.LightningDataModule):
39
- batch_size: int
40
-
41
- def __init__(self, hparams: Any):
42
- super().__init__()
43
- self.__hparams = hparams
44
- self.batch_size = hparams.train.batch_size
45
- if not isinstance(self.batch_size, int):
46
- self.batch_size = 1
47
- self.collate_fn = TextAudioCollate()
48
-
49
- # these should be called in setup(), but we need to calculate check_val_every_n_epoch
50
- self.train_dataset = TextAudioDataset(self.__hparams, is_validation=False)
51
- self.val_dataset = TextAudioDataset(self.__hparams, is_validation=True)
52
-
53
- def train_dataloader(self):
54
- return DataLoader(
55
- self.train_dataset,
56
- num_workers=min(cpu_count(), self.__hparams.train.get("num_workers", 8)),
57
- batch_size=self.batch_size,
58
- collate_fn=self.collate_fn,
59
- persistent_workers=True,
60
- )
61
-
62
- def val_dataloader(self):
63
- return DataLoader(
64
- self.val_dataset,
65
- batch_size=1,
66
- collate_fn=self.collate_fn,
67
- )
68
-
69
-
70
- def train(
71
- config_path: Path | str, model_path: Path | str, reset_optimizer: bool = False
72
- ):
73
- config_path = Path(config_path)
74
- model_path = Path(model_path)
75
-
76
- hparams = utils.get_backup_hparams(config_path, model_path)
77
- utils.ensure_pretrained_model(
78
- model_path,
79
- hparams.model.get(
80
- "pretrained",
81
- {
82
- "D_0.pth": "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/D_0.pth",
83
- "G_0.pth": "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/G_0.pth",
84
- },
85
- ),
86
- )
87
-
88
- datamodule = VCDataModule(hparams)
89
- strategy = (
90
- (
91
- "ddp_find_unused_parameters_true"
92
- if os.name != "nt"
93
- else DDPStrategy(find_unused_parameters=True, process_group_backend="gloo")
94
- )
95
- if torch.cuda.device_count() > 1
96
- else "auto"
97
- )
98
- LOG.info(f"Using strategy: {strategy}")
99
- trainer = pl.Trainer(
100
- logger=TensorBoardLogger(
101
- model_path, "lightning_logs", hparams.train.get("log_version", 0)
102
- ),
103
- # profiler="simple",
104
- val_check_interval=hparams.train.eval_interval,
105
- max_epochs=hparams.train.epochs,
106
- check_val_every_n_epoch=None,
107
- precision="16-mixed"
108
- if hparams.train.fp16_run
109
- else "bf16-mixed"
110
- if hparams.train.get("bf16_run", False)
111
- else 32,
112
- strategy=strategy,
113
- callbacks=([pl.callbacks.RichProgressBar()] if not is_notebook() else [])
114
- + [DeviceStatsMonitor()],
115
- benchmark=True,
116
- enable_checkpointing=False,
117
- )
118
- tuner = Tuner(trainer)
119
- model = VitsLightning(reset_optimizer=reset_optimizer, **hparams)
120
-
121
- # automatic batch size scaling
122
- batch_size = hparams.train.batch_size
123
- batch_split = str(batch_size).split("-")
124
- batch_size = batch_split[0]
125
- init_val = 2 if len(batch_split) <= 1 else int(batch_split[1])
126
- max_trials = 25 if len(batch_split) <= 2 else int(batch_split[2])
127
- if batch_size == "auto":
128
- batch_size = "binsearch"
129
- if batch_size in ["power", "binsearch"]:
130
- model.tuning = True
131
- tuner.scale_batch_size(
132
- model,
133
- mode=batch_size,
134
- datamodule=datamodule,
135
- steps_per_trial=1,
136
- init_val=init_val,
137
- max_trials=max_trials,
138
- )
139
- model.tuning = False
140
- else:
141
- batch_size = int(batch_size)
142
- # automatic learning rate scaling is not supported for multiple optimizers
143
- """if hparams.train.learning_rate == "auto":
144
- lr_finder = tuner.lr_find(model)
145
- LOG.info(lr_finder.results)
146
- fig = lr_finder.plot(suggest=True)
147
- fig.savefig(model_path / "lr_finder.png")"""
148
-
149
- trainer.fit(model, datamodule=datamodule)
150
-
151
-
152
- class VitsLightning(pl.LightningModule):
153
- def __init__(self, reset_optimizer: bool = False, **hparams: Any):
154
- super().__init__()
155
- self._temp_epoch = 0 # Add this line to initialize the _temp_epoch attribute
156
- self.save_hyperparameters("reset_optimizer")
157
- self.save_hyperparameters(*[k for k in hparams.keys()])
158
- torch.manual_seed(self.hparams.train.seed)
159
- self.net_g = SynthesizerTrn(
160
- self.hparams.data.filter_length // 2 + 1,
161
- self.hparams.train.segment_size // self.hparams.data.hop_length,
162
- **self.hparams.model,
163
- )
164
- self.net_d = MultiPeriodDiscriminator(self.hparams.model.use_spectral_norm)
165
- self.automatic_optimization = False
166
- self.learning_rate = self.hparams.train.learning_rate
167
- self.optim_g = torch.optim.AdamW(
168
- self.net_g.parameters(),
169
- self.learning_rate,
170
- betas=self.hparams.train.betas,
171
- eps=self.hparams.train.eps,
172
- )
173
- self.optim_d = torch.optim.AdamW(
174
- self.net_d.parameters(),
175
- self.learning_rate,
176
- betas=self.hparams.train.betas,
177
- eps=self.hparams.train.eps,
178
- )
179
- self.scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
180
- self.optim_g, gamma=self.hparams.train.lr_decay
181
- )
182
- self.scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
183
- self.optim_d, gamma=self.hparams.train.lr_decay
184
- )
185
- self.optimizers_count = 2
186
- self.load(reset_optimizer)
187
- self.tuning = False
188
-
189
- def on_train_start(self) -> None:
190
- if not self.tuning:
191
- self.set_current_epoch(self._temp_epoch)
192
- total_batch_idx = self._temp_epoch * len(self.trainer.train_dataloader)
193
- self.set_total_batch_idx(total_batch_idx)
194
- global_step = total_batch_idx * self.optimizers_count
195
- self.set_global_step(global_step)
196
-
197
- # check if using tpu or mps
198
- if isinstance(self.trainer.accelerator, (TPUAccelerator, MPSAccelerator)):
199
- # patch torch.stft to use cpu
200
- LOG.warning("Using TPU/MPS. Patching torch.stft to use cpu.")
201
-
202
- def stft(
203
- input: torch.Tensor,
204
- n_fft: int,
205
- hop_length: int | None = None,
206
- win_length: int | None = None,
207
- window: torch.Tensor | None = None,
208
- center: bool = True,
209
- pad_mode: str = "reflect",
210
- normalized: bool = False,
211
- onesided: bool | None = None,
212
- return_complex: bool | None = None,
213
- ) -> torch.Tensor:
214
- device = input.device
215
- input = input.cpu()
216
- if window is not None:
217
- window = window.cpu()
218
- return torch.functional.stft(
219
- input,
220
- n_fft,
221
- hop_length,
222
- win_length,
223
- window,
224
- center,
225
- pad_mode,
226
- normalized,
227
- onesided,
228
- return_complex,
229
- ).to(device)
230
-
231
- torch.stft = stft
232
-
233
- elif "bf" in self.trainer.precision:
234
- LOG.warning("Using bf. Patching torch.stft to use fp32.")
235
-
236
- def stft(
237
- input: torch.Tensor,
238
- n_fft: int,
239
- hop_length: int | None = None,
240
- win_length: int | None = None,
241
- window: torch.Tensor | None = None,
242
- center: bool = True,
243
- pad_mode: str = "reflect",
244
- normalized: bool = False,
245
- onesided: bool | None = None,
246
- return_complex: bool | None = None,
247
- ) -> torch.Tensor:
248
- dtype = input.dtype
249
- input = input.float()
250
- if window is not None:
251
- window = window.float()
252
- return torch.functional.stft(
253
- input,
254
- n_fft,
255
- hop_length,
256
- win_length,
257
- window,
258
- center,
259
- pad_mode,
260
- normalized,
261
- onesided,
262
- return_complex,
263
- ).to(dtype)
264
-
265
- torch.stft = stft
266
-
267
- def on_train_end(self) -> None:
268
- self.save_checkpoints(adjust=0)
269
-
270
- def save_checkpoints(self, adjust=1):
271
- if self.tuning or self.trainer.sanity_checking:
272
- return
273
-
274
- # only save checkpoints if we are on the main device
275
- if (
276
- hasattr(self.device, "index")
277
- and self.device.index != None
278
- and self.device.index != 0
279
- ):
280
- return
281
-
282
- # `on_train_end` will be the actual epoch, not a -1, so we have to call it with `adjust = 0`
283
- current_epoch = self.current_epoch + adjust
284
- total_batch_idx = self.total_batch_idx - 1 + adjust
285
-
286
- utils.save_checkpoint(
287
- self.net_g,
288
- self.optim_g,
289
- self.learning_rate,
290
- current_epoch,
291
- Path(self.hparams.model_dir)
292
- / f"G_{total_batch_idx if self.hparams.train.get('ckpt_name_by_step', False) else current_epoch}.pth",
293
- )
294
- utils.save_checkpoint(
295
- self.net_d,
296
- self.optim_d,
297
- self.learning_rate,
298
- current_epoch,
299
- Path(self.hparams.model_dir)
300
- / f"D_{total_batch_idx if self.hparams.train.get('ckpt_name_by_step', False) else current_epoch}.pth",
301
- )
302
- keep_ckpts = self.hparams.train.get("keep_ckpts", 0)
303
- if keep_ckpts > 0:
304
- utils.clean_checkpoints(
305
- path_to_models=self.hparams.model_dir,
306
- n_ckpts_to_keep=keep_ckpts,
307
- sort_by_time=True,
308
- )
309
-
310
- def set_current_epoch(self, epoch: int):
311
- LOG.info(f"Setting current epoch to {epoch}")
312
- self.trainer.fit_loop.epoch_progress.current.completed = epoch
313
- self.trainer.fit_loop.epoch_progress.current.processed = epoch
314
- assert self.current_epoch == epoch, f"{self.current_epoch} != {epoch}"
315
-
316
- def set_global_step(self, global_step: int):
317
- LOG.info(f"Setting global step to {global_step}")
318
- self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed = (
319
- global_step
320
- )
321
- self.trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = (
322
- global_step
323
- )
324
- assert self.global_step == global_step, f"{self.global_step} != {global_step}"
325
-
326
- def set_total_batch_idx(self, total_batch_idx: int):
327
- LOG.info(f"Setting total batch idx to {total_batch_idx}")
328
- self.trainer.fit_loop.epoch_loop.batch_progress.total.ready = (
329
- total_batch_idx + 1
330
- )
331
- self.trainer.fit_loop.epoch_loop.batch_progress.total.completed = (
332
- total_batch_idx
333
- )
334
- assert (
335
- self.total_batch_idx == total_batch_idx + 1
336
- ), f"{self.total_batch_idx} != {total_batch_idx + 1}"
337
-
338
- @property
339
- def total_batch_idx(self) -> int:
340
- return self.trainer.fit_loop.epoch_loop.total_batch_idx + 1
341
-
342
- def load(self, reset_optimizer: bool = False):
343
- latest_g_path = utils.latest_checkpoint_path(self.hparams.model_dir, "G_*.pth")
344
- latest_d_path = utils.latest_checkpoint_path(self.hparams.model_dir, "D_*.pth")
345
- if latest_g_path is not None and latest_d_path is not None:
346
- try:
347
- _, _, _, epoch = utils.load_checkpoint(
348
- latest_g_path,
349
- self.net_g,
350
- self.optim_g,
351
- reset_optimizer,
352
- )
353
- _, _, _, epoch = utils.load_checkpoint(
354
- latest_d_path,
355
- self.net_d,
356
- self.optim_d,
357
- reset_optimizer,
358
- )
359
- self._temp_epoch = epoch
360
- self.scheduler_g.last_epoch = epoch - 1
361
- self.scheduler_d.last_epoch = epoch - 1
362
- except Exception as e:
363
- raise RuntimeError("Failed to load checkpoint") from e
364
- else:
365
- LOG.warning("No checkpoint found. Start from scratch.")
366
-
367
- def configure_optimizers(self):
368
- return [self.optim_g, self.optim_d], [self.scheduler_g, self.scheduler_d]
369
-
370
- def log_image_dict(
371
- self, image_dict: dict[str, Any], dataformats: str = "HWC"
372
- ) -> None:
373
- if not isinstance(self.logger, TensorBoardLogger):
374
- warnings.warn("Image logging is only supported with TensorBoardLogger.")
375
- return
376
- writer: SummaryWriter = self.logger.experiment
377
- for k, v in image_dict.items():
378
- try:
379
- writer.add_image(k, v, self.total_batch_idx, dataformats=dataformats)
380
- except Exception as e:
381
- warnings.warn(f"Failed to log image {k}: {e}")
382
-
383
- def log_audio_dict(self, audio_dict: dict[str, Any]) -> None:
384
- if not isinstance(self.logger, TensorBoardLogger):
385
- warnings.warn("Audio logging is only supported with TensorBoardLogger.")
386
- return
387
- writer: SummaryWriter = self.logger.experiment
388
- for k, v in audio_dict.items():
389
- writer.add_audio(
390
- k,
391
- v.float(),
392
- self.total_batch_idx,
393
- sample_rate=self.hparams.data.sampling_rate,
394
- )
395
-
396
- def log_dict_(self, log_dict: dict[str, Any], **kwargs) -> None:
397
- if not isinstance(self.logger, TensorBoardLogger):
398
- warnings.warn("Logging is only supported with TensorBoardLogger.")
399
- return
400
- writer: SummaryWriter = self.logger.experiment
401
- for k, v in log_dict.items():
402
- writer.add_scalar(k, v, self.total_batch_idx)
403
- kwargs["logger"] = False
404
- self.log_dict(log_dict, **kwargs)
405
-
406
- def log_(self, key: str, value: Any, **kwargs) -> None:
407
- self.log_dict_({key: value}, **kwargs)
408
-
409
- def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None:
410
- self.net_g.train()
411
- self.net_d.train()
412
-
413
- # get optims
414
- optim_g, optim_d = self.optimizers()
415
-
416
- # Generator
417
- # train
418
- self.toggle_optimizer(optim_g)
419
- c, f0, spec, mel, y, g, lengths, uv = batch
420
- (
421
- y_hat,
422
- y_hat_mb,
423
- ids_slice,
424
- z_mask,
425
- (z, z_p, m_p, logs_p, m_q, logs_q),
426
- pred_lf0,
427
- norm_lf0,
428
- lf0,
429
- ) = self.net_g(c, f0, uv, spec, g=g, c_lengths=lengths, spec_lengths=lengths)
430
- y_mel = commons.slice_segments(
431
- mel,
432
- ids_slice,
433
- self.hparams.train.segment_size // self.hparams.data.hop_length,
434
- )
435
- y_hat_mel = mel_spectrogram_torch(y_hat.squeeze(1), self.hparams)
436
- y_mel = y_mel[..., : y_hat_mel.shape[-1]]
437
- y = commons.slice_segments(
438
- y,
439
- ids_slice * self.hparams.data.hop_length,
440
- self.hparams.train.segment_size,
441
- )
442
- y = y[..., : y_hat.shape[-1]]
443
-
444
- # generator loss
445
- y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.net_d(y, y_hat)
446
-
447
- with autocast(enabled=False):
448
- loss_mel = F.l1_loss(y_mel, y_hat_mel) * self.hparams.train.c_mel
449
- loss_kl = (
450
- kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * self.hparams.train.c_kl
451
- )
452
- loss_fm = feature_loss(fmap_r, fmap_g)
453
- loss_gen, losses_gen = generator_loss(y_d_hat_g)
454
- loss_lf0 = F.mse_loss(pred_lf0, lf0)
455
- loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + loss_lf0
456
-
457
- # MB-iSTFT-VITS
458
- loss_subband = torch.tensor(0.0)
459
- if self.hparams.model.get("type_") == "mb-istft":
460
- from .modules.decoders.mb_istft import PQMF, subband_stft_loss
461
-
462
- y_mb = PQMF(y.device, self.hparams.model.subbands).analysis(y)
463
- loss_subband = subband_stft_loss(self.hparams, y_mb, y_hat_mb)
464
- loss_gen_all += loss_subband
465
-
466
- # log loss
467
- self.log_("lr", self.optim_g.param_groups[0]["lr"])
468
- self.log_dict_(
469
- {
470
- "loss/g/total": loss_gen_all,
471
- "loss/g/fm": loss_fm,
472
- "loss/g/mel": loss_mel,
473
- "loss/g/kl": loss_kl,
474
- "loss/g/lf0": loss_lf0,
475
- },
476
- prog_bar=True,
477
- )
478
- if self.hparams.model.get("type_") == "mb-istft":
479
- self.log_("loss/g/subband", loss_subband)
480
- if self.total_batch_idx % self.hparams.train.log_interval == 0:
481
- self.log_image_dict(
482
- {
483
- "slice/mel_org": utils.plot_spectrogram_to_numpy(
484
- y_mel[0].data.cpu().float().numpy()
485
- ),
486
- "slice/mel_gen": utils.plot_spectrogram_to_numpy(
487
- y_hat_mel[0].data.cpu().float().numpy()
488
- ),
489
- "all/mel": utils.plot_spectrogram_to_numpy(
490
- mel[0].data.cpu().float().numpy()
491
- ),
492
- "all/lf0": so_vits_svc_fork.utils.plot_data_to_numpy(
493
- lf0[0, 0, :].cpu().float().numpy(),
494
- pred_lf0[0, 0, :].detach().cpu().float().numpy(),
495
- ),
496
- "all/norm_lf0": so_vits_svc_fork.utils.plot_data_to_numpy(
497
- lf0[0, 0, :].cpu().float().numpy(),
498
- norm_lf0[0, 0, :].detach().cpu().float().numpy(),
499
- ),
500
- }
501
- )
502
-
503
- accumulate_grad_batches = self.hparams.train.get("accumulate_grad_batches", 1)
504
- should_update = (
505
- batch_idx + 1
506
- ) % accumulate_grad_batches == 0 or self.trainer.is_last_batch
507
- # optimizer
508
- self.manual_backward(loss_gen_all / accumulate_grad_batches)
509
- if should_update:
510
- self.log_(
511
- "grad_norm_g", commons.clip_grad_value_(self.net_g.parameters(), None)
512
- )
513
- optim_g.step()
514
- optim_g.zero_grad()
515
- self.untoggle_optimizer(optim_g)
516
-
517
- # Discriminator
518
- # train
519
- self.toggle_optimizer(optim_d)
520
- y_d_hat_r, y_d_hat_g, _, _ = self.net_d(y, y_hat.detach())
521
-
522
- # discriminator loss
523
- with autocast(enabled=False):
524
- loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
525
- y_d_hat_r, y_d_hat_g
526
- )
527
- loss_disc_all = loss_disc
528
-
529
- # log loss
530
- self.log_("loss/d/total", loss_disc_all, prog_bar=True)
531
-
532
- # optimizer
533
- self.manual_backward(loss_disc_all / accumulate_grad_batches)
534
- if should_update:
535
- self.log_(
536
- "grad_norm_d", commons.clip_grad_value_(self.net_d.parameters(), None)
537
- )
538
- optim_d.step()
539
- optim_d.zero_grad()
540
- self.untoggle_optimizer(optim_d)
541
-
542
- # end of epoch
543
- if self.trainer.is_last_batch:
544
- self.scheduler_g.step()
545
- self.scheduler_d.step()
546
-
547
- def validation_step(self, batch, batch_idx):
548
- # avoid logging with wrong global step
549
- if self.global_step == 0:
550
- return
551
- with torch.no_grad():
552
- self.net_g.eval()
553
- c, f0, _, mel, y, g, _, uv = batch
554
- y_hat = self.net_g.infer(c, f0, uv, g=g)
555
- y_hat_mel = mel_spectrogram_torch(y_hat.squeeze(1).float(), self.hparams)
556
- self.log_audio_dict(
557
- {f"gen/audio_{batch_idx}": y_hat[0], f"gt/audio_{batch_idx}": y[0]}
558
- )
559
- self.log_image_dict(
560
- {
561
- "gen/mel": utils.plot_spectrogram_to_numpy(
562
- y_hat_mel[0].cpu().float().numpy()
563
- ),
564
- "gt/mel": utils.plot_spectrogram_to_numpy(
565
- mel[0].cpu().float().numpy()
566
- ),
567
- }
568
- )
569
-
570
- def on_validation_end(self) -> None:
571
- self.save_checkpoints()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
so_vits_svc_fork/utils.py DELETED
@@ -1,478 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import os
5
- import re
6
- import subprocess
7
- import warnings
8
- from itertools import groupby
9
- from logging import getLogger
10
- from pathlib import Path
11
- from typing import Any, Literal, Sequence
12
-
13
- import matplotlib
14
- import matplotlib.pylab as plt
15
- import numpy as np
16
- import requests
17
- import torch
18
- import torch.backends.mps
19
- import torch.nn as nn
20
- import torchaudio
21
- from cm_time import timer
22
- from numpy import ndarray
23
- from tqdm import tqdm
24
- from transformers import HubertModel
25
-
26
- from so_vits_svc_fork.hparams import HParams
27
-
28
- LOG = getLogger(__name__)
29
- HUBERT_SAMPLING_RATE = 16000
30
- IS_COLAB = os.getenv("COLAB_RELEASE_TAG", False)
31
-
32
-
33
- def get_optimal_device(index: int = 0) -> torch.device:
34
- if torch.cuda.is_available():
35
- return torch.device(f"cuda:{index % torch.cuda.device_count()}")
36
- elif torch.backends.mps.is_available():
37
- return torch.device("mps")
38
- else:
39
- try:
40
- import torch_xla.core.xla_model as xm # noqa
41
-
42
- if xm.xrt_world_size() > 0:
43
- return torch.device("xla")
44
- # return xm.xla_device()
45
- except ImportError:
46
- pass
47
- return torch.device("cpu")
48
-
49
-
50
- def download_file(
51
- url: str,
52
- filepath: Path | str,
53
- chunk_size: int = 64 * 1024,
54
- tqdm_cls: type = tqdm,
55
- skip_if_exists: bool = False,
56
- overwrite: bool = False,
57
- **tqdm_kwargs: Any,
58
- ):
59
- if skip_if_exists is True and overwrite is True:
60
- raise ValueError("skip_if_exists and overwrite cannot be both True")
61
- filepath = Path(filepath)
62
- filepath.parent.mkdir(parents=True, exist_ok=True)
63
- temppath = filepath.parent / f"{filepath.name}.download"
64
- if filepath.exists():
65
- if skip_if_exists:
66
- return
67
- elif not overwrite:
68
- filepath.unlink()
69
- else:
70
- raise FileExistsError(f"{filepath} already exists")
71
- temppath.unlink(missing_ok=True)
72
- resp = requests.get(url, stream=True)
73
- total = int(resp.headers.get("content-length", 0))
74
- kwargs = dict(
75
- total=total,
76
- unit="iB",
77
- unit_scale=True,
78
- unit_divisor=1024,
79
- desc=f"Downloading {filepath.name}",
80
- )
81
- kwargs.update(tqdm_kwargs)
82
- with temppath.open("wb") as f, tqdm_cls(**kwargs) as pbar:
83
- for data in resp.iter_content(chunk_size=chunk_size):
84
- size = f.write(data)
85
- pbar.update(size)
86
- temppath.rename(filepath)
87
-
88
-
89
- PRETRAINED_MODEL_URLS = {
90
- "hifi-gan": [
91
- [
92
- "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/D_0.pth",
93
- "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/G_0.pth",
94
- ],
95
- [
96
- "https://huggingface.co/Himawari00/so-vits-svc4.0-pretrain-models/resolve/main/D_0.pth",
97
- "https://huggingface.co/Himawari00/so-vits-svc4.0-pretrain-models/resolve/main/G_0.pth",
98
- ],
99
- ],
100
- "contentvec": [
101
- [
102
- "https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/checkpoint_best_legacy_500.pt"
103
- ],
104
- [
105
- "https://huggingface.co/Himawari00/so-vits-svc4.0-pretrain-models/resolve/main/checkpoint_best_legacy_500.pt"
106
- ],
107
- [
108
- "http://obs.cstcloud.cn/share/obs/sankagenkeshi/checkpoint_best_legacy_500.pt"
109
- ],
110
- ],
111
- }
112
- from joblib import Parallel, delayed
113
-
114
-
115
- def ensure_pretrained_model(
116
- folder_path: Path | str, type_: str | dict[str, str], **tqdm_kwargs: Any
117
- ) -> tuple[Path, ...] | None:
118
- folder_path = Path(folder_path)
119
-
120
- # new code
121
- if not isinstance(type_, str):
122
- try:
123
- Parallel(n_jobs=len(type_))(
124
- [
125
- delayed(download_file)(
126
- url,
127
- folder_path / filename,
128
- position=i,
129
- skip_if_exists=True,
130
- **tqdm_kwargs,
131
- )
132
- for i, (filename, url) in enumerate(type_.items())
133
- ]
134
- )
135
- return tuple(folder_path / filename for filename in type_.values())
136
- except Exception as e:
137
- LOG.error(f"Failed to download {type_}")
138
- LOG.exception(e)
139
-
140
- # old code
141
- models_candidates = PRETRAINED_MODEL_URLS.get(type_, None)
142
- if models_candidates is None:
143
- LOG.warning(f"Unknown pretrained model type: {type_}")
144
- return
145
- for model_urls in models_candidates:
146
- paths = [folder_path / model_url.split("/")[-1] for model_url in model_urls]
147
- try:
148
- Parallel(n_jobs=len(paths))(
149
- [
150
- delayed(download_file)(
151
- url, path, position=i, skip_if_exists=True, **tqdm_kwargs
152
- )
153
- for i, (url, path) in enumerate(zip(model_urls, paths))
154
- ]
155
- )
156
- return tuple(paths)
157
- except Exception as e:
158
- LOG.error(f"Failed to download {model_urls}")
159
- LOG.exception(e)
160
-
161
-
162
- class HubertModelWithFinalProj(HubertModel):
163
- def __init__(self, config):
164
- super().__init__(config)
165
-
166
- # The final projection layer is only used for backward compatibility.
167
- # Following https://github.com/auspicious3000/contentvec/issues/6
168
- # Remove this layer is necessary to achieve the desired outcome.
169
- self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
170
-
171
-
172
- def remove_weight_norm_if_exists(module, name: str = "weight"):
173
- r"""Removes the weight normalization reparameterization from a module.
174
-
175
- Args:
176
- module (Module): containing module
177
- name (str, optional): name of weight parameter
178
-
179
- Example:
180
- >>> m = weight_norm(nn.Linear(20, 40))
181
- >>> remove_weight_norm(m)
182
- """
183
- from torch.nn.utils.weight_norm import WeightNorm
184
-
185
- for k, hook in module._forward_pre_hooks.items():
186
- if isinstance(hook, WeightNorm) and hook.name == name:
187
- hook.remove(module)
188
- del module._forward_pre_hooks[k]
189
- return module
190
-
191
-
192
- def get_hubert_model(
193
- device: str | torch.device, final_proj: bool = True
194
- ) -> HubertModel:
195
- if final_proj:
196
- model = HubertModelWithFinalProj.from_pretrained("lengyue233/content-vec-best")
197
- else:
198
- model = HubertModel.from_pretrained("lengyue233/content-vec-best")
199
- # Hubert is always used in inference mode, we can safely remove weight-norms
200
- for m in model.modules():
201
- if isinstance(m, (nn.Conv2d, nn.Conv1d)):
202
- remove_weight_norm_if_exists(m)
203
-
204
- return model.to(device)
205
-
206
-
207
- def get_content(
208
- cmodel: HubertModel,
209
- audio: torch.Tensor | ndarray[Any, Any],
210
- device: torch.device | str,
211
- sr: int,
212
- legacy_final_proj: bool = False,
213
- ) -> torch.Tensor:
214
- audio = torch.as_tensor(audio)
215
- if sr != HUBERT_SAMPLING_RATE:
216
- audio = (
217
- torchaudio.transforms.Resample(sr, HUBERT_SAMPLING_RATE)
218
- .to(audio.device)(audio)
219
- .to(device)
220
- )
221
- if audio.ndim == 1:
222
- audio = audio.unsqueeze(0)
223
- with torch.no_grad(), timer() as t:
224
- if legacy_final_proj:
225
- warnings.warn("legacy_final_proj is deprecated")
226
- if not hasattr(cmodel, "final_proj"):
227
- raise ValueError("HubertModel does not have final_proj")
228
- c = cmodel(audio, output_hidden_states=True)["hidden_states"][9]
229
- c = cmodel.final_proj(c)
230
- else:
231
- c = cmodel(audio)["last_hidden_state"]
232
- c = c.transpose(1, 2)
233
- wav_len = audio.shape[-1] / HUBERT_SAMPLING_RATE
234
- LOG.info(
235
- f"HuBERT inference time : {t.elapsed:.3f}s, RTF: {t.elapsed / wav_len:.3f}"
236
- )
237
- return c
238
-
239
-
240
- def _substitute_if_same_shape(to_: dict[str, Any], from_: dict[str, Any]) -> None:
241
- not_in_to = list(filter(lambda x: x not in to_, from_.keys()))
242
- not_in_from = list(filter(lambda x: x not in from_, to_.keys()))
243
- if not_in_to:
244
- warnings.warn(f"Keys not found in model state dict:" f"{not_in_to}")
245
- if not_in_from:
246
- warnings.warn(f"Keys not found in checkpoint state dict:" f"{not_in_from}")
247
- shape_missmatch = []
248
- for k, v in from_.items():
249
- if k not in to_:
250
- pass
251
- elif hasattr(v, "shape"):
252
- if not hasattr(to_[k], "shape"):
253
- raise ValueError(f"Key {k} is not a tensor")
254
- if to_[k].shape == v.shape:
255
- to_[k] = v
256
- else:
257
- shape_missmatch.append((k, to_[k].shape, v.shape))
258
- elif isinstance(v, dict):
259
- assert isinstance(to_[k], dict)
260
- _substitute_if_same_shape(to_[k], v)
261
- else:
262
- to_[k] = v
263
- if shape_missmatch:
264
- warnings.warn(
265
- f"Shape mismatch: {[f'{k}: {v1} -> {v2}' for k, v1, v2 in shape_missmatch]}"
266
- )
267
-
268
-
269
- def safe_load(model: torch.nn.Module, state_dict: dict[str, Any]) -> None:
270
- model_state_dict = model.state_dict()
271
- _substitute_if_same_shape(model_state_dict, state_dict)
272
- model.load_state_dict(model_state_dict)
273
-
274
-
275
- def load_checkpoint(
276
- checkpoint_path: Path | str,
277
- model: torch.nn.Module,
278
- optimizer: torch.optim.Optimizer | None = None,
279
- skip_optimizer: bool = False,
280
- ) -> tuple[torch.nn.Module, torch.optim.Optimizer | None, float, int]:
281
- if not Path(checkpoint_path).is_file():
282
- raise FileNotFoundError(f"File {checkpoint_path} not found")
283
- with Path(checkpoint_path).open("rb") as f:
284
- with warnings.catch_warnings():
285
- warnings.filterwarnings(
286
- "ignore", category=UserWarning, message="TypedStorage is deprecated"
287
- )
288
- checkpoint_dict = torch.load(f, map_location="cpu", weights_only=True)
289
- iteration = checkpoint_dict["iteration"]
290
- learning_rate = checkpoint_dict["learning_rate"]
291
-
292
- # safe load module
293
- if hasattr(model, "module"):
294
- safe_load(model.module, checkpoint_dict["model"])
295
- else:
296
- safe_load(model, checkpoint_dict["model"])
297
- # safe load optim
298
- if (
299
- optimizer is not None
300
- and not skip_optimizer
301
- and checkpoint_dict["optimizer"] is not None
302
- ):
303
- with warnings.catch_warnings():
304
- warnings.simplefilter("ignore")
305
- safe_load(optimizer, checkpoint_dict["optimizer"])
306
-
307
- LOG.info(f"Loaded checkpoint '{checkpoint_path}' (epoch {iteration})")
308
- return model, optimizer, learning_rate, iteration
309
-
310
-
311
- def save_checkpoint(
312
- model: torch.nn.Module,
313
- optimizer: torch.optim.Optimizer,
314
- learning_rate: float,
315
- iteration: int,
316
- checkpoint_path: Path | str,
317
- ) -> None:
318
- LOG.info(
319
- "Saving model and optimizer state at epoch {} to {}".format(
320
- iteration, checkpoint_path
321
- )
322
- )
323
- if hasattr(model, "module"):
324
- state_dict = model.module.state_dict()
325
- else:
326
- state_dict = model.state_dict()
327
- with Path(checkpoint_path).open("wb") as f:
328
- torch.save(
329
- {
330
- "model": state_dict,
331
- "iteration": iteration,
332
- "optimizer": optimizer.state_dict(),
333
- "learning_rate": learning_rate,
334
- },
335
- f,
336
- )
337
-
338
-
339
- def clean_checkpoints(
340
- path_to_models: Path | str, n_ckpts_to_keep: int = 2, sort_by_time: bool = True
341
- ) -> None:
342
- """Freeing up space by deleting saved ckpts
343
-
344
- Arguments:
345
- path_to_models -- Path to the model directory
346
- n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
347
- sort_by_time -- True -> chronologically delete ckpts
348
- False -> lexicographically delete ckpts
349
- """
350
- LOG.info("Cleaning old checkpoints...")
351
- path_to_models = Path(path_to_models)
352
-
353
- # Define sort key functions
354
- name_key = lambda p: int(re.match(r"[GD]_(\d+)", p.stem).group(1))
355
- time_key = lambda p: p.stat().st_mtime
356
- path_key = lambda p: (p.stem[0], time_key(p) if sort_by_time else name_key(p))
357
-
358
- models = list(
359
- filter(
360
- lambda p: (
361
- p.is_file()
362
- and re.match(r"[GD]_\d+", p.stem)
363
- and not p.stem.endswith("_0")
364
- ),
365
- path_to_models.glob("*.pth"),
366
- )
367
- )
368
-
369
- models_sorted = sorted(models, key=path_key)
370
-
371
- models_sorted_grouped = groupby(models_sorted, lambda p: p.stem[0])
372
-
373
- for group_name, group_items in models_sorted_grouped:
374
- to_delete_list = list(group_items)[:-n_ckpts_to_keep]
375
-
376
- for to_delete in to_delete_list:
377
- if to_delete.exists():
378
- LOG.info(f"Removing {to_delete}")
379
- if IS_COLAB:
380
- to_delete.write_text("")
381
- to_delete.unlink()
382
-
383
-
384
- def latest_checkpoint_path(dir_path: Path | str, regex: str = "G_*.pth") -> Path | None:
385
- dir_path = Path(dir_path)
386
- name_key = lambda p: int(re.match(r"._(\d+)\.pth", p.name).group(1))
387
- paths = list(sorted(dir_path.glob(regex), key=name_key))
388
- if len(paths) == 0:
389
- return None
390
- return paths[-1]
391
-
392
-
393
- def plot_spectrogram_to_numpy(spectrogram: ndarray) -> ndarray:
394
- matplotlib.use("Agg")
395
- fig, ax = plt.subplots(figsize=(10, 2))
396
- im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
397
- plt.colorbar(im, ax=ax)
398
- plt.xlabel("Frames")
399
- plt.ylabel("Channels")
400
- plt.tight_layout()
401
-
402
- fig.canvas.draw()
403
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
404
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
405
- plt.close()
406
- return data
407
-
408
-
409
- def get_backup_hparams(
410
- config_path: Path, model_path: Path, init: bool = True
411
- ) -> HParams:
412
- model_path.mkdir(parents=True, exist_ok=True)
413
- config_save_path = model_path / "config.json"
414
- if init:
415
- with config_path.open() as f:
416
- data = f.read()
417
- with config_save_path.open("w") as f:
418
- f.write(data)
419
- else:
420
- with config_save_path.open() as f:
421
- data = f.read()
422
- config = json.loads(data)
423
-
424
- hparams = HParams(**config)
425
- hparams.model_dir = model_path.as_posix()
426
- return hparams
427
-
428
-
429
- def get_hparams(config_path: Path | str) -> HParams:
430
- config = json.loads(Path(config_path).read_text("utf-8"))
431
- hparams = HParams(**config)
432
- return hparams
433
-
434
-
435
- def repeat_expand_2d(content: torch.Tensor, target_len: int) -> torch.Tensor:
436
- # content : [h, t]
437
- src_len = content.shape[-1]
438
- if target_len < src_len:
439
- return content[:, :target_len]
440
- else:
441
- return torch.nn.functional.interpolate(
442
- content.unsqueeze(0), size=target_len, mode="nearest"
443
- ).squeeze(0)
444
-
445
-
446
- def plot_data_to_numpy(x: ndarray, y: ndarray) -> ndarray:
447
- matplotlib.use("Agg")
448
- fig, ax = plt.subplots(figsize=(10, 2))
449
- plt.plot(x)
450
- plt.plot(y)
451
- plt.tight_layout()
452
-
453
- fig.canvas.draw()
454
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
455
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
456
- plt.close()
457
- return data
458
-
459
-
460
- def get_gpu_memory(type_: Literal["total", "free", "used"]) -> Sequence[int] | None:
461
- command = f"nvidia-smi --query-gpu=memory.{type_} --format=csv"
462
- try:
463
- memory_free_info = (
464
- subprocess.check_output(command.split())
465
- .decode("ascii")
466
- .split("\n")[:-1][1:]
467
- )
468
- memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
469
- return memory_free_values
470
- except Exception:
471
- return
472
-
473
-
474
- def get_total_gpu_memory(type_: Literal["total", "free", "used"]) -> int | None:
475
- memories = get_gpu_memory(type_)
476
- if memories is None:
477
- return
478
- return sum(memories)