Spaces:
Running
on
Zero
Running
on
Zero
Upload 211 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- audiotools/__init__.py +10 -0
- audiotools/core/__init__.py +4 -0
- audiotools/core/audio_signal.py +1682 -0
- audiotools/core/display.py +194 -0
- audiotools/core/dsp.py +390 -0
- audiotools/core/effects.py +647 -0
- audiotools/core/ffmpeg.py +204 -0
- audiotools/core/loudness.py +320 -0
- audiotools/core/playback.py +252 -0
- audiotools/core/templates/__init__.py +0 -0
- audiotools/core/templates/headers.html +322 -0
- audiotools/core/templates/pandoc.css +407 -0
- audiotools/core/templates/widget.html +52 -0
- audiotools/core/util.py +671 -0
- audiotools/core/whisper.py +97 -0
- audiotools/data/__init__.py +3 -0
- audiotools/data/datasets.py +517 -0
- audiotools/data/preprocess.py +81 -0
- audiotools/data/transforms.py +1592 -0
- audiotools/metrics/__init__.py +6 -0
- audiotools/metrics/distance.py +131 -0
- audiotools/metrics/quality.py +159 -0
- audiotools/metrics/spectral.py +247 -0
- audiotools/ml/__init__.py +5 -0
- audiotools/ml/accelerator.py +184 -0
- audiotools/ml/decorators.py +440 -0
- audiotools/ml/experiment.py +90 -0
- audiotools/ml/layers/__init__.py +2 -0
- audiotools/ml/layers/base.py +328 -0
- audiotools/ml/layers/spectral_gate.py +127 -0
- audiotools/post.py +140 -0
- audiotools/preference.py +600 -0
- src/inference.py +169 -0
- src/inference_controlnet.py +129 -0
- src/models/.ipynb_checkpoints/blocks-checkpoint.py +325 -0
- src/models/.ipynb_checkpoints/conditioners-checkpoint.py +183 -0
- src/models/.ipynb_checkpoints/controlnet-checkpoint.py +318 -0
- src/models/.ipynb_checkpoints/udit-checkpoint.py +365 -0
- src/models/__pycache__/attention.cpython-311.pyc +0 -0
- src/models/__pycache__/blocks.cpython-310.pyc +0 -0
- src/models/__pycache__/blocks.cpython-311.pyc +0 -0
- src/models/__pycache__/conditioners.cpython-310.pyc +0 -0
- src/models/__pycache__/conditioners.cpython-311.pyc +0 -0
- src/models/__pycache__/controlnet.cpython-311.pyc +0 -0
- src/models/__pycache__/modules.cpython-311.pyc +0 -0
- src/models/__pycache__/rotary.cpython-311.pyc +0 -0
- src/models/__pycache__/timm.cpython-311.pyc +0 -0
- src/models/__pycache__/udit.cpython-310.pyc +0 -0
- src/models/__pycache__/udit.cpython-311.pyc +0 -0
- src/models/blocks.py +325 -0
audiotools/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "0.7.3"
|
2 |
+
from .core import AudioSignal
|
3 |
+
from .core import STFTParams
|
4 |
+
from .core import Meter
|
5 |
+
from .core import util
|
6 |
+
from . import metrics
|
7 |
+
from . import data
|
8 |
+
from . import ml
|
9 |
+
from .data import datasets
|
10 |
+
from .data import transforms
|
audiotools/core/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import util
|
2 |
+
from .audio_signal import AudioSignal
|
3 |
+
from .audio_signal import STFTParams
|
4 |
+
from .loudness import Meter
|
audiotools/core/audio_signal.py
ADDED
@@ -0,0 +1,1682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import functools
|
3 |
+
import hashlib
|
4 |
+
import math
|
5 |
+
import pathlib
|
6 |
+
import tempfile
|
7 |
+
import typing
|
8 |
+
import warnings
|
9 |
+
from collections import namedtuple
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import julius
|
13 |
+
import numpy as np
|
14 |
+
import soundfile
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from . import util
|
18 |
+
from .display import DisplayMixin
|
19 |
+
from .dsp import DSPMixin
|
20 |
+
from .effects import EffectMixin
|
21 |
+
from .effects import ImpulseResponseMixin
|
22 |
+
from .ffmpeg import FFMPEGMixin
|
23 |
+
from .loudness import LoudnessMixin
|
24 |
+
from .playback import PlayMixin
|
25 |
+
from .whisper import WhisperMixin
|
26 |
+
|
27 |
+
|
28 |
+
STFTParams = namedtuple(
|
29 |
+
"STFTParams",
|
30 |
+
["window_length", "hop_length", "window_type", "match_stride", "padding_type"],
|
31 |
+
)
|
32 |
+
"""
|
33 |
+
STFTParams object is a container that holds STFT parameters - window_length,
|
34 |
+
hop_length, and window_type. Not all parameters need to be specified. Ones that
|
35 |
+
are not specified will be inferred by the AudioSignal parameters.
|
36 |
+
|
37 |
+
Parameters
|
38 |
+
----------
|
39 |
+
window_length : int, optional
|
40 |
+
Window length of STFT, by default ``0.032 * self.sample_rate``.
|
41 |
+
hop_length : int, optional
|
42 |
+
Hop length of STFT, by default ``window_length // 4``.
|
43 |
+
window_type : str, optional
|
44 |
+
Type of window to use, by default ``sqrt\_hann``.
|
45 |
+
match_stride : bool, optional
|
46 |
+
Whether to match the stride of convolutional layers, by default False
|
47 |
+
padding_type : str, optional
|
48 |
+
Type of padding to use, by default 'reflect'
|
49 |
+
"""
|
50 |
+
STFTParams.__new__.__defaults__ = (None, None, None, None, None)
|
51 |
+
|
52 |
+
|
53 |
+
class AudioSignal(
|
54 |
+
EffectMixin,
|
55 |
+
LoudnessMixin,
|
56 |
+
PlayMixin,
|
57 |
+
ImpulseResponseMixin,
|
58 |
+
DSPMixin,
|
59 |
+
DisplayMixin,
|
60 |
+
FFMPEGMixin,
|
61 |
+
WhisperMixin,
|
62 |
+
):
|
63 |
+
"""This is the core object of this library. Audio is always
|
64 |
+
loaded into an AudioSignal, which then enables all the features
|
65 |
+
of this library, including audio augmentations, I/O, playback,
|
66 |
+
and more.
|
67 |
+
|
68 |
+
The structure of this object is that the base functionality
|
69 |
+
is defined in ``core/audio_signal.py``, while extensions to
|
70 |
+
that functionality are defined in the other ``core/*.py``
|
71 |
+
files. For example, all the display-based functionality
|
72 |
+
(e.g. plot spectrograms, waveforms, write to tensorboard)
|
73 |
+
are in ``core/display.py``.
|
74 |
+
|
75 |
+
Parameters
|
76 |
+
----------
|
77 |
+
audio_path_or_array : typing.Union[torch.Tensor, str, Path, np.ndarray]
|
78 |
+
Object to create AudioSignal from. Can be a tensor, numpy array,
|
79 |
+
or a path to a file. The file is always reshaped to
|
80 |
+
sample_rate : int, optional
|
81 |
+
Sample rate of the audio. If different from underlying file, resampling is
|
82 |
+
performed. If passing in an array or tensor, this must be defined,
|
83 |
+
by default None
|
84 |
+
stft_params : STFTParams, optional
|
85 |
+
Parameters of STFT to use. , by default None
|
86 |
+
offset : float, optional
|
87 |
+
Offset in seconds to read from file, by default 0
|
88 |
+
duration : float, optional
|
89 |
+
Duration in seconds to read from file, by default None
|
90 |
+
device : str, optional
|
91 |
+
Device to load audio onto, by default None
|
92 |
+
|
93 |
+
Examples
|
94 |
+
--------
|
95 |
+
Loading an AudioSignal from an array, at a sample rate of
|
96 |
+
44100.
|
97 |
+
|
98 |
+
>>> signal = AudioSignal(torch.randn(5*44100), 44100)
|
99 |
+
|
100 |
+
Note, the signal is reshaped to have a batch size, and one
|
101 |
+
audio channel:
|
102 |
+
|
103 |
+
>>> print(signal.shape)
|
104 |
+
(1, 1, 44100)
|
105 |
+
|
106 |
+
You can treat AudioSignals like tensors, and many of the same
|
107 |
+
functions you might use on tensors are defined for AudioSignals
|
108 |
+
as well:
|
109 |
+
|
110 |
+
>>> signal.to("cuda")
|
111 |
+
>>> signal.cuda()
|
112 |
+
>>> signal.clone()
|
113 |
+
>>> signal.detach()
|
114 |
+
|
115 |
+
Indexing AudioSignals returns an AudioSignal:
|
116 |
+
|
117 |
+
>>> signal[..., 3*44100:4*44100]
|
118 |
+
|
119 |
+
The above signal is 1 second long, and is also an AudioSignal.
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
audio_path_or_array: typing.Union[torch.Tensor, str, Path, np.ndarray],
|
125 |
+
sample_rate: int = None,
|
126 |
+
stft_params: STFTParams = None,
|
127 |
+
offset: float = 0,
|
128 |
+
duration: float = None,
|
129 |
+
device: str = None,
|
130 |
+
):
|
131 |
+
audio_path = None
|
132 |
+
audio_array = None
|
133 |
+
|
134 |
+
if isinstance(audio_path_or_array, str):
|
135 |
+
audio_path = audio_path_or_array
|
136 |
+
elif isinstance(audio_path_or_array, pathlib.Path):
|
137 |
+
audio_path = audio_path_or_array
|
138 |
+
elif isinstance(audio_path_or_array, np.ndarray):
|
139 |
+
audio_array = audio_path_or_array
|
140 |
+
elif torch.is_tensor(audio_path_or_array):
|
141 |
+
audio_array = audio_path_or_array
|
142 |
+
else:
|
143 |
+
raise ValueError(
|
144 |
+
"audio_path_or_array must be either a Path, "
|
145 |
+
"string, numpy array, or torch Tensor!"
|
146 |
+
)
|
147 |
+
|
148 |
+
self.path_to_file = None
|
149 |
+
|
150 |
+
self.audio_data = None
|
151 |
+
self.sources = None # List of AudioSignal objects.
|
152 |
+
self.stft_data = None
|
153 |
+
if audio_path is not None:
|
154 |
+
self.load_from_file(
|
155 |
+
audio_path, offset=offset, duration=duration, device=device
|
156 |
+
)
|
157 |
+
elif audio_array is not None:
|
158 |
+
assert sample_rate is not None, "Must set sample rate!"
|
159 |
+
self.load_from_array(audio_array, sample_rate, device=device)
|
160 |
+
|
161 |
+
self.window = None
|
162 |
+
self.stft_params = stft_params
|
163 |
+
|
164 |
+
self.metadata = {
|
165 |
+
"offset": offset,
|
166 |
+
"duration": duration,
|
167 |
+
}
|
168 |
+
|
169 |
+
@property
|
170 |
+
def path_to_input_file(
|
171 |
+
self,
|
172 |
+
):
|
173 |
+
"""
|
174 |
+
Path to input file, if it exists.
|
175 |
+
Alias to ``path_to_file`` for backwards compatibility
|
176 |
+
"""
|
177 |
+
return self.path_to_file
|
178 |
+
|
179 |
+
@classmethod
|
180 |
+
def excerpt(
|
181 |
+
cls,
|
182 |
+
audio_path: typing.Union[str, Path],
|
183 |
+
offset: float = None,
|
184 |
+
duration: float = None,
|
185 |
+
state: typing.Union[np.random.RandomState, int] = None,
|
186 |
+
**kwargs,
|
187 |
+
):
|
188 |
+
"""Randomly draw an excerpt of ``duration`` seconds from an
|
189 |
+
audio file specified at ``audio_path``, between ``offset`` seconds
|
190 |
+
and end of file. ``state`` can be used to seed the random draw.
|
191 |
+
|
192 |
+
Parameters
|
193 |
+
----------
|
194 |
+
audio_path : typing.Union[str, Path]
|
195 |
+
Path to audio file to grab excerpt from.
|
196 |
+
offset : float, optional
|
197 |
+
Lower bound for the start time, in seconds drawn from
|
198 |
+
the file, by default None.
|
199 |
+
duration : float, optional
|
200 |
+
Duration of excerpt, in seconds, by default None
|
201 |
+
state : typing.Union[np.random.RandomState, int], optional
|
202 |
+
RandomState or seed of random state, by default None
|
203 |
+
|
204 |
+
Returns
|
205 |
+
-------
|
206 |
+
AudioSignal
|
207 |
+
AudioSignal containing excerpt.
|
208 |
+
|
209 |
+
Examples
|
210 |
+
--------
|
211 |
+
>>> signal = AudioSignal.excerpt("path/to/audio", duration=5)
|
212 |
+
"""
|
213 |
+
info = util.info(audio_path)
|
214 |
+
total_duration = info.duration
|
215 |
+
|
216 |
+
state = util.random_state(state)
|
217 |
+
lower_bound = 0 if offset is None else offset
|
218 |
+
upper_bound = max(total_duration - duration, 0)
|
219 |
+
offset = state.uniform(lower_bound, upper_bound)
|
220 |
+
|
221 |
+
signal = cls(audio_path, offset=offset, duration=duration, **kwargs)
|
222 |
+
signal.metadata["offset"] = offset
|
223 |
+
signal.metadata["duration"] = duration
|
224 |
+
|
225 |
+
return signal
|
226 |
+
|
227 |
+
@classmethod
|
228 |
+
def salient_excerpt(
|
229 |
+
cls,
|
230 |
+
audio_path: typing.Union[str, Path],
|
231 |
+
loudness_cutoff: float = None,
|
232 |
+
num_tries: int = 8,
|
233 |
+
state: typing.Union[np.random.RandomState, int] = None,
|
234 |
+
**kwargs,
|
235 |
+
):
|
236 |
+
"""Similar to AudioSignal.excerpt, except it extracts excerpts only
|
237 |
+
if they are above a specified loudness threshold, which is computed via
|
238 |
+
a fast LUFS routine.
|
239 |
+
|
240 |
+
Parameters
|
241 |
+
----------
|
242 |
+
audio_path : typing.Union[str, Path]
|
243 |
+
Path to audio file to grab excerpt from.
|
244 |
+
loudness_cutoff : float, optional
|
245 |
+
Loudness threshold in dB. Typical values are ``-40, -60``,
|
246 |
+
etc, by default None
|
247 |
+
num_tries : int, optional
|
248 |
+
Number of tries to grab an excerpt above the threshold
|
249 |
+
before giving up, by default 8.
|
250 |
+
state : typing.Union[np.random.RandomState, int], optional
|
251 |
+
RandomState or seed of random state, by default None
|
252 |
+
kwargs : dict
|
253 |
+
Keyword arguments to AudioSignal.excerpt
|
254 |
+
|
255 |
+
Returns
|
256 |
+
-------
|
257 |
+
AudioSignal
|
258 |
+
AudioSignal containing excerpt.
|
259 |
+
|
260 |
+
|
261 |
+
.. warning::
|
262 |
+
if ``num_tries`` is set to None, ``salient_excerpt`` may try forever, which can
|
263 |
+
result in an infinite loop if ``audio_path`` does not have
|
264 |
+
any loud enough excerpts.
|
265 |
+
|
266 |
+
Examples
|
267 |
+
--------
|
268 |
+
>>> signal = AudioSignal.salient_excerpt(
|
269 |
+
"path/to/audio",
|
270 |
+
loudness_cutoff=-40,
|
271 |
+
duration=5
|
272 |
+
)
|
273 |
+
"""
|
274 |
+
state = util.random_state(state)
|
275 |
+
if loudness_cutoff is None:
|
276 |
+
excerpt = cls.excerpt(audio_path, state=state, **kwargs)
|
277 |
+
else:
|
278 |
+
loudness = -np.inf
|
279 |
+
num_try = 0
|
280 |
+
while loudness <= loudness_cutoff:
|
281 |
+
excerpt = cls.excerpt(audio_path, state=state, **kwargs)
|
282 |
+
loudness = excerpt.loudness()
|
283 |
+
num_try += 1
|
284 |
+
if num_tries is not None and num_try >= num_tries:
|
285 |
+
break
|
286 |
+
return excerpt
|
287 |
+
|
288 |
+
@classmethod
|
289 |
+
def zeros(
|
290 |
+
cls,
|
291 |
+
duration: float,
|
292 |
+
sample_rate: int,
|
293 |
+
num_channels: int = 1,
|
294 |
+
batch_size: int = 1,
|
295 |
+
**kwargs,
|
296 |
+
):
|
297 |
+
"""Helper function create an AudioSignal of all zeros.
|
298 |
+
|
299 |
+
Parameters
|
300 |
+
----------
|
301 |
+
duration : float
|
302 |
+
Duration of AudioSignal
|
303 |
+
sample_rate : int
|
304 |
+
Sample rate of AudioSignal
|
305 |
+
num_channels : int, optional
|
306 |
+
Number of channels, by default 1
|
307 |
+
batch_size : int, optional
|
308 |
+
Batch size, by default 1
|
309 |
+
|
310 |
+
Returns
|
311 |
+
-------
|
312 |
+
AudioSignal
|
313 |
+
AudioSignal containing all zeros.
|
314 |
+
|
315 |
+
Examples
|
316 |
+
--------
|
317 |
+
Generate 5 seconds of all zeros at a sample rate of 44100.
|
318 |
+
|
319 |
+
>>> signal = AudioSignal.zeros(5.0, 44100)
|
320 |
+
"""
|
321 |
+
n_samples = int(duration * sample_rate)
|
322 |
+
return cls(
|
323 |
+
torch.zeros(batch_size, num_channels, n_samples), sample_rate, **kwargs
|
324 |
+
)
|
325 |
+
|
326 |
+
@classmethod
|
327 |
+
def wave(
|
328 |
+
cls,
|
329 |
+
frequency: float,
|
330 |
+
duration: float,
|
331 |
+
sample_rate: int,
|
332 |
+
num_channels: int = 1,
|
333 |
+
shape: str = "sine",
|
334 |
+
**kwargs,
|
335 |
+
):
|
336 |
+
"""
|
337 |
+
Generate a waveform of a given frequency and shape.
|
338 |
+
|
339 |
+
Parameters
|
340 |
+
----------
|
341 |
+
frequency : float
|
342 |
+
Frequency of the waveform
|
343 |
+
duration : float
|
344 |
+
Duration of the waveform
|
345 |
+
sample_rate : int
|
346 |
+
Sample rate of the waveform
|
347 |
+
num_channels : int, optional
|
348 |
+
Number of channels, by default 1
|
349 |
+
shape : str, optional
|
350 |
+
Shape of the waveform, by default "saw"
|
351 |
+
One of "sawtooth", "square", "sine", "triangle"
|
352 |
+
kwargs : dict
|
353 |
+
Keyword arguments to AudioSignal
|
354 |
+
"""
|
355 |
+
n_samples = int(duration * sample_rate)
|
356 |
+
t = torch.linspace(0, duration, n_samples)
|
357 |
+
if shape == "sawtooth":
|
358 |
+
from scipy.signal import sawtooth
|
359 |
+
|
360 |
+
wave_data = sawtooth(2 * np.pi * frequency * t, 0.5)
|
361 |
+
elif shape == "square":
|
362 |
+
from scipy.signal import square
|
363 |
+
|
364 |
+
wave_data = square(2 * np.pi * frequency * t)
|
365 |
+
elif shape == "sine":
|
366 |
+
wave_data = np.sin(2 * np.pi * frequency * t)
|
367 |
+
elif shape == "triangle":
|
368 |
+
from scipy.signal import sawtooth
|
369 |
+
|
370 |
+
# frequency is doubled by the abs call, so omit the 2 in 2pi
|
371 |
+
wave_data = sawtooth(np.pi * frequency * t, 0.5)
|
372 |
+
wave_data = -np.abs(wave_data) * 2 + 1
|
373 |
+
else:
|
374 |
+
raise ValueError(f"Invalid shape {shape}")
|
375 |
+
|
376 |
+
wave_data = torch.tensor(wave_data, dtype=torch.float32)
|
377 |
+
wave_data = wave_data.unsqueeze(0).unsqueeze(0).repeat(1, num_channels, 1)
|
378 |
+
return cls(wave_data, sample_rate, **kwargs)
|
379 |
+
|
380 |
+
@classmethod
|
381 |
+
def batch(
|
382 |
+
cls,
|
383 |
+
audio_signals: list,
|
384 |
+
pad_signals: bool = False,
|
385 |
+
truncate_signals: bool = False,
|
386 |
+
resample: bool = False,
|
387 |
+
dim: int = 0,
|
388 |
+
):
|
389 |
+
"""Creates a batched AudioSignal from a list of AudioSignals.
|
390 |
+
|
391 |
+
Parameters
|
392 |
+
----------
|
393 |
+
audio_signals : list[AudioSignal]
|
394 |
+
List of AudioSignal objects
|
395 |
+
pad_signals : bool, optional
|
396 |
+
Whether to pad signals to length of the maximum length
|
397 |
+
AudioSignal in the list, by default False
|
398 |
+
truncate_signals : bool, optional
|
399 |
+
Whether to truncate signals to length of shortest length
|
400 |
+
AudioSignal in the list, by default False
|
401 |
+
resample : bool, optional
|
402 |
+
Whether to resample AudioSignal to the sample rate of
|
403 |
+
the first AudioSignal in the list, by default False
|
404 |
+
dim : int, optional
|
405 |
+
Dimension along which to batch the signals.
|
406 |
+
|
407 |
+
Returns
|
408 |
+
-------
|
409 |
+
AudioSignal
|
410 |
+
Batched AudioSignal.
|
411 |
+
|
412 |
+
Raises
|
413 |
+
------
|
414 |
+
RuntimeError
|
415 |
+
If not all AudioSignals are the same sample rate, and
|
416 |
+
``resample=False``, an error is raised.
|
417 |
+
RuntimeError
|
418 |
+
If not all AudioSignals are the same the length, and
|
419 |
+
both ``pad_signals=False`` and ``truncate_signals=False``,
|
420 |
+
an error is raised.
|
421 |
+
|
422 |
+
Examples
|
423 |
+
--------
|
424 |
+
Batching a bunch of random signals:
|
425 |
+
|
426 |
+
>>> signal_list = [AudioSignal(torch.randn(44100), 44100) for _ in range(10)]
|
427 |
+
>>> signal = AudioSignal.batch(signal_list)
|
428 |
+
>>> print(signal.shape)
|
429 |
+
(10, 1, 44100)
|
430 |
+
|
431 |
+
"""
|
432 |
+
signal_lengths = [x.signal_length for x in audio_signals]
|
433 |
+
sample_rates = [x.sample_rate for x in audio_signals]
|
434 |
+
|
435 |
+
if len(set(sample_rates)) != 1:
|
436 |
+
if resample:
|
437 |
+
for x in audio_signals:
|
438 |
+
x.resample(sample_rates[0])
|
439 |
+
else:
|
440 |
+
raise RuntimeError(
|
441 |
+
f"Not all signals had the same sample rate! Got {sample_rates}. "
|
442 |
+
f"All signals must have the same sample rate, or resample must be True. "
|
443 |
+
)
|
444 |
+
|
445 |
+
if len(set(signal_lengths)) != 1:
|
446 |
+
if pad_signals:
|
447 |
+
max_length = max(signal_lengths)
|
448 |
+
for x in audio_signals:
|
449 |
+
pad_len = max_length - x.signal_length
|
450 |
+
x.zero_pad(0, pad_len)
|
451 |
+
elif truncate_signals:
|
452 |
+
min_length = min(signal_lengths)
|
453 |
+
for x in audio_signals:
|
454 |
+
x.truncate_samples(min_length)
|
455 |
+
else:
|
456 |
+
raise RuntimeError(
|
457 |
+
f"Not all signals had the same length! Got {signal_lengths}. "
|
458 |
+
f"All signals must be the same length, or pad_signals/truncate_signals "
|
459 |
+
f"must be True. "
|
460 |
+
)
|
461 |
+
# Concatenate along the specified dimension (default 0)
|
462 |
+
audio_data = torch.cat([x.audio_data for x in audio_signals], dim=dim)
|
463 |
+
audio_paths = [x.path_to_file for x in audio_signals]
|
464 |
+
|
465 |
+
batched_signal = cls(
|
466 |
+
audio_data,
|
467 |
+
sample_rate=audio_signals[0].sample_rate,
|
468 |
+
)
|
469 |
+
batched_signal.path_to_file = audio_paths
|
470 |
+
return batched_signal
|
471 |
+
|
472 |
+
# I/O
|
473 |
+
def load_from_file(
|
474 |
+
self,
|
475 |
+
audio_path: typing.Union[str, Path],
|
476 |
+
offset: float,
|
477 |
+
duration: float,
|
478 |
+
device: str = "cpu",
|
479 |
+
):
|
480 |
+
"""Loads data from file. Used internally when AudioSignal
|
481 |
+
is instantiated with a path to a file.
|
482 |
+
|
483 |
+
Parameters
|
484 |
+
----------
|
485 |
+
audio_path : typing.Union[str, Path]
|
486 |
+
Path to file
|
487 |
+
offset : float
|
488 |
+
Offset in seconds
|
489 |
+
duration : float
|
490 |
+
Duration in seconds
|
491 |
+
device : str, optional
|
492 |
+
Device to put AudioSignal on, by default "cpu"
|
493 |
+
|
494 |
+
Returns
|
495 |
+
-------
|
496 |
+
AudioSignal
|
497 |
+
AudioSignal loaded from file
|
498 |
+
"""
|
499 |
+
import librosa
|
500 |
+
|
501 |
+
data, sample_rate = librosa.load(
|
502 |
+
audio_path,
|
503 |
+
offset=offset,
|
504 |
+
duration=duration,
|
505 |
+
sr=None,
|
506 |
+
mono=False,
|
507 |
+
)
|
508 |
+
data = util.ensure_tensor(data)
|
509 |
+
if data.shape[-1] == 0:
|
510 |
+
raise RuntimeError(
|
511 |
+
f"Audio file {audio_path} with offset {offset} and duration {duration} is empty!"
|
512 |
+
)
|
513 |
+
|
514 |
+
if data.ndim < 2:
|
515 |
+
data = data.unsqueeze(0)
|
516 |
+
if data.ndim < 3:
|
517 |
+
data = data.unsqueeze(0)
|
518 |
+
self.audio_data = data
|
519 |
+
|
520 |
+
self.original_signal_length = self.signal_length
|
521 |
+
|
522 |
+
self.sample_rate = sample_rate
|
523 |
+
self.path_to_file = audio_path
|
524 |
+
return self.to(device)
|
525 |
+
|
526 |
+
def load_from_array(
|
527 |
+
self,
|
528 |
+
audio_array: typing.Union[torch.Tensor, np.ndarray],
|
529 |
+
sample_rate: int,
|
530 |
+
device: str = "cpu",
|
531 |
+
):
|
532 |
+
"""Loads data from array, reshaping it to be exactly 3
|
533 |
+
dimensions. Used internally when AudioSignal is called
|
534 |
+
with a tensor or an array.
|
535 |
+
|
536 |
+
Parameters
|
537 |
+
----------
|
538 |
+
audio_array : typing.Union[torch.Tensor, np.ndarray]
|
539 |
+
Array/tensor of audio of samples.
|
540 |
+
sample_rate : int
|
541 |
+
Sample rate of audio
|
542 |
+
device : str, optional
|
543 |
+
Device to move audio onto, by default "cpu"
|
544 |
+
|
545 |
+
Returns
|
546 |
+
-------
|
547 |
+
AudioSignal
|
548 |
+
AudioSignal loaded from array
|
549 |
+
"""
|
550 |
+
audio_data = util.ensure_tensor(audio_array)
|
551 |
+
|
552 |
+
if audio_data.dtype == torch.double:
|
553 |
+
audio_data = audio_data.float()
|
554 |
+
|
555 |
+
if audio_data.ndim < 2:
|
556 |
+
audio_data = audio_data.unsqueeze(0)
|
557 |
+
if audio_data.ndim < 3:
|
558 |
+
audio_data = audio_data.unsqueeze(0)
|
559 |
+
self.audio_data = audio_data
|
560 |
+
|
561 |
+
self.original_signal_length = self.signal_length
|
562 |
+
|
563 |
+
self.sample_rate = sample_rate
|
564 |
+
return self.to(device)
|
565 |
+
|
566 |
+
def write(self, audio_path: typing.Union[str, Path]):
|
567 |
+
"""Writes audio to a file. Only writes the audio
|
568 |
+
that is in the very first item of the batch. To write other items
|
569 |
+
in the batch, index the signal along the batch dimension
|
570 |
+
before writing. After writing, the signal's ``path_to_file``
|
571 |
+
attribute is updated to the new path.
|
572 |
+
|
573 |
+
Parameters
|
574 |
+
----------
|
575 |
+
audio_path : typing.Union[str, Path]
|
576 |
+
Path to write audio to.
|
577 |
+
|
578 |
+
Returns
|
579 |
+
-------
|
580 |
+
AudioSignal
|
581 |
+
Returns original AudioSignal, so you can use this in a fluent
|
582 |
+
interface.
|
583 |
+
|
584 |
+
Examples
|
585 |
+
--------
|
586 |
+
Creating and writing a signal to disk:
|
587 |
+
|
588 |
+
>>> signal = AudioSignal(torch.randn(10, 1, 44100), 44100)
|
589 |
+
>>> signal.write("/tmp/out.wav")
|
590 |
+
|
591 |
+
Writing a different element of the batch:
|
592 |
+
|
593 |
+
>>> signal[5].write("/tmp/out.wav")
|
594 |
+
|
595 |
+
Using this in a fluent interface:
|
596 |
+
|
597 |
+
>>> signal.write("/tmp/original.wav").low_pass(4000).write("/tmp/lowpass.wav")
|
598 |
+
|
599 |
+
"""
|
600 |
+
if self.audio_data[0].abs().max() > 1:
|
601 |
+
warnings.warn("Audio amplitude > 1 clipped when saving")
|
602 |
+
soundfile.write(str(audio_path), self.audio_data[0].numpy().T, self.sample_rate)
|
603 |
+
|
604 |
+
self.path_to_file = audio_path
|
605 |
+
return self
|
606 |
+
|
607 |
+
def deepcopy(self):
|
608 |
+
"""Copies the signal and all of its attributes.
|
609 |
+
|
610 |
+
Returns
|
611 |
+
-------
|
612 |
+
AudioSignal
|
613 |
+
Deep copy of the audio signal.
|
614 |
+
"""
|
615 |
+
return copy.deepcopy(self)
|
616 |
+
|
617 |
+
def copy(self):
|
618 |
+
"""Shallow copy of signal.
|
619 |
+
|
620 |
+
Returns
|
621 |
+
-------
|
622 |
+
AudioSignal
|
623 |
+
Shallow copy of the audio signal.
|
624 |
+
"""
|
625 |
+
return copy.copy(self)
|
626 |
+
|
627 |
+
def clone(self):
|
628 |
+
"""Clones all tensors contained in the AudioSignal,
|
629 |
+
and returns a copy of the signal with everything
|
630 |
+
cloned. Useful when using AudioSignal within autograd
|
631 |
+
computation graphs.
|
632 |
+
|
633 |
+
Relevant attributes are the stft data, the audio data,
|
634 |
+
and the loudness of the file.
|
635 |
+
|
636 |
+
Returns
|
637 |
+
-------
|
638 |
+
AudioSignal
|
639 |
+
Clone of AudioSignal.
|
640 |
+
"""
|
641 |
+
clone = type(self)(
|
642 |
+
self.audio_data.clone(),
|
643 |
+
self.sample_rate,
|
644 |
+
stft_params=self.stft_params,
|
645 |
+
)
|
646 |
+
if self.stft_data is not None:
|
647 |
+
clone.stft_data = self.stft_data.clone()
|
648 |
+
if self._loudness is not None:
|
649 |
+
clone._loudness = self._loudness.clone()
|
650 |
+
clone.path_to_file = copy.deepcopy(self.path_to_file)
|
651 |
+
clone.metadata = copy.deepcopy(self.metadata)
|
652 |
+
return clone
|
653 |
+
|
654 |
+
def detach(self):
|
655 |
+
"""Detaches tensors contained in AudioSignal.
|
656 |
+
|
657 |
+
Relevant attributes are the stft data, the audio data,
|
658 |
+
and the loudness of the file.
|
659 |
+
|
660 |
+
Returns
|
661 |
+
-------
|
662 |
+
AudioSignal
|
663 |
+
Same signal, but with all tensors detached.
|
664 |
+
"""
|
665 |
+
if self._loudness is not None:
|
666 |
+
self._loudness = self._loudness.detach()
|
667 |
+
if self.stft_data is not None:
|
668 |
+
self.stft_data = self.stft_data.detach()
|
669 |
+
|
670 |
+
self.audio_data = self.audio_data.detach()
|
671 |
+
return self
|
672 |
+
|
673 |
+
def hash(self):
|
674 |
+
"""Writes the audio data to a temporary file, and then
|
675 |
+
hashes it using hashlib. Useful for creating a file
|
676 |
+
name based on the audio content.
|
677 |
+
|
678 |
+
Returns
|
679 |
+
-------
|
680 |
+
str
|
681 |
+
Hash of audio data.
|
682 |
+
|
683 |
+
Examples
|
684 |
+
--------
|
685 |
+
Creating a signal, and writing it to a unique file name:
|
686 |
+
|
687 |
+
>>> signal = AudioSignal(torch.randn(44100), 44100)
|
688 |
+
>>> hash = signal.hash()
|
689 |
+
>>> signal.write(f"{hash}.wav")
|
690 |
+
|
691 |
+
"""
|
692 |
+
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
|
693 |
+
self.write(f.name)
|
694 |
+
h = hashlib.sha256()
|
695 |
+
b = bytearray(128 * 1024)
|
696 |
+
mv = memoryview(b)
|
697 |
+
with open(f.name, "rb", buffering=0) as f:
|
698 |
+
for n in iter(lambda: f.readinto(mv), 0):
|
699 |
+
h.update(mv[:n])
|
700 |
+
file_hash = h.hexdigest()
|
701 |
+
return file_hash
|
702 |
+
|
703 |
+
# Signal operations
|
704 |
+
def to_mono(self):
|
705 |
+
"""Converts audio data to mono audio, by taking the mean
|
706 |
+
along the channels dimension.
|
707 |
+
|
708 |
+
Returns
|
709 |
+
-------
|
710 |
+
AudioSignal
|
711 |
+
AudioSignal with mean of channels.
|
712 |
+
"""
|
713 |
+
self.audio_data = self.audio_data.mean(1, keepdim=True)
|
714 |
+
return self
|
715 |
+
|
716 |
+
def resample(self, sample_rate: int):
|
717 |
+
"""Resamples the audio, using sinc interpolation. This works on both
|
718 |
+
cpu and gpu, and is much faster on gpu.
|
719 |
+
|
720 |
+
Parameters
|
721 |
+
----------
|
722 |
+
sample_rate : int
|
723 |
+
Sample rate to resample to.
|
724 |
+
|
725 |
+
Returns
|
726 |
+
-------
|
727 |
+
AudioSignal
|
728 |
+
Resampled AudioSignal
|
729 |
+
"""
|
730 |
+
if sample_rate == self.sample_rate:
|
731 |
+
return self
|
732 |
+
self.audio_data = julius.resample_frac(
|
733 |
+
self.audio_data, self.sample_rate, sample_rate
|
734 |
+
)
|
735 |
+
self.sample_rate = sample_rate
|
736 |
+
return self
|
737 |
+
|
738 |
+
# Tensor operations
|
739 |
+
def to(self, device: str):
|
740 |
+
"""Moves all tensors contained in signal to the specified device.
|
741 |
+
|
742 |
+
Parameters
|
743 |
+
----------
|
744 |
+
device : str
|
745 |
+
Device to move AudioSignal onto. Typical values are
|
746 |
+
"cuda", "cpu", or "cuda:n" to specify the nth gpu.
|
747 |
+
|
748 |
+
Returns
|
749 |
+
-------
|
750 |
+
AudioSignal
|
751 |
+
AudioSignal with all tensors moved to specified device.
|
752 |
+
"""
|
753 |
+
if self._loudness is not None:
|
754 |
+
self._loudness = self._loudness.to(device)
|
755 |
+
if self.stft_data is not None:
|
756 |
+
self.stft_data = self.stft_data.to(device)
|
757 |
+
if self.audio_data is not None:
|
758 |
+
self.audio_data = self.audio_data.to(device)
|
759 |
+
return self
|
760 |
+
|
761 |
+
def float(self):
|
762 |
+
"""Calls ``.float()`` on ``self.audio_data``.
|
763 |
+
|
764 |
+
Returns
|
765 |
+
-------
|
766 |
+
AudioSignal
|
767 |
+
"""
|
768 |
+
self.audio_data = self.audio_data.float()
|
769 |
+
return self
|
770 |
+
|
771 |
+
def cpu(self):
|
772 |
+
"""Moves AudioSignal to cpu.
|
773 |
+
|
774 |
+
Returns
|
775 |
+
-------
|
776 |
+
AudioSignal
|
777 |
+
"""
|
778 |
+
return self.to("cpu")
|
779 |
+
|
780 |
+
def cuda(self): # pragma: no cover
|
781 |
+
"""Moves AudioSignal to cuda.
|
782 |
+
|
783 |
+
Returns
|
784 |
+
-------
|
785 |
+
AudioSignal
|
786 |
+
"""
|
787 |
+
return self.to("cuda")
|
788 |
+
|
789 |
+
def numpy(self):
|
790 |
+
"""Detaches ``self.audio_data``, moves to cpu, and converts to numpy.
|
791 |
+
|
792 |
+
Returns
|
793 |
+
-------
|
794 |
+
np.ndarray
|
795 |
+
Audio data as a numpy array.
|
796 |
+
"""
|
797 |
+
return self.audio_data.detach().cpu().numpy()
|
798 |
+
|
799 |
+
def zero_pad(self, before: int, after: int):
|
800 |
+
"""Zero pads the audio_data tensor before and after.
|
801 |
+
|
802 |
+
Parameters
|
803 |
+
----------
|
804 |
+
before : int
|
805 |
+
How many zeros to prepend to audio.
|
806 |
+
after : int
|
807 |
+
How many zeros to append to audio.
|
808 |
+
|
809 |
+
Returns
|
810 |
+
-------
|
811 |
+
AudioSignal
|
812 |
+
AudioSignal with padding applied.
|
813 |
+
"""
|
814 |
+
self.audio_data = torch.nn.functional.pad(self.audio_data, (before, after))
|
815 |
+
return self
|
816 |
+
|
817 |
+
def zero_pad_to(self, length: int, mode: str = "after"):
|
818 |
+
"""Pad with zeros to a specified length, either before or after
|
819 |
+
the audio data.
|
820 |
+
|
821 |
+
Parameters
|
822 |
+
----------
|
823 |
+
length : int
|
824 |
+
Length to pad to
|
825 |
+
mode : str, optional
|
826 |
+
Whether to prepend or append zeros to signal, by default "after"
|
827 |
+
|
828 |
+
Returns
|
829 |
+
-------
|
830 |
+
AudioSignal
|
831 |
+
AudioSignal with padding applied.
|
832 |
+
"""
|
833 |
+
if mode == "before":
|
834 |
+
self.zero_pad(max(length - self.signal_length, 0), 0)
|
835 |
+
elif mode == "after":
|
836 |
+
self.zero_pad(0, max(length - self.signal_length, 0))
|
837 |
+
return self
|
838 |
+
|
839 |
+
def trim(self, before: int, after: int):
|
840 |
+
"""Trims the audio_data tensor before and after.
|
841 |
+
|
842 |
+
Parameters
|
843 |
+
----------
|
844 |
+
before : int
|
845 |
+
How many samples to trim from beginning.
|
846 |
+
after : int
|
847 |
+
How many samples to trim from end.
|
848 |
+
|
849 |
+
Returns
|
850 |
+
-------
|
851 |
+
AudioSignal
|
852 |
+
AudioSignal with trimming applied.
|
853 |
+
"""
|
854 |
+
if after == 0:
|
855 |
+
self.audio_data = self.audio_data[..., before:]
|
856 |
+
else:
|
857 |
+
self.audio_data = self.audio_data[..., before:-after]
|
858 |
+
return self
|
859 |
+
|
860 |
+
def truncate_samples(self, length_in_samples: int):
|
861 |
+
"""Truncate signal to specified length.
|
862 |
+
|
863 |
+
Parameters
|
864 |
+
----------
|
865 |
+
length_in_samples : int
|
866 |
+
Truncate to this many samples.
|
867 |
+
|
868 |
+
Returns
|
869 |
+
-------
|
870 |
+
AudioSignal
|
871 |
+
AudioSignal with truncation applied.
|
872 |
+
"""
|
873 |
+
self.audio_data = self.audio_data[..., :length_in_samples]
|
874 |
+
return self
|
875 |
+
|
876 |
+
@property
|
877 |
+
def device(self):
|
878 |
+
"""Get device that AudioSignal is on.
|
879 |
+
|
880 |
+
Returns
|
881 |
+
-------
|
882 |
+
torch.device
|
883 |
+
Device that AudioSignal is on.
|
884 |
+
"""
|
885 |
+
if self.audio_data is not None:
|
886 |
+
device = self.audio_data.device
|
887 |
+
elif self.stft_data is not None:
|
888 |
+
device = self.stft_data.device
|
889 |
+
return device
|
890 |
+
|
891 |
+
# Properties
|
892 |
+
@property
|
893 |
+
def audio_data(self):
|
894 |
+
"""Returns the audio data tensor in the object.
|
895 |
+
|
896 |
+
Audio data is always of the shape
|
897 |
+
(batch_size, num_channels, num_samples). If value has less
|
898 |
+
than 3 dims (e.g. is (num_channels, num_samples)), then it will
|
899 |
+
be reshaped to (1, num_channels, num_samples) - a batch size of 1.
|
900 |
+
|
901 |
+
Parameters
|
902 |
+
----------
|
903 |
+
data : typing.Union[torch.Tensor, np.ndarray]
|
904 |
+
Audio data to set.
|
905 |
+
|
906 |
+
Returns
|
907 |
+
-------
|
908 |
+
torch.Tensor
|
909 |
+
Audio samples.
|
910 |
+
"""
|
911 |
+
return self._audio_data
|
912 |
+
|
913 |
+
@audio_data.setter
|
914 |
+
def audio_data(self, data: typing.Union[torch.Tensor, np.ndarray]):
|
915 |
+
if data is not None:
|
916 |
+
assert torch.is_tensor(data), "audio_data should be torch.Tensor"
|
917 |
+
assert data.ndim == 3, "audio_data should be 3-dim (B, C, T)"
|
918 |
+
self._audio_data = data
|
919 |
+
# Old loudness value not guaranteed to be right, reset it.
|
920 |
+
self._loudness = None
|
921 |
+
return
|
922 |
+
|
923 |
+
# alias for audio_data
|
924 |
+
samples = audio_data
|
925 |
+
|
926 |
+
@property
|
927 |
+
def stft_data(self):
|
928 |
+
"""Returns the STFT data inside the signal. Shape is
|
929 |
+
(batch, channels, frequencies, time).
|
930 |
+
|
931 |
+
Returns
|
932 |
+
-------
|
933 |
+
torch.Tensor
|
934 |
+
Complex spectrogram data.
|
935 |
+
"""
|
936 |
+
return self._stft_data
|
937 |
+
|
938 |
+
@stft_data.setter
|
939 |
+
def stft_data(self, data: typing.Union[torch.Tensor, np.ndarray]):
|
940 |
+
if data is not None:
|
941 |
+
assert torch.is_tensor(data) and torch.is_complex(data)
|
942 |
+
if self.stft_data is not None and self.stft_data.shape != data.shape:
|
943 |
+
warnings.warn("stft_data changed shape")
|
944 |
+
self._stft_data = data
|
945 |
+
return
|
946 |
+
|
947 |
+
@property
|
948 |
+
def batch_size(self):
|
949 |
+
"""Batch size of audio signal.
|
950 |
+
|
951 |
+
Returns
|
952 |
+
-------
|
953 |
+
int
|
954 |
+
Batch size of signal.
|
955 |
+
"""
|
956 |
+
return self.audio_data.shape[0]
|
957 |
+
|
958 |
+
@property
|
959 |
+
def signal_length(self):
|
960 |
+
"""Length of audio signal.
|
961 |
+
|
962 |
+
Returns
|
963 |
+
-------
|
964 |
+
int
|
965 |
+
Length of signal in samples.
|
966 |
+
"""
|
967 |
+
return self.audio_data.shape[-1]
|
968 |
+
|
969 |
+
# alias for signal_length
|
970 |
+
length = signal_length
|
971 |
+
|
972 |
+
@property
|
973 |
+
def shape(self):
|
974 |
+
"""Shape of audio data.
|
975 |
+
|
976 |
+
Returns
|
977 |
+
-------
|
978 |
+
tuple
|
979 |
+
Shape of audio data.
|
980 |
+
"""
|
981 |
+
return self.audio_data.shape
|
982 |
+
|
983 |
+
@property
|
984 |
+
def signal_duration(self):
|
985 |
+
"""Length of audio signal in seconds.
|
986 |
+
|
987 |
+
Returns
|
988 |
+
-------
|
989 |
+
float
|
990 |
+
Length of signal in seconds.
|
991 |
+
"""
|
992 |
+
return self.signal_length / self.sample_rate
|
993 |
+
|
994 |
+
# alias for signal_duration
|
995 |
+
duration = signal_duration
|
996 |
+
|
997 |
+
@property
|
998 |
+
def num_channels(self):
|
999 |
+
"""Number of audio channels.
|
1000 |
+
|
1001 |
+
Returns
|
1002 |
+
-------
|
1003 |
+
int
|
1004 |
+
Number of audio channels.
|
1005 |
+
"""
|
1006 |
+
return self.audio_data.shape[1]
|
1007 |
+
|
1008 |
+
# STFT
|
1009 |
+
@staticmethod
|
1010 |
+
@functools.lru_cache(None)
|
1011 |
+
def get_window(window_type: str, window_length: int, device: str):
|
1012 |
+
"""Wrapper around scipy.signal.get_window so one can also get the
|
1013 |
+
popular sqrt-hann window. This function caches for efficiency
|
1014 |
+
using functools.lru\_cache.
|
1015 |
+
|
1016 |
+
Parameters
|
1017 |
+
----------
|
1018 |
+
window_type : str
|
1019 |
+
Type of window to get
|
1020 |
+
window_length : int
|
1021 |
+
Length of the window
|
1022 |
+
device : str
|
1023 |
+
Device to put window onto.
|
1024 |
+
|
1025 |
+
Returns
|
1026 |
+
-------
|
1027 |
+
torch.Tensor
|
1028 |
+
Window returned by scipy.signal.get_window, as a tensor.
|
1029 |
+
"""
|
1030 |
+
from scipy import signal
|
1031 |
+
|
1032 |
+
if window_type == "average":
|
1033 |
+
window = np.ones(window_length) / window_length
|
1034 |
+
elif window_type == "sqrt_hann":
|
1035 |
+
window = np.sqrt(signal.get_window("hann", window_length))
|
1036 |
+
else:
|
1037 |
+
window = signal.get_window(window_type, window_length)
|
1038 |
+
window = torch.from_numpy(window).to(device).float()
|
1039 |
+
return window
|
1040 |
+
|
1041 |
+
@property
|
1042 |
+
def stft_params(self):
|
1043 |
+
"""Returns STFTParams object, which can be re-used to other
|
1044 |
+
AudioSignals.
|
1045 |
+
|
1046 |
+
This property can be set as well. If values are not defined in STFTParams,
|
1047 |
+
they are inferred automatically from the signal properties. The default is to use
|
1048 |
+
32ms windows, with 8ms hop length, and the square root of the hann window.
|
1049 |
+
|
1050 |
+
Returns
|
1051 |
+
-------
|
1052 |
+
STFTParams
|
1053 |
+
STFT parameters for the AudioSignal.
|
1054 |
+
|
1055 |
+
Examples
|
1056 |
+
--------
|
1057 |
+
>>> stft_params = STFTParams(128, 32)
|
1058 |
+
>>> signal1 = AudioSignal(torch.randn(44100), 44100, stft_params=stft_params)
|
1059 |
+
>>> signal2 = AudioSignal(torch.randn(44100), 44100, stft_params=signal1.stft_params)
|
1060 |
+
>>> signal1.stft_params = STFTParams() # Defaults
|
1061 |
+
"""
|
1062 |
+
return self._stft_params
|
1063 |
+
|
1064 |
+
@stft_params.setter
|
1065 |
+
def stft_params(self, value: STFTParams):
|
1066 |
+
default_win_len = int(2 ** (np.ceil(np.log2(0.032 * self.sample_rate))))
|
1067 |
+
default_hop_len = default_win_len // 4
|
1068 |
+
default_win_type = "hann"
|
1069 |
+
default_match_stride = False
|
1070 |
+
default_padding_type = "reflect"
|
1071 |
+
|
1072 |
+
default_stft_params = STFTParams(
|
1073 |
+
window_length=default_win_len,
|
1074 |
+
hop_length=default_hop_len,
|
1075 |
+
window_type=default_win_type,
|
1076 |
+
match_stride=default_match_stride,
|
1077 |
+
padding_type=default_padding_type,
|
1078 |
+
)._asdict()
|
1079 |
+
|
1080 |
+
value = value._asdict() if value else default_stft_params
|
1081 |
+
|
1082 |
+
for key in default_stft_params:
|
1083 |
+
if value[key] is None:
|
1084 |
+
value[key] = default_stft_params[key]
|
1085 |
+
|
1086 |
+
self._stft_params = STFTParams(**value)
|
1087 |
+
self.stft_data = None
|
1088 |
+
|
1089 |
+
def compute_stft_padding(
|
1090 |
+
self, window_length: int, hop_length: int, match_stride: bool
|
1091 |
+
):
|
1092 |
+
"""Compute how the STFT should be padded, based on match\_stride.
|
1093 |
+
|
1094 |
+
Parameters
|
1095 |
+
----------
|
1096 |
+
window_length : int
|
1097 |
+
Window length of STFT.
|
1098 |
+
hop_length : int
|
1099 |
+
Hop length of STFT.
|
1100 |
+
match_stride : bool
|
1101 |
+
Whether or not to match stride, making the STFT have the same alignment as
|
1102 |
+
convolutional layers.
|
1103 |
+
|
1104 |
+
Returns
|
1105 |
+
-------
|
1106 |
+
tuple
|
1107 |
+
Amount to pad on either side of audio.
|
1108 |
+
"""
|
1109 |
+
length = self.signal_length
|
1110 |
+
|
1111 |
+
if match_stride:
|
1112 |
+
assert (
|
1113 |
+
hop_length == window_length // 4
|
1114 |
+
), "For match_stride, hop must equal n_fft // 4"
|
1115 |
+
right_pad = math.ceil(length / hop_length) * hop_length - length
|
1116 |
+
pad = (window_length - hop_length) // 2
|
1117 |
+
else:
|
1118 |
+
right_pad = 0
|
1119 |
+
pad = 0
|
1120 |
+
|
1121 |
+
return right_pad, pad
|
1122 |
+
|
1123 |
+
def stft(
|
1124 |
+
self,
|
1125 |
+
window_length: int = None,
|
1126 |
+
hop_length: int = None,
|
1127 |
+
window_type: str = None,
|
1128 |
+
match_stride: bool = None,
|
1129 |
+
padding_type: str = None,
|
1130 |
+
):
|
1131 |
+
"""Computes the short-time Fourier transform of the audio data,
|
1132 |
+
with specified STFT parameters.
|
1133 |
+
|
1134 |
+
Parameters
|
1135 |
+
----------
|
1136 |
+
window_length : int, optional
|
1137 |
+
Window length of STFT, by default ``0.032 * self.sample_rate``.
|
1138 |
+
hop_length : int, optional
|
1139 |
+
Hop length of STFT, by default ``window_length // 4``.
|
1140 |
+
window_type : str, optional
|
1141 |
+
Type of window to use, by default ``sqrt\_hann``.
|
1142 |
+
match_stride : bool, optional
|
1143 |
+
Whether to match the stride of convolutional layers, by default False
|
1144 |
+
padding_type : str, optional
|
1145 |
+
Type of padding to use, by default 'reflect'
|
1146 |
+
|
1147 |
+
Returns
|
1148 |
+
-------
|
1149 |
+
torch.Tensor
|
1150 |
+
STFT of audio data.
|
1151 |
+
|
1152 |
+
Examples
|
1153 |
+
--------
|
1154 |
+
Compute the STFT of an AudioSignal:
|
1155 |
+
|
1156 |
+
>>> signal = AudioSignal(torch.randn(44100), 44100)
|
1157 |
+
>>> signal.stft()
|
1158 |
+
|
1159 |
+
Vary the window and hop length:
|
1160 |
+
|
1161 |
+
>>> stft_params = [STFTParams(128, 32), STFTParams(512, 128)]
|
1162 |
+
>>> for stft_param in stft_params:
|
1163 |
+
>>> signal.stft_params = stft_params
|
1164 |
+
>>> signal.stft()
|
1165 |
+
|
1166 |
+
"""
|
1167 |
+
window_length = (
|
1168 |
+
self.stft_params.window_length
|
1169 |
+
if window_length is None
|
1170 |
+
else int(window_length)
|
1171 |
+
)
|
1172 |
+
hop_length = (
|
1173 |
+
self.stft_params.hop_length if hop_length is None else int(hop_length)
|
1174 |
+
)
|
1175 |
+
window_type = (
|
1176 |
+
self.stft_params.window_type if window_type is None else window_type
|
1177 |
+
)
|
1178 |
+
match_stride = (
|
1179 |
+
self.stft_params.match_stride if match_stride is None else match_stride
|
1180 |
+
)
|
1181 |
+
padding_type = (
|
1182 |
+
self.stft_params.padding_type if padding_type is None else padding_type
|
1183 |
+
)
|
1184 |
+
|
1185 |
+
window = self.get_window(window_type, window_length, self.audio_data.device)
|
1186 |
+
window = window.to(self.audio_data.device)
|
1187 |
+
|
1188 |
+
audio_data = self.audio_data
|
1189 |
+
right_pad, pad = self.compute_stft_padding(
|
1190 |
+
window_length, hop_length, match_stride
|
1191 |
+
)
|
1192 |
+
audio_data = torch.nn.functional.pad(
|
1193 |
+
audio_data, (pad, pad + right_pad), padding_type
|
1194 |
+
)
|
1195 |
+
stft_data = torch.stft(
|
1196 |
+
audio_data.reshape(-1, audio_data.shape[-1]),
|
1197 |
+
n_fft=window_length,
|
1198 |
+
hop_length=hop_length,
|
1199 |
+
window=window,
|
1200 |
+
return_complex=True,
|
1201 |
+
center=True,
|
1202 |
+
)
|
1203 |
+
_, nf, nt = stft_data.shape
|
1204 |
+
stft_data = stft_data.reshape(self.batch_size, self.num_channels, nf, nt)
|
1205 |
+
|
1206 |
+
if match_stride:
|
1207 |
+
# Drop first two and last two frames, which are added
|
1208 |
+
# because of padding. Now num_frames * hop_length = num_samples.
|
1209 |
+
stft_data = stft_data[..., 2:-2]
|
1210 |
+
self.stft_data = stft_data
|
1211 |
+
|
1212 |
+
return stft_data
|
1213 |
+
|
1214 |
+
def istft(
|
1215 |
+
self,
|
1216 |
+
window_length: int = None,
|
1217 |
+
hop_length: int = None,
|
1218 |
+
window_type: str = None,
|
1219 |
+
match_stride: bool = None,
|
1220 |
+
length: int = None,
|
1221 |
+
):
|
1222 |
+
"""Computes inverse STFT and sets it to audio\_data.
|
1223 |
+
|
1224 |
+
Parameters
|
1225 |
+
----------
|
1226 |
+
window_length : int, optional
|
1227 |
+
Window length of STFT, by default ``0.032 * self.sample_rate``.
|
1228 |
+
hop_length : int, optional
|
1229 |
+
Hop length of STFT, by default ``window_length // 4``.
|
1230 |
+
window_type : str, optional
|
1231 |
+
Type of window to use, by default ``sqrt\_hann``.
|
1232 |
+
match_stride : bool, optional
|
1233 |
+
Whether to match the stride of convolutional layers, by default False
|
1234 |
+
length : int, optional
|
1235 |
+
Original length of signal, by default None
|
1236 |
+
|
1237 |
+
Returns
|
1238 |
+
-------
|
1239 |
+
AudioSignal
|
1240 |
+
AudioSignal with istft applied.
|
1241 |
+
|
1242 |
+
Raises
|
1243 |
+
------
|
1244 |
+
RuntimeError
|
1245 |
+
Raises an error if stft was not called prior to istft on the signal,
|
1246 |
+
or if stft_data is not set.
|
1247 |
+
"""
|
1248 |
+
if self.stft_data is None:
|
1249 |
+
raise RuntimeError("Cannot do inverse STFT without self.stft_data!")
|
1250 |
+
|
1251 |
+
window_length = (
|
1252 |
+
self.stft_params.window_length
|
1253 |
+
if window_length is None
|
1254 |
+
else int(window_length)
|
1255 |
+
)
|
1256 |
+
hop_length = (
|
1257 |
+
self.stft_params.hop_length if hop_length is None else int(hop_length)
|
1258 |
+
)
|
1259 |
+
window_type = (
|
1260 |
+
self.stft_params.window_type if window_type is None else window_type
|
1261 |
+
)
|
1262 |
+
match_stride = (
|
1263 |
+
self.stft_params.match_stride if match_stride is None else match_stride
|
1264 |
+
)
|
1265 |
+
|
1266 |
+
window = self.get_window(window_type, window_length, self.stft_data.device)
|
1267 |
+
|
1268 |
+
nb, nch, nf, nt = self.stft_data.shape
|
1269 |
+
stft_data = self.stft_data.reshape(nb * nch, nf, nt)
|
1270 |
+
right_pad, pad = self.compute_stft_padding(
|
1271 |
+
window_length, hop_length, match_stride
|
1272 |
+
)
|
1273 |
+
|
1274 |
+
if length is None:
|
1275 |
+
length = self.original_signal_length
|
1276 |
+
length = length + 2 * pad + right_pad
|
1277 |
+
|
1278 |
+
if match_stride:
|
1279 |
+
# Zero-pad the STFT on either side, putting back the frames that were
|
1280 |
+
# dropped in stft().
|
1281 |
+
stft_data = torch.nn.functional.pad(stft_data, (2, 2))
|
1282 |
+
|
1283 |
+
audio_data = torch.istft(
|
1284 |
+
stft_data,
|
1285 |
+
n_fft=window_length,
|
1286 |
+
hop_length=hop_length,
|
1287 |
+
window=window,
|
1288 |
+
length=length,
|
1289 |
+
center=True,
|
1290 |
+
)
|
1291 |
+
audio_data = audio_data.reshape(nb, nch, -1)
|
1292 |
+
if match_stride:
|
1293 |
+
audio_data = audio_data[..., pad : -(pad + right_pad)]
|
1294 |
+
self.audio_data = audio_data
|
1295 |
+
|
1296 |
+
return self
|
1297 |
+
|
1298 |
+
@staticmethod
|
1299 |
+
@functools.lru_cache(None)
|
1300 |
+
def get_mel_filters(
|
1301 |
+
sr: int, n_fft: int, n_mels: int, fmin: float = 0.0, fmax: float = None
|
1302 |
+
):
|
1303 |
+
"""Create a Filterbank matrix to combine FFT bins into Mel-frequency bins.
|
1304 |
+
|
1305 |
+
Parameters
|
1306 |
+
----------
|
1307 |
+
sr : int
|
1308 |
+
Sample rate of audio
|
1309 |
+
n_fft : int
|
1310 |
+
Number of FFT bins
|
1311 |
+
n_mels : int
|
1312 |
+
Number of mels
|
1313 |
+
fmin : float, optional
|
1314 |
+
Lowest frequency, in Hz, by default 0.0
|
1315 |
+
fmax : float, optional
|
1316 |
+
Highest frequency, by default None
|
1317 |
+
|
1318 |
+
Returns
|
1319 |
+
-------
|
1320 |
+
np.ndarray [shape=(n_mels, 1 + n_fft/2)]
|
1321 |
+
Mel transform matrix
|
1322 |
+
"""
|
1323 |
+
from librosa.filters import mel as librosa_mel_fn
|
1324 |
+
|
1325 |
+
return librosa_mel_fn(
|
1326 |
+
sr=sr,
|
1327 |
+
n_fft=n_fft,
|
1328 |
+
n_mels=n_mels,
|
1329 |
+
fmin=fmin,
|
1330 |
+
fmax=fmax,
|
1331 |
+
)
|
1332 |
+
|
1333 |
+
def mel_spectrogram(
|
1334 |
+
self, n_mels: int = 80, mel_fmin: float = 0.0, mel_fmax: float = None, **kwargs
|
1335 |
+
):
|
1336 |
+
"""Computes a Mel spectrogram.
|
1337 |
+
|
1338 |
+
Parameters
|
1339 |
+
----------
|
1340 |
+
n_mels : int, optional
|
1341 |
+
Number of mels, by default 80
|
1342 |
+
mel_fmin : float, optional
|
1343 |
+
Lowest frequency, in Hz, by default 0.0
|
1344 |
+
mel_fmax : float, optional
|
1345 |
+
Highest frequency, by default None
|
1346 |
+
kwargs : dict, optional
|
1347 |
+
Keyword arguments to self.stft().
|
1348 |
+
|
1349 |
+
Returns
|
1350 |
+
-------
|
1351 |
+
torch.Tensor [shape=(batch, channels, mels, time)]
|
1352 |
+
Mel spectrogram.
|
1353 |
+
"""
|
1354 |
+
stft = self.stft(**kwargs)
|
1355 |
+
magnitude = torch.abs(stft)
|
1356 |
+
|
1357 |
+
nf = magnitude.shape[2]
|
1358 |
+
mel_basis = self.get_mel_filters(
|
1359 |
+
sr=self.sample_rate,
|
1360 |
+
n_fft=2 * (nf - 1),
|
1361 |
+
n_mels=n_mels,
|
1362 |
+
fmin=mel_fmin,
|
1363 |
+
fmax=mel_fmax,
|
1364 |
+
)
|
1365 |
+
mel_basis = torch.from_numpy(mel_basis).to(self.device)
|
1366 |
+
|
1367 |
+
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
|
1368 |
+
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
|
1369 |
+
return mel_spectrogram
|
1370 |
+
|
1371 |
+
@staticmethod
|
1372 |
+
@functools.lru_cache(None)
|
1373 |
+
def get_dct(n_mfcc: int, n_mels: int, norm: str = "ortho", device: str = None):
|
1374 |
+
"""Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``),
|
1375 |
+
it can be normalized depending on norm. For more information about dct:
|
1376 |
+
http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
|
1377 |
+
|
1378 |
+
Parameters
|
1379 |
+
----------
|
1380 |
+
n_mfcc : int
|
1381 |
+
Number of mfccs
|
1382 |
+
n_mels : int
|
1383 |
+
Number of mels
|
1384 |
+
norm : str
|
1385 |
+
Use "ortho" to get a orthogonal matrix or None, by default "ortho"
|
1386 |
+
device : str, optional
|
1387 |
+
Device to load the transformation matrix on, by default None
|
1388 |
+
|
1389 |
+
Returns
|
1390 |
+
-------
|
1391 |
+
torch.Tensor [shape=(n_mels, n_mfcc)] T
|
1392 |
+
The dct transformation matrix.
|
1393 |
+
"""
|
1394 |
+
from torchaudio.functional import create_dct
|
1395 |
+
|
1396 |
+
return create_dct(n_mfcc, n_mels, norm).to(device)
|
1397 |
+
|
1398 |
+
def mfcc(
|
1399 |
+
self, n_mfcc: int = 40, n_mels: int = 80, log_offset: float = 1e-6, **kwargs
|
1400 |
+
):
|
1401 |
+
"""Computes mel-frequency cepstral coefficients (MFCCs).
|
1402 |
+
|
1403 |
+
Parameters
|
1404 |
+
----------
|
1405 |
+
n_mfcc : int, optional
|
1406 |
+
Number of mels, by default 40
|
1407 |
+
n_mels : int, optional
|
1408 |
+
Number of mels, by default 80
|
1409 |
+
log_offset: float, optional
|
1410 |
+
Small value to prevent numerical issues when trying to compute log(0), by default 1e-6
|
1411 |
+
kwargs : dict, optional
|
1412 |
+
Keyword arguments to self.mel_spectrogram(), note that some of them will be used for self.stft()
|
1413 |
+
|
1414 |
+
Returns
|
1415 |
+
-------
|
1416 |
+
torch.Tensor [shape=(batch, channels, mfccs, time)]
|
1417 |
+
MFCCs.
|
1418 |
+
"""
|
1419 |
+
|
1420 |
+
mel_spectrogram = self.mel_spectrogram(n_mels, **kwargs)
|
1421 |
+
mel_spectrogram = torch.log(mel_spectrogram + log_offset)
|
1422 |
+
dct_mat = self.get_dct(n_mfcc, n_mels, "ortho", self.device)
|
1423 |
+
|
1424 |
+
mfcc = mel_spectrogram.transpose(-1, -2) @ dct_mat
|
1425 |
+
mfcc = mfcc.transpose(-1, -2)
|
1426 |
+
return mfcc
|
1427 |
+
|
1428 |
+
@property
|
1429 |
+
def magnitude(self):
|
1430 |
+
"""Computes and returns the absolute value of the STFT, which
|
1431 |
+
is the magnitude. This value can also be set to some tensor.
|
1432 |
+
When set, ``self.stft_data`` is manipulated so that its magnitude
|
1433 |
+
matches what this is set to, and modulated by the phase.
|
1434 |
+
|
1435 |
+
Returns
|
1436 |
+
-------
|
1437 |
+
torch.Tensor
|
1438 |
+
Magnitude of STFT.
|
1439 |
+
|
1440 |
+
Examples
|
1441 |
+
--------
|
1442 |
+
>>> signal = AudioSignal(torch.randn(44100), 44100)
|
1443 |
+
>>> magnitude = signal.magnitude # Computes stft if not computed
|
1444 |
+
>>> magnitude[magnitude < magnitude.mean()] = 0
|
1445 |
+
>>> signal.magnitude = magnitude
|
1446 |
+
>>> signal.istft()
|
1447 |
+
"""
|
1448 |
+
if self.stft_data is None:
|
1449 |
+
self.stft()
|
1450 |
+
return torch.abs(self.stft_data)
|
1451 |
+
|
1452 |
+
@magnitude.setter
|
1453 |
+
def magnitude(self, value):
|
1454 |
+
self.stft_data = value * torch.exp(1j * self.phase)
|
1455 |
+
return
|
1456 |
+
|
1457 |
+
def log_magnitude(
|
1458 |
+
self, ref_value: float = 1.0, amin: float = 1e-5, top_db: float = 80.0
|
1459 |
+
):
|
1460 |
+
"""Computes the log-magnitude of the spectrogram.
|
1461 |
+
|
1462 |
+
Parameters
|
1463 |
+
----------
|
1464 |
+
ref_value : float, optional
|
1465 |
+
The magnitude is scaled relative to ``ref``: ``20 * log10(S / ref)``.
|
1466 |
+
Zeros in the output correspond to positions where ``S == ref``,
|
1467 |
+
by default 1.0
|
1468 |
+
amin : float, optional
|
1469 |
+
Minimum threshold for ``S`` and ``ref``, by default 1e-5
|
1470 |
+
top_db : float, optional
|
1471 |
+
Threshold the output at ``top_db`` below the peak:
|
1472 |
+
``max(10 * log10(S/ref)) - top_db``, by default -80.0
|
1473 |
+
|
1474 |
+
Returns
|
1475 |
+
-------
|
1476 |
+
torch.Tensor
|
1477 |
+
Log-magnitude spectrogram
|
1478 |
+
"""
|
1479 |
+
magnitude = self.magnitude
|
1480 |
+
|
1481 |
+
amin = amin**2
|
1482 |
+
log_spec = 10.0 * torch.log10(magnitude.pow(2).clamp(min=amin))
|
1483 |
+
log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value))
|
1484 |
+
|
1485 |
+
if top_db is not None:
|
1486 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - top_db)
|
1487 |
+
return log_spec
|
1488 |
+
|
1489 |
+
@property
|
1490 |
+
def phase(self):
|
1491 |
+
"""Computes and returns the phase of the STFT.
|
1492 |
+
This value can also be set to some tensor.
|
1493 |
+
When set, ``self.stft_data`` is manipulated so that its phase
|
1494 |
+
matches what this is set to, we original magnitudeith th.
|
1495 |
+
|
1496 |
+
Returns
|
1497 |
+
-------
|
1498 |
+
torch.Tensor
|
1499 |
+
Phase of STFT.
|
1500 |
+
|
1501 |
+
Examples
|
1502 |
+
--------
|
1503 |
+
>>> signal = AudioSignal(torch.randn(44100), 44100)
|
1504 |
+
>>> phase = signal.phase # Computes stft if not computed
|
1505 |
+
>>> phase[phase < phase.mean()] = 0
|
1506 |
+
>>> signal.phase = phase
|
1507 |
+
>>> signal.istft()
|
1508 |
+
"""
|
1509 |
+
if self.stft_data is None:
|
1510 |
+
self.stft()
|
1511 |
+
return torch.angle(self.stft_data)
|
1512 |
+
|
1513 |
+
@phase.setter
|
1514 |
+
def phase(self, value):
|
1515 |
+
self.stft_data = self.magnitude * torch.exp(1j * value)
|
1516 |
+
return
|
1517 |
+
|
1518 |
+
# Operator overloading
|
1519 |
+
def __add__(self, other):
|
1520 |
+
new_signal = self.clone()
|
1521 |
+
new_signal.audio_data += util._get_value(other)
|
1522 |
+
return new_signal
|
1523 |
+
|
1524 |
+
def __iadd__(self, other):
|
1525 |
+
self.audio_data += util._get_value(other)
|
1526 |
+
return self
|
1527 |
+
|
1528 |
+
def __radd__(self, other):
|
1529 |
+
return self + other
|
1530 |
+
|
1531 |
+
def __sub__(self, other):
|
1532 |
+
new_signal = self.clone()
|
1533 |
+
new_signal.audio_data -= util._get_value(other)
|
1534 |
+
return new_signal
|
1535 |
+
|
1536 |
+
def __isub__(self, other):
|
1537 |
+
self.audio_data -= util._get_value(other)
|
1538 |
+
return self
|
1539 |
+
|
1540 |
+
def __mul__(self, other):
|
1541 |
+
new_signal = self.clone()
|
1542 |
+
new_signal.audio_data *= util._get_value(other)
|
1543 |
+
return new_signal
|
1544 |
+
|
1545 |
+
def __imul__(self, other):
|
1546 |
+
self.audio_data *= util._get_value(other)
|
1547 |
+
return self
|
1548 |
+
|
1549 |
+
def __rmul__(self, other):
|
1550 |
+
return self * other
|
1551 |
+
|
1552 |
+
# Representation
|
1553 |
+
def _info(self):
|
1554 |
+
dur = f"{self.signal_duration:0.3f}" if self.signal_duration else "[unknown]"
|
1555 |
+
info = {
|
1556 |
+
"duration": f"{dur} seconds",
|
1557 |
+
"batch_size": self.batch_size,
|
1558 |
+
"path": self.path_to_file if self.path_to_file else "path unknown",
|
1559 |
+
"sample_rate": self.sample_rate,
|
1560 |
+
"num_channels": self.num_channels if self.num_channels else "[unknown]",
|
1561 |
+
"audio_data.shape": self.audio_data.shape,
|
1562 |
+
"stft_params": self.stft_params,
|
1563 |
+
"device": self.device,
|
1564 |
+
}
|
1565 |
+
|
1566 |
+
return info
|
1567 |
+
|
1568 |
+
def markdown(self):
|
1569 |
+
"""Produces a markdown representation of AudioSignal, in a markdown table.
|
1570 |
+
|
1571 |
+
Returns
|
1572 |
+
-------
|
1573 |
+
str
|
1574 |
+
Markdown representation of AudioSignal.
|
1575 |
+
|
1576 |
+
Examples
|
1577 |
+
--------
|
1578 |
+
>>> signal = AudioSignal(torch.randn(44100), 44100)
|
1579 |
+
>>> print(signal.markdown())
|
1580 |
+
| Key | Value
|
1581 |
+
|---|---
|
1582 |
+
| duration | 1.000 seconds |
|
1583 |
+
| batch_size | 1 |
|
1584 |
+
| path | path unknown |
|
1585 |
+
| sample_rate | 44100 |
|
1586 |
+
| num_channels | 1 |
|
1587 |
+
| audio_data.shape | torch.Size([1, 1, 44100]) |
|
1588 |
+
| stft_params | STFTParams(window_length=2048, hop_length=512, window_type='sqrt_hann', match_stride=False) |
|
1589 |
+
| device | cpu |
|
1590 |
+
"""
|
1591 |
+
info = self._info()
|
1592 |
+
|
1593 |
+
FORMAT = "| Key | Value \n" "|---|--- \n"
|
1594 |
+
for k, v in info.items():
|
1595 |
+
row = f"| {k} | {v} |\n"
|
1596 |
+
FORMAT += row
|
1597 |
+
return FORMAT
|
1598 |
+
|
1599 |
+
def __str__(self):
|
1600 |
+
info = self._info()
|
1601 |
+
|
1602 |
+
desc = ""
|
1603 |
+
for k, v in info.items():
|
1604 |
+
desc += f"{k}: {v}\n"
|
1605 |
+
return desc
|
1606 |
+
|
1607 |
+
def __rich__(self):
|
1608 |
+
from rich.table import Table
|
1609 |
+
|
1610 |
+
info = self._info()
|
1611 |
+
|
1612 |
+
table = Table(title=f"{self.__class__.__name__}")
|
1613 |
+
table.add_column("Key", style="green")
|
1614 |
+
table.add_column("Value", style="cyan")
|
1615 |
+
|
1616 |
+
for k, v in info.items():
|
1617 |
+
table.add_row(k, str(v))
|
1618 |
+
return table
|
1619 |
+
|
1620 |
+
# Comparison
|
1621 |
+
def __eq__(self, other):
|
1622 |
+
for k, v in list(self.__dict__.items()):
|
1623 |
+
if torch.is_tensor(v):
|
1624 |
+
if not torch.allclose(v, other.__dict__[k], atol=1e-6):
|
1625 |
+
max_error = (v - other.__dict__[k]).abs().max()
|
1626 |
+
print(f"Max abs error for {k}: {max_error}")
|
1627 |
+
return False
|
1628 |
+
return True
|
1629 |
+
|
1630 |
+
# Indexing
|
1631 |
+
def __getitem__(self, key):
|
1632 |
+
if torch.is_tensor(key) and key.ndim == 0 and key.item() is True:
|
1633 |
+
assert self.batch_size == 1
|
1634 |
+
audio_data = self.audio_data
|
1635 |
+
_loudness = self._loudness
|
1636 |
+
stft_data = self.stft_data
|
1637 |
+
|
1638 |
+
elif isinstance(key, (bool, int, list, slice, tuple)) or (
|
1639 |
+
torch.is_tensor(key) and key.ndim <= 1
|
1640 |
+
):
|
1641 |
+
# Indexing only on the batch dimension.
|
1642 |
+
# Then let's copy over relevant stuff.
|
1643 |
+
# Future work: make this work for time-indexing
|
1644 |
+
# as well, using the hop length.
|
1645 |
+
audio_data = self.audio_data[key]
|
1646 |
+
_loudness = self._loudness[key] if self._loudness is not None else None
|
1647 |
+
stft_data = self.stft_data[key] if self.stft_data is not None else None
|
1648 |
+
|
1649 |
+
sources = None
|
1650 |
+
|
1651 |
+
copy = type(self)(audio_data, self.sample_rate, stft_params=self.stft_params)
|
1652 |
+
copy._loudness = _loudness
|
1653 |
+
copy._stft_data = stft_data
|
1654 |
+
copy.sources = sources
|
1655 |
+
|
1656 |
+
return copy
|
1657 |
+
|
1658 |
+
def __setitem__(self, key, value):
|
1659 |
+
if not isinstance(value, type(self)):
|
1660 |
+
self.audio_data[key] = value
|
1661 |
+
return
|
1662 |
+
|
1663 |
+
if torch.is_tensor(key) and key.ndim == 0 and key.item() is True:
|
1664 |
+
assert self.batch_size == 1
|
1665 |
+
self.audio_data = value.audio_data
|
1666 |
+
self._loudness = value._loudness
|
1667 |
+
self.stft_data = value.stft_data
|
1668 |
+
return
|
1669 |
+
|
1670 |
+
elif isinstance(key, (bool, int, list, slice, tuple)) or (
|
1671 |
+
torch.is_tensor(key) and key.ndim <= 1
|
1672 |
+
):
|
1673 |
+
if self.audio_data is not None and value.audio_data is not None:
|
1674 |
+
self.audio_data[key] = value.audio_data
|
1675 |
+
if self._loudness is not None and value._loudness is not None:
|
1676 |
+
self._loudness[key] = value._loudness
|
1677 |
+
if self.stft_data is not None and value.stft_data is not None:
|
1678 |
+
self.stft_data[key] = value.stft_data
|
1679 |
+
return
|
1680 |
+
|
1681 |
+
def __ne__(self, other):
|
1682 |
+
return not self == other
|
audiotools/core/display.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import typing
|
3 |
+
from functools import wraps
|
4 |
+
|
5 |
+
from . import util
|
6 |
+
|
7 |
+
|
8 |
+
def format_figure(func):
|
9 |
+
"""Decorator for formatting figures produced by the code below.
|
10 |
+
See :py:func:`audiotools.core.util.format_figure` for more.
|
11 |
+
|
12 |
+
Parameters
|
13 |
+
----------
|
14 |
+
func : Callable
|
15 |
+
Plotting function that is decorated by this function.
|
16 |
+
|
17 |
+
"""
|
18 |
+
|
19 |
+
@wraps(func)
|
20 |
+
def wrapper(*args, **kwargs):
|
21 |
+
f_keys = inspect.signature(util.format_figure).parameters.keys()
|
22 |
+
f_kwargs = {}
|
23 |
+
for k, v in list(kwargs.items()):
|
24 |
+
if k in f_keys:
|
25 |
+
kwargs.pop(k)
|
26 |
+
f_kwargs[k] = v
|
27 |
+
func(*args, **kwargs)
|
28 |
+
util.format_figure(**f_kwargs)
|
29 |
+
|
30 |
+
return wrapper
|
31 |
+
|
32 |
+
|
33 |
+
class DisplayMixin:
|
34 |
+
@format_figure
|
35 |
+
def specshow(
|
36 |
+
self,
|
37 |
+
preemphasis: bool = False,
|
38 |
+
x_axis: str = "time",
|
39 |
+
y_axis: str = "linear",
|
40 |
+
n_mels: int = 128,
|
41 |
+
**kwargs,
|
42 |
+
):
|
43 |
+
"""Displays a spectrogram, using ``librosa.display.specshow``.
|
44 |
+
|
45 |
+
Parameters
|
46 |
+
----------
|
47 |
+
preemphasis : bool, optional
|
48 |
+
Whether or not to apply preemphasis, which makes high
|
49 |
+
frequency detail easier to see, by default False
|
50 |
+
x_axis : str, optional
|
51 |
+
How to label the x axis, by default "time"
|
52 |
+
y_axis : str, optional
|
53 |
+
How to label the y axis, by default "linear"
|
54 |
+
n_mels : int, optional
|
55 |
+
If displaying a mel spectrogram with ``y_axis = "mel"``,
|
56 |
+
this controls the number of mels, by default 128.
|
57 |
+
kwargs : dict, optional
|
58 |
+
Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
|
59 |
+
"""
|
60 |
+
import librosa
|
61 |
+
import librosa.display
|
62 |
+
|
63 |
+
# Always re-compute the STFT data before showing it, in case
|
64 |
+
# it changed.
|
65 |
+
signal = self.clone()
|
66 |
+
signal.stft_data = None
|
67 |
+
|
68 |
+
if preemphasis:
|
69 |
+
signal.preemphasis()
|
70 |
+
|
71 |
+
ref = signal.magnitude.max()
|
72 |
+
log_mag = signal.log_magnitude(ref_value=ref)
|
73 |
+
|
74 |
+
if y_axis == "mel":
|
75 |
+
log_mag = 20 * signal.mel_spectrogram(n_mels).clamp(1e-5).log10()
|
76 |
+
log_mag -= log_mag.max()
|
77 |
+
|
78 |
+
librosa.display.specshow(
|
79 |
+
log_mag.numpy()[0].mean(axis=0),
|
80 |
+
x_axis=x_axis,
|
81 |
+
y_axis=y_axis,
|
82 |
+
sr=signal.sample_rate,
|
83 |
+
**kwargs,
|
84 |
+
)
|
85 |
+
|
86 |
+
@format_figure
|
87 |
+
def waveplot(self, x_axis: str = "time", **kwargs):
|
88 |
+
"""Displays a waveform plot, using ``librosa.display.waveshow``.
|
89 |
+
|
90 |
+
Parameters
|
91 |
+
----------
|
92 |
+
x_axis : str, optional
|
93 |
+
How to label the x axis, by default "time"
|
94 |
+
kwargs : dict, optional
|
95 |
+
Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
|
96 |
+
"""
|
97 |
+
import librosa
|
98 |
+
import librosa.display
|
99 |
+
|
100 |
+
audio_data = self.audio_data[0].mean(dim=0)
|
101 |
+
audio_data = audio_data.cpu().numpy()
|
102 |
+
|
103 |
+
plot_fn = "waveshow" if hasattr(librosa.display, "waveshow") else "waveplot"
|
104 |
+
wave_plot_fn = getattr(librosa.display, plot_fn)
|
105 |
+
wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs)
|
106 |
+
|
107 |
+
@format_figure
|
108 |
+
def wavespec(self, x_axis: str = "time", **kwargs):
|
109 |
+
"""Displays a waveform plot, using ``librosa.display.waveshow``.
|
110 |
+
|
111 |
+
Parameters
|
112 |
+
----------
|
113 |
+
x_axis : str, optional
|
114 |
+
How to label the x axis, by default "time"
|
115 |
+
kwargs : dict, optional
|
116 |
+
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`.
|
117 |
+
"""
|
118 |
+
import matplotlib.pyplot as plt
|
119 |
+
from matplotlib.gridspec import GridSpec
|
120 |
+
|
121 |
+
gs = GridSpec(6, 1)
|
122 |
+
plt.subplot(gs[0, :])
|
123 |
+
self.waveplot(x_axis=x_axis)
|
124 |
+
plt.subplot(gs[1:, :])
|
125 |
+
self.specshow(x_axis=x_axis, **kwargs)
|
126 |
+
|
127 |
+
def write_audio_to_tb(
|
128 |
+
self,
|
129 |
+
tag: str,
|
130 |
+
writer,
|
131 |
+
step: int = None,
|
132 |
+
plot_fn: typing.Union[typing.Callable, str] = "specshow",
|
133 |
+
**kwargs,
|
134 |
+
):
|
135 |
+
"""Writes a signal and its spectrogram to Tensorboard. Will show up
|
136 |
+
under the Audio and Images tab in Tensorboard.
|
137 |
+
|
138 |
+
Parameters
|
139 |
+
----------
|
140 |
+
tag : str
|
141 |
+
Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be
|
142 |
+
written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``).
|
143 |
+
writer : SummaryWriter
|
144 |
+
A SummaryWriter object from PyTorch library.
|
145 |
+
step : int, optional
|
146 |
+
The step to write the signal to, by default None
|
147 |
+
plot_fn : typing.Union[typing.Callable, str], optional
|
148 |
+
How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
|
149 |
+
kwargs : dict, optional
|
150 |
+
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
|
151 |
+
whatever ``plot_fn`` is set to.
|
152 |
+
"""
|
153 |
+
import matplotlib.pyplot as plt
|
154 |
+
|
155 |
+
audio_data = self.audio_data[0, 0].detach().cpu()
|
156 |
+
sample_rate = self.sample_rate
|
157 |
+
writer.add_audio(tag, audio_data, step, sample_rate)
|
158 |
+
|
159 |
+
if plot_fn is not None:
|
160 |
+
if isinstance(plot_fn, str):
|
161 |
+
plot_fn = getattr(self, plot_fn)
|
162 |
+
fig = plt.figure()
|
163 |
+
plt.clf()
|
164 |
+
plot_fn(**kwargs)
|
165 |
+
writer.add_figure(tag.replace("wav", "png"), fig, step)
|
166 |
+
|
167 |
+
def save_image(
|
168 |
+
self,
|
169 |
+
image_path: str,
|
170 |
+
plot_fn: typing.Union[typing.Callable, str] = "specshow",
|
171 |
+
**kwargs,
|
172 |
+
):
|
173 |
+
"""Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to
|
174 |
+
a specified file.
|
175 |
+
|
176 |
+
Parameters
|
177 |
+
----------
|
178 |
+
image_path : str
|
179 |
+
Where to save the file to.
|
180 |
+
plot_fn : typing.Union[typing.Callable, str], optional
|
181 |
+
How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
|
182 |
+
kwargs : dict, optional
|
183 |
+
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
|
184 |
+
whatever ``plot_fn`` is set to.
|
185 |
+
"""
|
186 |
+
import matplotlib.pyplot as plt
|
187 |
+
|
188 |
+
if isinstance(plot_fn, str):
|
189 |
+
plot_fn = getattr(self, plot_fn)
|
190 |
+
|
191 |
+
plt.clf()
|
192 |
+
plot_fn(**kwargs)
|
193 |
+
plt.savefig(image_path, bbox_inches="tight", pad_inches=0)
|
194 |
+
plt.close()
|
audiotools/core/dsp.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing
|
2 |
+
|
3 |
+
import julius
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from . import util
|
8 |
+
|
9 |
+
|
10 |
+
class DSPMixin:
|
11 |
+
_original_batch_size = None
|
12 |
+
_original_num_channels = None
|
13 |
+
_padded_signal_length = None
|
14 |
+
|
15 |
+
def _preprocess_signal_for_windowing(self, window_duration, hop_duration):
|
16 |
+
self._original_batch_size = self.batch_size
|
17 |
+
self._original_num_channels = self.num_channels
|
18 |
+
|
19 |
+
window_length = int(window_duration * self.sample_rate)
|
20 |
+
hop_length = int(hop_duration * self.sample_rate)
|
21 |
+
|
22 |
+
if window_length % hop_length != 0:
|
23 |
+
factor = window_length // hop_length
|
24 |
+
window_length = factor * hop_length
|
25 |
+
|
26 |
+
self.zero_pad(hop_length, hop_length)
|
27 |
+
self._padded_signal_length = self.signal_length
|
28 |
+
|
29 |
+
return window_length, hop_length
|
30 |
+
|
31 |
+
def windows(
|
32 |
+
self, window_duration: float, hop_duration: float, preprocess: bool = True
|
33 |
+
):
|
34 |
+
"""Generator which yields windows of specified duration from signal with a specified
|
35 |
+
hop length.
|
36 |
+
|
37 |
+
Parameters
|
38 |
+
----------
|
39 |
+
window_duration : float
|
40 |
+
Duration of every window in seconds.
|
41 |
+
hop_duration : float
|
42 |
+
Hop between windows in seconds.
|
43 |
+
preprocess : bool, optional
|
44 |
+
Whether to preprocess the signal, so that the first sample is in
|
45 |
+
the middle of the first window, by default True
|
46 |
+
|
47 |
+
Yields
|
48 |
+
------
|
49 |
+
AudioSignal
|
50 |
+
Each window is returned as an AudioSignal.
|
51 |
+
"""
|
52 |
+
if preprocess:
|
53 |
+
window_length, hop_length = self._preprocess_signal_for_windowing(
|
54 |
+
window_duration, hop_duration
|
55 |
+
)
|
56 |
+
|
57 |
+
self.audio_data = self.audio_data.reshape(-1, 1, self.signal_length)
|
58 |
+
|
59 |
+
for b in range(self.batch_size):
|
60 |
+
i = 0
|
61 |
+
start_idx = i * hop_length
|
62 |
+
while True:
|
63 |
+
start_idx = i * hop_length
|
64 |
+
i += 1
|
65 |
+
end_idx = start_idx + window_length
|
66 |
+
if end_idx > self.signal_length:
|
67 |
+
break
|
68 |
+
yield self[b, ..., start_idx:end_idx]
|
69 |
+
|
70 |
+
def collect_windows(
|
71 |
+
self, window_duration: float, hop_duration: float, preprocess: bool = True
|
72 |
+
):
|
73 |
+
"""Reshapes signal into windows of specified duration from signal with a specified
|
74 |
+
hop length. Window are placed along the batch dimension. Use with
|
75 |
+
:py:func:`audiotools.core.dsp.DSPMixin.overlap_and_add` to reconstruct the
|
76 |
+
original signal.
|
77 |
+
|
78 |
+
Parameters
|
79 |
+
----------
|
80 |
+
window_duration : float
|
81 |
+
Duration of every window in seconds.
|
82 |
+
hop_duration : float
|
83 |
+
Hop between windows in seconds.
|
84 |
+
preprocess : bool, optional
|
85 |
+
Whether to preprocess the signal, so that the first sample is in
|
86 |
+
the middle of the first window, by default True
|
87 |
+
|
88 |
+
Returns
|
89 |
+
-------
|
90 |
+
AudioSignal
|
91 |
+
AudioSignal unfolded with shape ``(nb * nch * num_windows, 1, window_length)``
|
92 |
+
"""
|
93 |
+
if preprocess:
|
94 |
+
window_length, hop_length = self._preprocess_signal_for_windowing(
|
95 |
+
window_duration, hop_duration
|
96 |
+
)
|
97 |
+
|
98 |
+
# self.audio_data: (nb, nch, nt).
|
99 |
+
unfolded = torch.nn.functional.unfold(
|
100 |
+
self.audio_data.reshape(-1, 1, 1, self.signal_length),
|
101 |
+
kernel_size=(1, window_length),
|
102 |
+
stride=(1, hop_length),
|
103 |
+
)
|
104 |
+
# unfolded: (nb * nch, window_length, num_windows).
|
105 |
+
# -> (nb * nch * num_windows, 1, window_length)
|
106 |
+
unfolded = unfolded.permute(0, 2, 1).reshape(-1, 1, window_length)
|
107 |
+
self.audio_data = unfolded
|
108 |
+
return self
|
109 |
+
|
110 |
+
def overlap_and_add(self, hop_duration: float):
|
111 |
+
"""Function which takes a list of windows and overlap adds them into a
|
112 |
+
signal the same length as ``audio_signal``.
|
113 |
+
|
114 |
+
Parameters
|
115 |
+
----------
|
116 |
+
hop_duration : float
|
117 |
+
How much to shift for each window
|
118 |
+
(overlap is window_duration - hop_duration) in seconds.
|
119 |
+
|
120 |
+
Returns
|
121 |
+
-------
|
122 |
+
AudioSignal
|
123 |
+
overlap-and-added signal.
|
124 |
+
"""
|
125 |
+
hop_length = int(hop_duration * self.sample_rate)
|
126 |
+
window_length = self.signal_length
|
127 |
+
|
128 |
+
nb, nch = self._original_batch_size, self._original_num_channels
|
129 |
+
|
130 |
+
unfolded = self.audio_data.reshape(nb * nch, -1, window_length).permute(0, 2, 1)
|
131 |
+
folded = torch.nn.functional.fold(
|
132 |
+
unfolded,
|
133 |
+
output_size=(1, self._padded_signal_length),
|
134 |
+
kernel_size=(1, window_length),
|
135 |
+
stride=(1, hop_length),
|
136 |
+
)
|
137 |
+
|
138 |
+
norm = torch.ones_like(unfolded, device=unfolded.device)
|
139 |
+
norm = torch.nn.functional.fold(
|
140 |
+
norm,
|
141 |
+
output_size=(1, self._padded_signal_length),
|
142 |
+
kernel_size=(1, window_length),
|
143 |
+
stride=(1, hop_length),
|
144 |
+
)
|
145 |
+
|
146 |
+
folded = folded / norm
|
147 |
+
|
148 |
+
folded = folded.reshape(nb, nch, -1)
|
149 |
+
self.audio_data = folded
|
150 |
+
self.trim(hop_length, hop_length)
|
151 |
+
return self
|
152 |
+
|
153 |
+
def low_pass(
|
154 |
+
self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51
|
155 |
+
):
|
156 |
+
"""Low-passes the signal in-place. Each item in the batch
|
157 |
+
can have a different low-pass cutoff, if the input
|
158 |
+
to this signal is an array or tensor. If a float, all
|
159 |
+
items are given the same low-pass filter.
|
160 |
+
|
161 |
+
Parameters
|
162 |
+
----------
|
163 |
+
cutoffs : typing.Union[torch.Tensor, np.ndarray, float]
|
164 |
+
Cutoff in Hz of low-pass filter.
|
165 |
+
zeros : int, optional
|
166 |
+
Number of taps to use in low-pass filter, by default 51
|
167 |
+
|
168 |
+
Returns
|
169 |
+
-------
|
170 |
+
AudioSignal
|
171 |
+
Low-passed AudioSignal.
|
172 |
+
"""
|
173 |
+
cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
|
174 |
+
cutoffs = cutoffs / self.sample_rate
|
175 |
+
filtered = torch.empty_like(self.audio_data)
|
176 |
+
|
177 |
+
for i, cutoff in enumerate(cutoffs):
|
178 |
+
lp_filter = julius.LowPassFilter(cutoff.cpu(), zeros=zeros).to(self.device)
|
179 |
+
filtered[i] = lp_filter(self.audio_data[i])
|
180 |
+
|
181 |
+
self.audio_data = filtered
|
182 |
+
self.stft_data = None
|
183 |
+
return self
|
184 |
+
|
185 |
+
def high_pass(
|
186 |
+
self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51
|
187 |
+
):
|
188 |
+
"""High-passes the signal in-place. Each item in the batch
|
189 |
+
can have a different high-pass cutoff, if the input
|
190 |
+
to this signal is an array or tensor. If a float, all
|
191 |
+
items are given the same high-pass filter.
|
192 |
+
|
193 |
+
Parameters
|
194 |
+
----------
|
195 |
+
cutoffs : typing.Union[torch.Tensor, np.ndarray, float]
|
196 |
+
Cutoff in Hz of high-pass filter.
|
197 |
+
zeros : int, optional
|
198 |
+
Number of taps to use in high-pass filter, by default 51
|
199 |
+
|
200 |
+
Returns
|
201 |
+
-------
|
202 |
+
AudioSignal
|
203 |
+
High-passed AudioSignal.
|
204 |
+
"""
|
205 |
+
cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
|
206 |
+
cutoffs = cutoffs / self.sample_rate
|
207 |
+
filtered = torch.empty_like(self.audio_data)
|
208 |
+
|
209 |
+
for i, cutoff in enumerate(cutoffs):
|
210 |
+
hp_filter = julius.HighPassFilter(cutoff.cpu(), zeros=zeros).to(self.device)
|
211 |
+
filtered[i] = hp_filter(self.audio_data[i])
|
212 |
+
|
213 |
+
self.audio_data = filtered
|
214 |
+
self.stft_data = None
|
215 |
+
return self
|
216 |
+
|
217 |
+
def mask_frequencies(
|
218 |
+
self,
|
219 |
+
fmin_hz: typing.Union[torch.Tensor, np.ndarray, float],
|
220 |
+
fmax_hz: typing.Union[torch.Tensor, np.ndarray, float],
|
221 |
+
val: float = 0.0,
|
222 |
+
):
|
223 |
+
"""Masks frequencies between ``fmin_hz`` and ``fmax_hz``, and fills them
|
224 |
+
with the value specified by ``val``. Useful for implementing SpecAug.
|
225 |
+
The min and max can be different for every item in the batch.
|
226 |
+
|
227 |
+
Parameters
|
228 |
+
----------
|
229 |
+
fmin_hz : typing.Union[torch.Tensor, np.ndarray, float]
|
230 |
+
Lower end of band to mask out.
|
231 |
+
fmax_hz : typing.Union[torch.Tensor, np.ndarray, float]
|
232 |
+
Upper end of band to mask out.
|
233 |
+
val : float, optional
|
234 |
+
Value to fill in, by default 0.0
|
235 |
+
|
236 |
+
Returns
|
237 |
+
-------
|
238 |
+
AudioSignal
|
239 |
+
Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
|
240 |
+
masked audio data.
|
241 |
+
"""
|
242 |
+
# SpecAug
|
243 |
+
mag, phase = self.magnitude, self.phase
|
244 |
+
fmin_hz = util.ensure_tensor(fmin_hz, ndim=mag.ndim)
|
245 |
+
fmax_hz = util.ensure_tensor(fmax_hz, ndim=mag.ndim)
|
246 |
+
assert torch.all(fmin_hz < fmax_hz)
|
247 |
+
|
248 |
+
# build mask
|
249 |
+
nbins = mag.shape[-2]
|
250 |
+
bins_hz = torch.linspace(0, self.sample_rate / 2, nbins, device=self.device)
|
251 |
+
bins_hz = bins_hz[None, None, :, None].repeat(
|
252 |
+
self.batch_size, 1, 1, mag.shape[-1]
|
253 |
+
)
|
254 |
+
mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz)
|
255 |
+
mask = mask.to(self.device)
|
256 |
+
|
257 |
+
mag = mag.masked_fill(mask, val)
|
258 |
+
phase = phase.masked_fill(mask, val)
|
259 |
+
self.stft_data = mag * torch.exp(1j * phase)
|
260 |
+
return self
|
261 |
+
|
262 |
+
def mask_timesteps(
|
263 |
+
self,
|
264 |
+
tmin_s: typing.Union[torch.Tensor, np.ndarray, float],
|
265 |
+
tmax_s: typing.Union[torch.Tensor, np.ndarray, float],
|
266 |
+
val: float = 0.0,
|
267 |
+
):
|
268 |
+
"""Masks timesteps between ``tmin_s`` and ``tmax_s``, and fills them
|
269 |
+
with the value specified by ``val``. Useful for implementing SpecAug.
|
270 |
+
The min and max can be different for every item in the batch.
|
271 |
+
|
272 |
+
Parameters
|
273 |
+
----------
|
274 |
+
tmin_s : typing.Union[torch.Tensor, np.ndarray, float]
|
275 |
+
Lower end of timesteps to mask out.
|
276 |
+
tmax_s : typing.Union[torch.Tensor, np.ndarray, float]
|
277 |
+
Upper end of timesteps to mask out.
|
278 |
+
val : float, optional
|
279 |
+
Value to fill in, by default 0.0
|
280 |
+
|
281 |
+
Returns
|
282 |
+
-------
|
283 |
+
AudioSignal
|
284 |
+
Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
|
285 |
+
masked audio data.
|
286 |
+
"""
|
287 |
+
# SpecAug
|
288 |
+
mag, phase = self.magnitude, self.phase
|
289 |
+
tmin_s = util.ensure_tensor(tmin_s, ndim=mag.ndim)
|
290 |
+
tmax_s = util.ensure_tensor(tmax_s, ndim=mag.ndim)
|
291 |
+
|
292 |
+
assert torch.all(tmin_s < tmax_s)
|
293 |
+
|
294 |
+
# build mask
|
295 |
+
nt = mag.shape[-1]
|
296 |
+
bins_t = torch.linspace(0, self.signal_duration, nt, device=self.device)
|
297 |
+
bins_t = bins_t[None, None, None, :].repeat(
|
298 |
+
self.batch_size, 1, mag.shape[-2], 1
|
299 |
+
)
|
300 |
+
mask = (tmin_s <= bins_t) & (bins_t < tmax_s)
|
301 |
+
|
302 |
+
mag = mag.masked_fill(mask, val)
|
303 |
+
phase = phase.masked_fill(mask, val)
|
304 |
+
self.stft_data = mag * torch.exp(1j * phase)
|
305 |
+
return self
|
306 |
+
|
307 |
+
def mask_low_magnitudes(
|
308 |
+
self, db_cutoff: typing.Union[torch.Tensor, np.ndarray, float], val: float = 0.0
|
309 |
+
):
|
310 |
+
"""Mask away magnitudes below a specified threshold, which
|
311 |
+
can be different for every item in the batch.
|
312 |
+
|
313 |
+
Parameters
|
314 |
+
----------
|
315 |
+
db_cutoff : typing.Union[torch.Tensor, np.ndarray, float]
|
316 |
+
Decibel value for which things below it will be masked away.
|
317 |
+
val : float, optional
|
318 |
+
Value to fill in for masked portions, by default 0.0
|
319 |
+
|
320 |
+
Returns
|
321 |
+
-------
|
322 |
+
AudioSignal
|
323 |
+
Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
|
324 |
+
masked audio data.
|
325 |
+
"""
|
326 |
+
mag = self.magnitude
|
327 |
+
log_mag = self.log_magnitude()
|
328 |
+
|
329 |
+
db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim)
|
330 |
+
mask = log_mag < db_cutoff
|
331 |
+
mag = mag.masked_fill(mask, val)
|
332 |
+
|
333 |
+
self.magnitude = mag
|
334 |
+
return self
|
335 |
+
|
336 |
+
def shift_phase(self, shift: typing.Union[torch.Tensor, np.ndarray, float]):
|
337 |
+
"""Shifts the phase by a constant value.
|
338 |
+
|
339 |
+
Parameters
|
340 |
+
----------
|
341 |
+
shift : typing.Union[torch.Tensor, np.ndarray, float]
|
342 |
+
What to shift the phase by.
|
343 |
+
|
344 |
+
Returns
|
345 |
+
-------
|
346 |
+
AudioSignal
|
347 |
+
Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
|
348 |
+
masked audio data.
|
349 |
+
"""
|
350 |
+
shift = util.ensure_tensor(shift, ndim=self.phase.ndim)
|
351 |
+
self.phase = self.phase + shift
|
352 |
+
return self
|
353 |
+
|
354 |
+
def corrupt_phase(self, scale: typing.Union[torch.Tensor, np.ndarray, float]):
|
355 |
+
"""Corrupts the phase randomly by some scaled value.
|
356 |
+
|
357 |
+
Parameters
|
358 |
+
----------
|
359 |
+
scale : typing.Union[torch.Tensor, np.ndarray, float]
|
360 |
+
Standard deviation of noise to add to the phase.
|
361 |
+
|
362 |
+
Returns
|
363 |
+
-------
|
364 |
+
AudioSignal
|
365 |
+
Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
|
366 |
+
masked audio data.
|
367 |
+
"""
|
368 |
+
scale = util.ensure_tensor(scale, ndim=self.phase.ndim)
|
369 |
+
self.phase = self.phase + scale * torch.randn_like(self.phase)
|
370 |
+
return self
|
371 |
+
|
372 |
+
def preemphasis(self, coef: float = 0.85):
|
373 |
+
"""Applies pre-emphasis to audio signal.
|
374 |
+
|
375 |
+
Parameters
|
376 |
+
----------
|
377 |
+
coef : float, optional
|
378 |
+
How much pre-emphasis to apply, lower values do less. 0 does nothing.
|
379 |
+
by default 0.85
|
380 |
+
|
381 |
+
Returns
|
382 |
+
-------
|
383 |
+
AudioSignal
|
384 |
+
Pre-emphasized signal.
|
385 |
+
"""
|
386 |
+
kernel = torch.tensor([1, -coef, 0]).view(1, 1, -1).to(self.device)
|
387 |
+
x = self.audio_data.reshape(-1, 1, self.signal_length)
|
388 |
+
x = torch.nn.functional.conv1d(x, kernel, padding=1)
|
389 |
+
self.audio_data = x.reshape(*self.audio_data.shape)
|
390 |
+
return self
|
audiotools/core/effects.py
ADDED
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing
|
2 |
+
|
3 |
+
import julius
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
|
8 |
+
from . import util
|
9 |
+
|
10 |
+
|
11 |
+
class EffectMixin:
|
12 |
+
GAIN_FACTOR = np.log(10) / 20
|
13 |
+
"""Gain factor for converting between amplitude and decibels."""
|
14 |
+
CODEC_PRESETS = {
|
15 |
+
"8-bit": {"format": "wav", "encoding": "ULAW", "bits_per_sample": 8},
|
16 |
+
"GSM-FR": {"format": "gsm"},
|
17 |
+
"MP3": {"format": "mp3", "compression": -9},
|
18 |
+
"Vorbis": {"format": "vorbis", "compression": -1},
|
19 |
+
"Ogg": {
|
20 |
+
"format": "ogg",
|
21 |
+
"compression": -1,
|
22 |
+
},
|
23 |
+
"Amr-nb": {"format": "amr-nb"},
|
24 |
+
}
|
25 |
+
"""Presets for applying codecs via torchaudio."""
|
26 |
+
|
27 |
+
def mix(
|
28 |
+
self,
|
29 |
+
other,
|
30 |
+
snr: typing.Union[torch.Tensor, np.ndarray, float] = 10,
|
31 |
+
other_eq: typing.Union[torch.Tensor, np.ndarray] = None,
|
32 |
+
):
|
33 |
+
"""Mixes noise with signal at specified
|
34 |
+
signal-to-noise ratio. Optionally, the
|
35 |
+
other signal can be equalized in-place.
|
36 |
+
|
37 |
+
|
38 |
+
Parameters
|
39 |
+
----------
|
40 |
+
other : AudioSignal
|
41 |
+
AudioSignal object to mix with.
|
42 |
+
snr : typing.Union[torch.Tensor, np.ndarray, float], optional
|
43 |
+
Signal to noise ratio, by default 10
|
44 |
+
other_eq : typing.Union[torch.Tensor, np.ndarray], optional
|
45 |
+
EQ curve to apply to other signal, if any, by default None
|
46 |
+
|
47 |
+
Returns
|
48 |
+
-------
|
49 |
+
AudioSignal
|
50 |
+
In-place modification of AudioSignal.
|
51 |
+
"""
|
52 |
+
snr = util.ensure_tensor(snr).to(self.device)
|
53 |
+
|
54 |
+
pad_len = max(0, self.signal_length - other.signal_length)
|
55 |
+
other.zero_pad(0, pad_len)
|
56 |
+
other.truncate_samples(self.signal_length)
|
57 |
+
if other_eq is not None:
|
58 |
+
other = other.equalizer(other_eq)
|
59 |
+
|
60 |
+
tgt_loudness = self.loudness() - snr
|
61 |
+
other = other.normalize(tgt_loudness)
|
62 |
+
|
63 |
+
self.audio_data = self.audio_data + other.audio_data
|
64 |
+
return self
|
65 |
+
|
66 |
+
def convolve(self, other, start_at_max: bool = True):
|
67 |
+
"""Convolves self with other.
|
68 |
+
This function uses FFTs to do the convolution.
|
69 |
+
|
70 |
+
Parameters
|
71 |
+
----------
|
72 |
+
other : AudioSignal
|
73 |
+
Signal to convolve with.
|
74 |
+
start_at_max : bool, optional
|
75 |
+
Whether to start at the max value of other signal, to
|
76 |
+
avoid inducing delays, by default True
|
77 |
+
|
78 |
+
Returns
|
79 |
+
-------
|
80 |
+
AudioSignal
|
81 |
+
Convolved signal, in-place.
|
82 |
+
"""
|
83 |
+
from . import AudioSignal
|
84 |
+
|
85 |
+
pad_len = self.signal_length - other.signal_length
|
86 |
+
|
87 |
+
if pad_len > 0:
|
88 |
+
other.zero_pad(0, pad_len)
|
89 |
+
else:
|
90 |
+
other.truncate_samples(self.signal_length)
|
91 |
+
|
92 |
+
if start_at_max:
|
93 |
+
# Use roll to rotate over the max for every item
|
94 |
+
# so that the impulse responses don't induce any
|
95 |
+
# delay.
|
96 |
+
idx = other.audio_data.abs().argmax(axis=-1)
|
97 |
+
irs = torch.zeros_like(other.audio_data)
|
98 |
+
for i in range(other.batch_size):
|
99 |
+
irs[i] = torch.roll(other.audio_data[i], -idx[i].item(), -1)
|
100 |
+
other = AudioSignal(irs, other.sample_rate)
|
101 |
+
|
102 |
+
delta = torch.zeros_like(other.audio_data)
|
103 |
+
delta[..., 0] = 1
|
104 |
+
|
105 |
+
length = self.signal_length
|
106 |
+
delta_fft = torch.fft.rfft(delta, length)
|
107 |
+
other_fft = torch.fft.rfft(other.audio_data, length)
|
108 |
+
self_fft = torch.fft.rfft(self.audio_data, length)
|
109 |
+
|
110 |
+
convolved_fft = other_fft * self_fft
|
111 |
+
convolved_audio = torch.fft.irfft(convolved_fft, length)
|
112 |
+
|
113 |
+
delta_convolved_fft = other_fft * delta_fft
|
114 |
+
delta_audio = torch.fft.irfft(delta_convolved_fft, length)
|
115 |
+
|
116 |
+
# Use the delta to rescale the audio exactly as needed.
|
117 |
+
delta_max = delta_audio.abs().max(dim=-1, keepdims=True)[0]
|
118 |
+
scale = 1 / delta_max.clamp(1e-5)
|
119 |
+
convolved_audio = convolved_audio * scale
|
120 |
+
|
121 |
+
self.audio_data = convolved_audio
|
122 |
+
|
123 |
+
return self
|
124 |
+
|
125 |
+
def apply_ir(
|
126 |
+
self,
|
127 |
+
ir,
|
128 |
+
drr: typing.Union[torch.Tensor, np.ndarray, float] = None,
|
129 |
+
ir_eq: typing.Union[torch.Tensor, np.ndarray] = None,
|
130 |
+
use_original_phase: bool = False,
|
131 |
+
):
|
132 |
+
"""Applies an impulse response to the signal. If ` is`ir_eq``
|
133 |
+
is specified, the impulse response is equalized before
|
134 |
+
it is applied, using the given curve.
|
135 |
+
|
136 |
+
Parameters
|
137 |
+
----------
|
138 |
+
ir : AudioSignal
|
139 |
+
Impulse response to convolve with.
|
140 |
+
drr : typing.Union[torch.Tensor, np.ndarray, float], optional
|
141 |
+
Direct-to-reverberant ratio that impulse response will be
|
142 |
+
altered to, if specified, by default None
|
143 |
+
ir_eq : typing.Union[torch.Tensor, np.ndarray], optional
|
144 |
+
Equalization that will be applied to impulse response
|
145 |
+
if specified, by default None
|
146 |
+
use_original_phase : bool, optional
|
147 |
+
Whether to use the original phase, instead of the convolved
|
148 |
+
phase, by default False
|
149 |
+
|
150 |
+
Returns
|
151 |
+
-------
|
152 |
+
AudioSignal
|
153 |
+
Signal with impulse response applied to it
|
154 |
+
"""
|
155 |
+
if ir_eq is not None:
|
156 |
+
ir = ir.equalizer(ir_eq)
|
157 |
+
if drr is not None:
|
158 |
+
ir = ir.alter_drr(drr)
|
159 |
+
|
160 |
+
# Save the peak before
|
161 |
+
max_spk = self.audio_data.abs().max(dim=-1, keepdims=True).values
|
162 |
+
|
163 |
+
# Augment the impulse response to simulate microphone effects
|
164 |
+
# and with varying direct-to-reverberant ratio.
|
165 |
+
phase = self.phase
|
166 |
+
self.convolve(ir)
|
167 |
+
|
168 |
+
# Use the input phase
|
169 |
+
if use_original_phase:
|
170 |
+
self.stft()
|
171 |
+
self.stft_data = self.magnitude * torch.exp(1j * phase)
|
172 |
+
self.istft()
|
173 |
+
|
174 |
+
# Rescale to the input's amplitude
|
175 |
+
max_transformed = self.audio_data.abs().max(dim=-1, keepdims=True).values
|
176 |
+
scale_factor = max_spk.clamp(1e-8) / max_transformed.clamp(1e-8)
|
177 |
+
self = self * scale_factor
|
178 |
+
|
179 |
+
return self
|
180 |
+
|
181 |
+
def ensure_max_of_audio(self, max: float = 1.0):
|
182 |
+
"""Ensures that ``abs(audio_data) <= max``.
|
183 |
+
|
184 |
+
Parameters
|
185 |
+
----------
|
186 |
+
max : float, optional
|
187 |
+
Max absolute value of signal, by default 1.0
|
188 |
+
|
189 |
+
Returns
|
190 |
+
-------
|
191 |
+
AudioSignal
|
192 |
+
Signal with values scaled between -max and max.
|
193 |
+
"""
|
194 |
+
peak = self.audio_data.abs().max(dim=-1, keepdims=True)[0]
|
195 |
+
peak_gain = torch.ones_like(peak)
|
196 |
+
peak_gain[peak > max] = max / peak[peak > max]
|
197 |
+
self.audio_data = self.audio_data * peak_gain
|
198 |
+
return self
|
199 |
+
|
200 |
+
def normalize(self, db: typing.Union[torch.Tensor, np.ndarray, float] = -24.0):
|
201 |
+
"""Normalizes the signal's volume to the specified db, in LUFS.
|
202 |
+
This is GPU-compatible, making for very fast loudness normalization.
|
203 |
+
|
204 |
+
Parameters
|
205 |
+
----------
|
206 |
+
db : typing.Union[torch.Tensor, np.ndarray, float], optional
|
207 |
+
Loudness to normalize to, by default -24.0
|
208 |
+
|
209 |
+
Returns
|
210 |
+
-------
|
211 |
+
AudioSignal
|
212 |
+
Normalized audio signal.
|
213 |
+
"""
|
214 |
+
db = util.ensure_tensor(db).to(self.device)
|
215 |
+
ref_db = self.loudness()
|
216 |
+
gain = db - ref_db
|
217 |
+
gain = torch.exp(gain * self.GAIN_FACTOR)
|
218 |
+
|
219 |
+
self.audio_data = self.audio_data * gain[:, None, None]
|
220 |
+
return self
|
221 |
+
|
222 |
+
def volume_change(self, db: typing.Union[torch.Tensor, np.ndarray, float]):
|
223 |
+
"""Change volume of signal by some amount, in dB.
|
224 |
+
|
225 |
+
Parameters
|
226 |
+
----------
|
227 |
+
db : typing.Union[torch.Tensor, np.ndarray, float]
|
228 |
+
Amount to change volume by.
|
229 |
+
|
230 |
+
Returns
|
231 |
+
-------
|
232 |
+
AudioSignal
|
233 |
+
Signal at new volume.
|
234 |
+
"""
|
235 |
+
db = util.ensure_tensor(db, ndim=1).to(self.device)
|
236 |
+
gain = torch.exp(db * self.GAIN_FACTOR)
|
237 |
+
self.audio_data = self.audio_data * gain[:, None, None]
|
238 |
+
return self
|
239 |
+
|
240 |
+
def _to_2d(self):
|
241 |
+
waveform = self.audio_data.reshape(-1, self.signal_length)
|
242 |
+
return waveform
|
243 |
+
|
244 |
+
def _to_3d(self, waveform):
|
245 |
+
return waveform.reshape(self.batch_size, self.num_channels, -1)
|
246 |
+
|
247 |
+
def pitch_shift(self, n_semitones: int, quick: bool = True):
|
248 |
+
"""Pitch shift the signal. All items in the batch
|
249 |
+
get the same pitch shift.
|
250 |
+
|
251 |
+
Parameters
|
252 |
+
----------
|
253 |
+
n_semitones : int
|
254 |
+
How many semitones to shift the signal by.
|
255 |
+
quick : bool, optional
|
256 |
+
Using quick pitch shifting, by default True
|
257 |
+
|
258 |
+
Returns
|
259 |
+
-------
|
260 |
+
AudioSignal
|
261 |
+
Pitch shifted audio signal.
|
262 |
+
"""
|
263 |
+
device = self.device
|
264 |
+
effects = [
|
265 |
+
["pitch", str(n_semitones * 100)],
|
266 |
+
["rate", str(self.sample_rate)],
|
267 |
+
]
|
268 |
+
if quick:
|
269 |
+
effects[0].insert(1, "-q")
|
270 |
+
|
271 |
+
waveform = self._to_2d().cpu()
|
272 |
+
waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
|
273 |
+
waveform, self.sample_rate, effects, channels_first=True
|
274 |
+
)
|
275 |
+
self.sample_rate = sample_rate
|
276 |
+
self.audio_data = self._to_3d(waveform)
|
277 |
+
return self.to(device)
|
278 |
+
|
279 |
+
def time_stretch(self, factor: float, quick: bool = True):
|
280 |
+
"""Time stretch the audio signal.
|
281 |
+
|
282 |
+
Parameters
|
283 |
+
----------
|
284 |
+
factor : float
|
285 |
+
Factor by which to stretch the AudioSignal. Typically
|
286 |
+
between 0.8 and 1.2.
|
287 |
+
quick : bool, optional
|
288 |
+
Whether to use quick time stretching, by default True
|
289 |
+
|
290 |
+
Returns
|
291 |
+
-------
|
292 |
+
AudioSignal
|
293 |
+
Time-stretched AudioSignal.
|
294 |
+
"""
|
295 |
+
device = self.device
|
296 |
+
effects = [
|
297 |
+
["tempo", str(factor)],
|
298 |
+
["rate", str(self.sample_rate)],
|
299 |
+
]
|
300 |
+
if quick:
|
301 |
+
effects[0].insert(1, "-q")
|
302 |
+
|
303 |
+
waveform = self._to_2d().cpu()
|
304 |
+
waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
|
305 |
+
waveform, self.sample_rate, effects, channels_first=True
|
306 |
+
)
|
307 |
+
self.sample_rate = sample_rate
|
308 |
+
self.audio_data = self._to_3d(waveform)
|
309 |
+
return self.to(device)
|
310 |
+
|
311 |
+
def apply_codec(
|
312 |
+
self,
|
313 |
+
preset: str = None,
|
314 |
+
format: str = "wav",
|
315 |
+
encoding: str = None,
|
316 |
+
bits_per_sample: int = None,
|
317 |
+
compression: int = None,
|
318 |
+
): # pragma: no cover
|
319 |
+
"""Applies an audio codec to the signal.
|
320 |
+
|
321 |
+
Parameters
|
322 |
+
----------
|
323 |
+
preset : str, optional
|
324 |
+
One of the keys in ``self.CODEC_PRESETS``, by default None
|
325 |
+
format : str, optional
|
326 |
+
Format for audio codec, by default "wav"
|
327 |
+
encoding : str, optional
|
328 |
+
Encoding to use, by default None
|
329 |
+
bits_per_sample : int, optional
|
330 |
+
How many bits per sample, by default None
|
331 |
+
compression : int, optional
|
332 |
+
Compression amount of codec, by default None
|
333 |
+
|
334 |
+
Returns
|
335 |
+
-------
|
336 |
+
AudioSignal
|
337 |
+
AudioSignal with codec applied.
|
338 |
+
|
339 |
+
Raises
|
340 |
+
------
|
341 |
+
ValueError
|
342 |
+
If preset is not in ``self.CODEC_PRESETS``, an error
|
343 |
+
is thrown.
|
344 |
+
"""
|
345 |
+
torchaudio_version_070 = "0.7" in torchaudio.__version__
|
346 |
+
if torchaudio_version_070:
|
347 |
+
return self
|
348 |
+
|
349 |
+
kwargs = {
|
350 |
+
"format": format,
|
351 |
+
"encoding": encoding,
|
352 |
+
"bits_per_sample": bits_per_sample,
|
353 |
+
"compression": compression,
|
354 |
+
}
|
355 |
+
|
356 |
+
if preset is not None:
|
357 |
+
if preset in self.CODEC_PRESETS:
|
358 |
+
kwargs = self.CODEC_PRESETS[preset]
|
359 |
+
else:
|
360 |
+
raise ValueError(
|
361 |
+
f"Unknown preset: {preset}. "
|
362 |
+
f"Known presets: {list(self.CODEC_PRESETS.keys())}"
|
363 |
+
)
|
364 |
+
|
365 |
+
waveform = self._to_2d()
|
366 |
+
if kwargs["format"] in ["vorbis", "mp3", "ogg", "amr-nb"]:
|
367 |
+
# Apply it in a for loop
|
368 |
+
augmented = torch.cat(
|
369 |
+
[
|
370 |
+
torchaudio.functional.apply_codec(
|
371 |
+
waveform[i][None, :], self.sample_rate, **kwargs
|
372 |
+
)
|
373 |
+
for i in range(waveform.shape[0])
|
374 |
+
],
|
375 |
+
dim=0,
|
376 |
+
)
|
377 |
+
else:
|
378 |
+
augmented = torchaudio.functional.apply_codec(
|
379 |
+
waveform, self.sample_rate, **kwargs
|
380 |
+
)
|
381 |
+
augmented = self._to_3d(augmented)
|
382 |
+
|
383 |
+
self.audio_data = augmented
|
384 |
+
return self
|
385 |
+
|
386 |
+
def mel_filterbank(self, n_bands: int):
|
387 |
+
"""Breaks signal into mel bands.
|
388 |
+
|
389 |
+
Parameters
|
390 |
+
----------
|
391 |
+
n_bands : int
|
392 |
+
Number of mel bands to use.
|
393 |
+
|
394 |
+
Returns
|
395 |
+
-------
|
396 |
+
torch.Tensor
|
397 |
+
Mel-filtered bands, with last axis being the band index.
|
398 |
+
"""
|
399 |
+
filterbank = (
|
400 |
+
julius.SplitBands(self.sample_rate, n_bands).float().to(self.device)
|
401 |
+
)
|
402 |
+
filtered = filterbank(self.audio_data)
|
403 |
+
return filtered.permute(1, 2, 3, 0)
|
404 |
+
|
405 |
+
def equalizer(self, db: typing.Union[torch.Tensor, np.ndarray]):
|
406 |
+
"""Applies a mel-spaced equalizer to the audio signal.
|
407 |
+
|
408 |
+
Parameters
|
409 |
+
----------
|
410 |
+
db : typing.Union[torch.Tensor, np.ndarray]
|
411 |
+
EQ curve to apply.
|
412 |
+
|
413 |
+
Returns
|
414 |
+
-------
|
415 |
+
AudioSignal
|
416 |
+
AudioSignal with equalization applied.
|
417 |
+
"""
|
418 |
+
db = util.ensure_tensor(db)
|
419 |
+
n_bands = db.shape[-1]
|
420 |
+
fbank = self.mel_filterbank(n_bands)
|
421 |
+
|
422 |
+
# If there's a batch dimension, make sure it's the same.
|
423 |
+
if db.ndim == 2:
|
424 |
+
if db.shape[0] != 1:
|
425 |
+
assert db.shape[0] == fbank.shape[0]
|
426 |
+
else:
|
427 |
+
db = db.unsqueeze(0)
|
428 |
+
|
429 |
+
weights = (10**db).to(self.device).float()
|
430 |
+
fbank = fbank * weights[:, None, None, :]
|
431 |
+
eq_audio_data = fbank.sum(-1)
|
432 |
+
self.audio_data = eq_audio_data
|
433 |
+
return self
|
434 |
+
|
435 |
+
def clip_distortion(
|
436 |
+
self, clip_percentile: typing.Union[torch.Tensor, np.ndarray, float]
|
437 |
+
):
|
438 |
+
"""Clips the signal at a given percentile. The higher it is,
|
439 |
+
the lower the threshold for clipping.
|
440 |
+
|
441 |
+
Parameters
|
442 |
+
----------
|
443 |
+
clip_percentile : typing.Union[torch.Tensor, np.ndarray, float]
|
444 |
+
Values are between 0.0 to 1.0. Typical values are 0.1 or below.
|
445 |
+
|
446 |
+
Returns
|
447 |
+
-------
|
448 |
+
AudioSignal
|
449 |
+
Audio signal with clipped audio data.
|
450 |
+
"""
|
451 |
+
clip_percentile = util.ensure_tensor(clip_percentile, ndim=1)
|
452 |
+
min_thresh = torch.quantile(self.audio_data, clip_percentile / 2, dim=-1)
|
453 |
+
max_thresh = torch.quantile(self.audio_data, 1 - (clip_percentile / 2), dim=-1)
|
454 |
+
|
455 |
+
nc = self.audio_data.shape[1]
|
456 |
+
min_thresh = min_thresh[:, :nc, :]
|
457 |
+
max_thresh = max_thresh[:, :nc, :]
|
458 |
+
|
459 |
+
self.audio_data = self.audio_data.clamp(min_thresh, max_thresh)
|
460 |
+
|
461 |
+
return self
|
462 |
+
|
463 |
+
def quantization(
|
464 |
+
self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int]
|
465 |
+
):
|
466 |
+
"""Applies quantization to the input waveform.
|
467 |
+
|
468 |
+
Parameters
|
469 |
+
----------
|
470 |
+
quantization_channels : typing.Union[torch.Tensor, np.ndarray, int]
|
471 |
+
Number of evenly spaced quantization channels to quantize
|
472 |
+
to.
|
473 |
+
|
474 |
+
Returns
|
475 |
+
-------
|
476 |
+
AudioSignal
|
477 |
+
Quantized AudioSignal.
|
478 |
+
"""
|
479 |
+
quantization_channels = util.ensure_tensor(quantization_channels, ndim=3)
|
480 |
+
|
481 |
+
x = self.audio_data
|
482 |
+
x = (x + 1) / 2
|
483 |
+
x = x * quantization_channels
|
484 |
+
x = x.floor()
|
485 |
+
x = x / quantization_channels
|
486 |
+
x = 2 * x - 1
|
487 |
+
|
488 |
+
residual = (self.audio_data - x).detach()
|
489 |
+
self.audio_data = self.audio_data - residual
|
490 |
+
return self
|
491 |
+
|
492 |
+
def mulaw_quantization(
|
493 |
+
self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int]
|
494 |
+
):
|
495 |
+
"""Applies mu-law quantization to the input waveform.
|
496 |
+
|
497 |
+
Parameters
|
498 |
+
----------
|
499 |
+
quantization_channels : typing.Union[torch.Tensor, np.ndarray, int]
|
500 |
+
Number of mu-law spaced quantization channels to quantize
|
501 |
+
to.
|
502 |
+
|
503 |
+
Returns
|
504 |
+
-------
|
505 |
+
AudioSignal
|
506 |
+
Quantized AudioSignal.
|
507 |
+
"""
|
508 |
+
mu = quantization_channels - 1.0
|
509 |
+
mu = util.ensure_tensor(mu, ndim=3)
|
510 |
+
|
511 |
+
x = self.audio_data
|
512 |
+
|
513 |
+
# quantize
|
514 |
+
x = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
|
515 |
+
x = ((x + 1) / 2 * mu + 0.5).to(torch.int64)
|
516 |
+
|
517 |
+
# unquantize
|
518 |
+
x = (x / mu) * 2 - 1.0
|
519 |
+
x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
|
520 |
+
|
521 |
+
residual = (self.audio_data - x).detach()
|
522 |
+
self.audio_data = self.audio_data - residual
|
523 |
+
return self
|
524 |
+
|
525 |
+
def __matmul__(self, other):
|
526 |
+
return self.convolve(other)
|
527 |
+
|
528 |
+
|
529 |
+
class ImpulseResponseMixin:
|
530 |
+
"""These functions are generally only used with AudioSignals that are derived
|
531 |
+
from impulse responses, not other sources like music or speech. These methods
|
532 |
+
are used to replicate the data augmentation described in [1].
|
533 |
+
|
534 |
+
1. Bryan, Nicholas J. "Impulse response data augmentation and deep
|
535 |
+
neural networks for blind room acoustic parameter estimation."
|
536 |
+
ICASSP 2020-2020 IEEE International Conference on Acoustics,
|
537 |
+
Speech and Signal Processing (ICASSP). IEEE, 2020.
|
538 |
+
"""
|
539 |
+
|
540 |
+
def decompose_ir(self):
|
541 |
+
"""Decomposes an impulse response into early and late
|
542 |
+
field responses.
|
543 |
+
"""
|
544 |
+
# Equations 1 and 2
|
545 |
+
# -----------------
|
546 |
+
# Breaking up into early
|
547 |
+
# response + late field response.
|
548 |
+
|
549 |
+
td = torch.argmax(self.audio_data, dim=-1, keepdim=True)
|
550 |
+
t0 = int(self.sample_rate * 0.0025)
|
551 |
+
|
552 |
+
idx = torch.arange(self.audio_data.shape[-1], device=self.device)[None, None, :]
|
553 |
+
idx = idx.expand(self.batch_size, -1, -1)
|
554 |
+
early_idx = (idx >= td - t0) * (idx <= td + t0)
|
555 |
+
|
556 |
+
early_response = torch.zeros_like(self.audio_data, device=self.device)
|
557 |
+
early_response[early_idx] = self.audio_data[early_idx]
|
558 |
+
|
559 |
+
late_idx = ~early_idx
|
560 |
+
late_field = torch.zeros_like(self.audio_data, device=self.device)
|
561 |
+
late_field[late_idx] = self.audio_data[late_idx]
|
562 |
+
|
563 |
+
# Equation 4
|
564 |
+
# ----------
|
565 |
+
# Decompose early response into windowed
|
566 |
+
# direct path and windowed residual.
|
567 |
+
|
568 |
+
window = torch.zeros_like(self.audio_data, device=self.device)
|
569 |
+
for idx in range(self.batch_size):
|
570 |
+
window_idx = early_idx[idx, 0].nonzero()
|
571 |
+
window[idx, ..., window_idx] = self.get_window(
|
572 |
+
"hann", window_idx.shape[-1], self.device
|
573 |
+
)
|
574 |
+
return early_response, late_field, window
|
575 |
+
|
576 |
+
def measure_drr(self):
|
577 |
+
"""Measures the direct-to-reverberant ratio of the impulse
|
578 |
+
response.
|
579 |
+
|
580 |
+
Returns
|
581 |
+
-------
|
582 |
+
float
|
583 |
+
Direct-to-reverberant ratio
|
584 |
+
"""
|
585 |
+
early_response, late_field, _ = self.decompose_ir()
|
586 |
+
num = (early_response**2).sum(dim=-1)
|
587 |
+
den = (late_field**2).sum(dim=-1)
|
588 |
+
drr = 10 * torch.log10(num / den)
|
589 |
+
return drr
|
590 |
+
|
591 |
+
@staticmethod
|
592 |
+
def solve_alpha(early_response, late_field, wd, target_drr):
|
593 |
+
"""Used to solve for the alpha value, which is used
|
594 |
+
to alter the drr.
|
595 |
+
"""
|
596 |
+
# Equation 5
|
597 |
+
# ----------
|
598 |
+
# Apply the good ol' quadratic formula.
|
599 |
+
|
600 |
+
wd_sq = wd**2
|
601 |
+
wd_sq_1 = (1 - wd) ** 2
|
602 |
+
e_sq = early_response**2
|
603 |
+
l_sq = late_field**2
|
604 |
+
a = (wd_sq * e_sq).sum(dim=-1)
|
605 |
+
b = (2 * (1 - wd) * wd * e_sq).sum(dim=-1)
|
606 |
+
c = (wd_sq_1 * e_sq).sum(dim=-1) - torch.pow(10, target_drr / 10) * l_sq.sum(
|
607 |
+
dim=-1
|
608 |
+
)
|
609 |
+
|
610 |
+
expr = ((b**2) - 4 * a * c).sqrt()
|
611 |
+
alpha = torch.maximum(
|
612 |
+
(-b - expr) / (2 * a),
|
613 |
+
(-b + expr) / (2 * a),
|
614 |
+
)
|
615 |
+
return alpha
|
616 |
+
|
617 |
+
def alter_drr(self, drr: typing.Union[torch.Tensor, np.ndarray, float]):
|
618 |
+
"""Alters the direct-to-reverberant ratio of the impulse response.
|
619 |
+
|
620 |
+
Parameters
|
621 |
+
----------
|
622 |
+
drr : typing.Union[torch.Tensor, np.ndarray, float]
|
623 |
+
Direct-to-reverberant ratio that impulse response will be
|
624 |
+
altered to, if specified, by default None
|
625 |
+
|
626 |
+
Returns
|
627 |
+
-------
|
628 |
+
AudioSignal
|
629 |
+
Altered impulse response.
|
630 |
+
"""
|
631 |
+
drr = util.ensure_tensor(drr, 2, self.batch_size).to(self.device)
|
632 |
+
|
633 |
+
early_response, late_field, window = self.decompose_ir()
|
634 |
+
alpha = self.solve_alpha(early_response, late_field, window, drr)
|
635 |
+
min_alpha = (
|
636 |
+
late_field.abs().max(dim=-1)[0] / early_response.abs().max(dim=-1)[0]
|
637 |
+
)
|
638 |
+
alpha = torch.maximum(alpha, min_alpha)[..., None]
|
639 |
+
|
640 |
+
aug_ir_data = (
|
641 |
+
alpha * window * early_response
|
642 |
+
+ ((1 - window) * early_response)
|
643 |
+
+ late_field
|
644 |
+
)
|
645 |
+
self.audio_data = aug_ir_data
|
646 |
+
self.ensure_max_of_audio()
|
647 |
+
return self
|
audiotools/core/ffmpeg.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import shlex
|
3 |
+
import subprocess
|
4 |
+
import tempfile
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Tuple
|
7 |
+
|
8 |
+
import ffmpy
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
def r128stats(filepath: str, quiet: bool):
|
14 |
+
"""Takes a path to an audio file, returns a dict with the loudness
|
15 |
+
stats computed by the ffmpeg ebur128 filter.
|
16 |
+
|
17 |
+
Parameters
|
18 |
+
----------
|
19 |
+
filepath : str
|
20 |
+
Path to compute loudness stats on.
|
21 |
+
quiet : bool
|
22 |
+
Whether to show FFMPEG output during computation.
|
23 |
+
|
24 |
+
Returns
|
25 |
+
-------
|
26 |
+
dict
|
27 |
+
Dictionary containing loudness stats.
|
28 |
+
"""
|
29 |
+
ffargs = [
|
30 |
+
"ffmpeg",
|
31 |
+
"-nostats",
|
32 |
+
"-i",
|
33 |
+
filepath,
|
34 |
+
"-filter_complex",
|
35 |
+
"ebur128",
|
36 |
+
"-f",
|
37 |
+
"null",
|
38 |
+
"-",
|
39 |
+
]
|
40 |
+
if quiet:
|
41 |
+
ffargs += ["-hide_banner"]
|
42 |
+
proc = subprocess.Popen(ffargs, stderr=subprocess.PIPE, universal_newlines=True)
|
43 |
+
stats = proc.communicate()[1]
|
44 |
+
summary_index = stats.rfind("Summary:")
|
45 |
+
|
46 |
+
summary_list = stats[summary_index:].split()
|
47 |
+
i_lufs = float(summary_list[summary_list.index("I:") + 1])
|
48 |
+
i_thresh = float(summary_list[summary_list.index("I:") + 4])
|
49 |
+
lra = float(summary_list[summary_list.index("LRA:") + 1])
|
50 |
+
lra_thresh = float(summary_list[summary_list.index("LRA:") + 4])
|
51 |
+
lra_low = float(summary_list[summary_list.index("low:") + 1])
|
52 |
+
lra_high = float(summary_list[summary_list.index("high:") + 1])
|
53 |
+
stats_dict = {
|
54 |
+
"I": i_lufs,
|
55 |
+
"I Threshold": i_thresh,
|
56 |
+
"LRA": lra,
|
57 |
+
"LRA Threshold": lra_thresh,
|
58 |
+
"LRA Low": lra_low,
|
59 |
+
"LRA High": lra_high,
|
60 |
+
}
|
61 |
+
|
62 |
+
return stats_dict
|
63 |
+
|
64 |
+
|
65 |
+
def ffprobe_offset_and_codec(path: str) -> Tuple[float, str]:
|
66 |
+
"""Given a path to a file, returns the start time offset and codec of
|
67 |
+
the first audio stream.
|
68 |
+
"""
|
69 |
+
ff = ffmpy.FFprobe(
|
70 |
+
inputs={path: None},
|
71 |
+
global_options="-show_entries format=start_time:stream=duration,start_time,codec_type,codec_name,start_pts,time_base -of json -v quiet",
|
72 |
+
)
|
73 |
+
streams = json.loads(ff.run(stdout=subprocess.PIPE)[0])["streams"]
|
74 |
+
seconds_offset = 0.0
|
75 |
+
codec = None
|
76 |
+
|
77 |
+
# Get the offset and codec of the first audio stream we find
|
78 |
+
# and return its start time, if it has one.
|
79 |
+
for stream in streams:
|
80 |
+
if stream["codec_type"] == "audio":
|
81 |
+
seconds_offset = stream.get("start_time", 0.0)
|
82 |
+
codec = stream.get("codec_name")
|
83 |
+
break
|
84 |
+
return float(seconds_offset), codec
|
85 |
+
|
86 |
+
|
87 |
+
class FFMPEGMixin:
|
88 |
+
_loudness = None
|
89 |
+
|
90 |
+
def ffmpeg_loudness(self, quiet: bool = True):
|
91 |
+
"""Computes loudness of audio file using FFMPEG.
|
92 |
+
|
93 |
+
Parameters
|
94 |
+
----------
|
95 |
+
quiet : bool, optional
|
96 |
+
Whether to show FFMPEG output during computation,
|
97 |
+
by default True
|
98 |
+
|
99 |
+
Returns
|
100 |
+
-------
|
101 |
+
torch.Tensor
|
102 |
+
Loudness of every item in the batch, computed via
|
103 |
+
FFMPEG.
|
104 |
+
"""
|
105 |
+
loudness = []
|
106 |
+
|
107 |
+
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
|
108 |
+
for i in range(self.batch_size):
|
109 |
+
self[i].write(f.name)
|
110 |
+
loudness_stats = r128stats(f.name, quiet=quiet)
|
111 |
+
loudness.append(loudness_stats["I"])
|
112 |
+
|
113 |
+
self._loudness = torch.from_numpy(np.array(loudness)).float()
|
114 |
+
return self.loudness()
|
115 |
+
|
116 |
+
def ffmpeg_resample(self, sample_rate: int, quiet: bool = True):
|
117 |
+
"""Resamples AudioSignal using FFMPEG. More memory-efficient
|
118 |
+
than using julius.resample for long audio files.
|
119 |
+
|
120 |
+
Parameters
|
121 |
+
----------
|
122 |
+
sample_rate : int
|
123 |
+
Sample rate to resample to.
|
124 |
+
quiet : bool, optional
|
125 |
+
Whether to show FFMPEG output during computation,
|
126 |
+
by default True
|
127 |
+
|
128 |
+
Returns
|
129 |
+
-------
|
130 |
+
AudioSignal
|
131 |
+
Resampled AudioSignal.
|
132 |
+
"""
|
133 |
+
from audiotools import AudioSignal
|
134 |
+
|
135 |
+
if sample_rate == self.sample_rate:
|
136 |
+
return self
|
137 |
+
|
138 |
+
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
|
139 |
+
self.write(f.name)
|
140 |
+
f_out = f.name.replace("wav", "rs.wav")
|
141 |
+
command = f"ffmpeg -i {f.name} -ar {sample_rate} {f_out}"
|
142 |
+
if quiet:
|
143 |
+
command += " -hide_banner -loglevel error"
|
144 |
+
subprocess.check_call(shlex.split(command))
|
145 |
+
resampled = AudioSignal(f_out)
|
146 |
+
Path.unlink(Path(f_out))
|
147 |
+
return resampled
|
148 |
+
|
149 |
+
@classmethod
|
150 |
+
def load_from_file_with_ffmpeg(cls, audio_path: str, quiet: bool = True, **kwargs):
|
151 |
+
"""Loads AudioSignal object after decoding it to a wav file using FFMPEG.
|
152 |
+
Useful for loading audio that isn't covered by librosa's loading mechanism. Also
|
153 |
+
useful for loading mp3 files, without any offset.
|
154 |
+
|
155 |
+
Parameters
|
156 |
+
----------
|
157 |
+
audio_path : str
|
158 |
+
Path to load AudioSignal from.
|
159 |
+
quiet : bool, optional
|
160 |
+
Whether to show FFMPEG output during computation,
|
161 |
+
by default True
|
162 |
+
|
163 |
+
Returns
|
164 |
+
-------
|
165 |
+
AudioSignal
|
166 |
+
AudioSignal loaded from file with FFMPEG.
|
167 |
+
"""
|
168 |
+
audio_path = str(audio_path)
|
169 |
+
with tempfile.TemporaryDirectory() as d:
|
170 |
+
wav_file = str(Path(d) / "extracted.wav")
|
171 |
+
padded_wav = str(Path(d) / "padded.wav")
|
172 |
+
|
173 |
+
global_options = "-y"
|
174 |
+
if quiet:
|
175 |
+
global_options += " -loglevel error"
|
176 |
+
|
177 |
+
ff = ffmpy.FFmpeg(
|
178 |
+
inputs={audio_path: None},
|
179 |
+
outputs={wav_file: None},
|
180 |
+
global_options=global_options,
|
181 |
+
)
|
182 |
+
ff.run()
|
183 |
+
|
184 |
+
# We pad the file using the start time offset in case it's an audio
|
185 |
+
# stream starting at some offset in a video container.
|
186 |
+
pad, codec = ffprobe_offset_and_codec(audio_path)
|
187 |
+
|
188 |
+
# For mp3s, don't pad files with discrepancies less than 0.027s -
|
189 |
+
# it's likely due to codec latency. The amount of latency introduced
|
190 |
+
# by mp3 is 1152, which is 0.0261 44khz. So we set the threshold
|
191 |
+
# here slightly above that.
|
192 |
+
# Source: https://lame.sourceforge.io/tech-FAQ.txt.
|
193 |
+
if codec == "mp3" and pad < 0.027:
|
194 |
+
pad = 0.0
|
195 |
+
ff = ffmpy.FFmpeg(
|
196 |
+
inputs={wav_file: None},
|
197 |
+
outputs={padded_wav: f"-af 'adelay={pad*1000}:all=true'"},
|
198 |
+
global_options=global_options,
|
199 |
+
)
|
200 |
+
ff.run()
|
201 |
+
|
202 |
+
signal = cls(padded_wav, **kwargs)
|
203 |
+
|
204 |
+
return signal
|
audiotools/core/loudness.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
import julius
|
4 |
+
import numpy as np
|
5 |
+
import scipy
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchaudio
|
9 |
+
|
10 |
+
|
11 |
+
class Meter(torch.nn.Module):
|
12 |
+
"""Tensorized version of pyloudnorm.Meter. Works with batched audio tensors.
|
13 |
+
|
14 |
+
Parameters
|
15 |
+
----------
|
16 |
+
rate : int
|
17 |
+
Sample rate of audio.
|
18 |
+
filter_class : str, optional
|
19 |
+
Class of weighting filter used.
|
20 |
+
K-weighting' (default), 'Fenton/Lee 1'
|
21 |
+
'Fenton/Lee 2', 'Dash et al.'
|
22 |
+
by default "K-weighting"
|
23 |
+
block_size : float, optional
|
24 |
+
Gating block size in seconds, by default 0.400
|
25 |
+
zeros : int, optional
|
26 |
+
Number of zeros to use in FIR approximation of
|
27 |
+
IIR filters, by default 512
|
28 |
+
use_fir : bool, optional
|
29 |
+
Whether to use FIR approximation or exact IIR formulation.
|
30 |
+
If computing on GPU, ``use_fir=True`` will be used, as its
|
31 |
+
much faster, by default False
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
rate: int,
|
37 |
+
filter_class: str = "K-weighting",
|
38 |
+
block_size: float = 0.400,
|
39 |
+
zeros: int = 512,
|
40 |
+
use_fir: bool = False,
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
self.rate = rate
|
45 |
+
self.filter_class = filter_class
|
46 |
+
self.block_size = block_size
|
47 |
+
self.use_fir = use_fir
|
48 |
+
|
49 |
+
G = torch.from_numpy(np.array([1.0, 1.0, 1.0, 1.41, 1.41]))
|
50 |
+
self.register_buffer("G", G)
|
51 |
+
|
52 |
+
# Compute impulse responses so that filtering is fast via
|
53 |
+
# a convolution at runtime, on GPU, unlike lfilter.
|
54 |
+
impulse = np.zeros((zeros,))
|
55 |
+
impulse[..., 0] = 1.0
|
56 |
+
|
57 |
+
firs = np.zeros((len(self._filters), 1, zeros))
|
58 |
+
passband_gain = torch.zeros(len(self._filters))
|
59 |
+
|
60 |
+
for i, (_, filter_stage) in enumerate(self._filters.items()):
|
61 |
+
firs[i] = scipy.signal.lfilter(filter_stage.b, filter_stage.a, impulse)
|
62 |
+
passband_gain[i] = filter_stage.passband_gain
|
63 |
+
|
64 |
+
firs = torch.from_numpy(firs[..., ::-1].copy()).float()
|
65 |
+
|
66 |
+
self.register_buffer("firs", firs)
|
67 |
+
self.register_buffer("passband_gain", passband_gain)
|
68 |
+
|
69 |
+
def apply_filter_gpu(self, data: torch.Tensor):
|
70 |
+
"""Performs FIR approximation of loudness computation.
|
71 |
+
|
72 |
+
Parameters
|
73 |
+
----------
|
74 |
+
data : torch.Tensor
|
75 |
+
Audio data of shape (nb, nch, nt).
|
76 |
+
|
77 |
+
Returns
|
78 |
+
-------
|
79 |
+
torch.Tensor
|
80 |
+
Filtered audio data.
|
81 |
+
"""
|
82 |
+
# Data is of shape (nb, nch, nt)
|
83 |
+
# Reshape to (nb*nch, 1, nt)
|
84 |
+
nb, nt, nch = data.shape
|
85 |
+
data = data.permute(0, 2, 1)
|
86 |
+
data = data.reshape(nb * nch, 1, nt)
|
87 |
+
|
88 |
+
# Apply padding
|
89 |
+
pad_length = self.firs.shape[-1]
|
90 |
+
|
91 |
+
# Apply filtering in sequence
|
92 |
+
for i in range(self.firs.shape[0]):
|
93 |
+
data = F.pad(data, (pad_length, pad_length))
|
94 |
+
data = julius.fftconv.fft_conv1d(data, self.firs[i, None, ...])
|
95 |
+
data = self.passband_gain[i] * data
|
96 |
+
data = data[..., 1 : nt + 1]
|
97 |
+
|
98 |
+
data = data.permute(0, 2, 1)
|
99 |
+
data = data[:, :nt, :]
|
100 |
+
return data
|
101 |
+
|
102 |
+
def apply_filter_cpu(self, data: torch.Tensor):
|
103 |
+
"""Performs IIR formulation of loudness computation.
|
104 |
+
|
105 |
+
Parameters
|
106 |
+
----------
|
107 |
+
data : torch.Tensor
|
108 |
+
Audio data of shape (nb, nch, nt).
|
109 |
+
|
110 |
+
Returns
|
111 |
+
-------
|
112 |
+
torch.Tensor
|
113 |
+
Filtered audio data.
|
114 |
+
"""
|
115 |
+
for _, filter_stage in self._filters.items():
|
116 |
+
passband_gain = filter_stage.passband_gain
|
117 |
+
|
118 |
+
a_coeffs = torch.from_numpy(filter_stage.a).float().to(data.device)
|
119 |
+
b_coeffs = torch.from_numpy(filter_stage.b).float().to(data.device)
|
120 |
+
|
121 |
+
_data = data.permute(0, 2, 1)
|
122 |
+
filtered = torchaudio.functional.lfilter(
|
123 |
+
_data, a_coeffs, b_coeffs, clamp=False
|
124 |
+
)
|
125 |
+
data = passband_gain * filtered.permute(0, 2, 1)
|
126 |
+
return data
|
127 |
+
|
128 |
+
def apply_filter(self, data: torch.Tensor):
|
129 |
+
"""Applies filter on either CPU or GPU, depending
|
130 |
+
on if the audio is on GPU or is on CPU, or if
|
131 |
+
``self.use_fir`` is True.
|
132 |
+
|
133 |
+
Parameters
|
134 |
+
----------
|
135 |
+
data : torch.Tensor
|
136 |
+
Audio data of shape (nb, nch, nt).
|
137 |
+
|
138 |
+
Returns
|
139 |
+
-------
|
140 |
+
torch.Tensor
|
141 |
+
Filtered audio data.
|
142 |
+
"""
|
143 |
+
if data.is_cuda or self.use_fir:
|
144 |
+
data = self.apply_filter_gpu(data)
|
145 |
+
else:
|
146 |
+
data = self.apply_filter_cpu(data)
|
147 |
+
return data
|
148 |
+
|
149 |
+
def forward(self, data: torch.Tensor):
|
150 |
+
"""Computes integrated loudness of data.
|
151 |
+
|
152 |
+
Parameters
|
153 |
+
----------
|
154 |
+
data : torch.Tensor
|
155 |
+
Audio data of shape (nb, nch, nt).
|
156 |
+
|
157 |
+
Returns
|
158 |
+
-------
|
159 |
+
torch.Tensor
|
160 |
+
Filtered audio data.
|
161 |
+
"""
|
162 |
+
return self.integrated_loudness(data)
|
163 |
+
|
164 |
+
def _unfold(self, input_data):
|
165 |
+
T_g = self.block_size
|
166 |
+
overlap = 0.75 # overlap of 75% of the block duration
|
167 |
+
step = 1.0 - overlap # step size by percentage
|
168 |
+
|
169 |
+
kernel_size = int(T_g * self.rate)
|
170 |
+
stride = int(T_g * self.rate * step)
|
171 |
+
unfolded = julius.core.unfold(input_data.permute(0, 2, 1), kernel_size, stride)
|
172 |
+
unfolded = unfolded.transpose(-1, -2)
|
173 |
+
|
174 |
+
return unfolded
|
175 |
+
|
176 |
+
def integrated_loudness(self, data: torch.Tensor):
|
177 |
+
"""Computes integrated loudness of data.
|
178 |
+
|
179 |
+
Parameters
|
180 |
+
----------
|
181 |
+
data : torch.Tensor
|
182 |
+
Audio data of shape (nb, nch, nt).
|
183 |
+
|
184 |
+
Returns
|
185 |
+
-------
|
186 |
+
torch.Tensor
|
187 |
+
Filtered audio data.
|
188 |
+
"""
|
189 |
+
if not torch.is_tensor(data):
|
190 |
+
data = torch.from_numpy(data).float()
|
191 |
+
else:
|
192 |
+
data = data.float()
|
193 |
+
|
194 |
+
input_data = copy.copy(data)
|
195 |
+
# Data always has a batch and channel dimension.
|
196 |
+
# Is of shape (nb, nt, nch)
|
197 |
+
if input_data.ndim < 2:
|
198 |
+
input_data = input_data.unsqueeze(-1)
|
199 |
+
if input_data.ndim < 3:
|
200 |
+
input_data = input_data.unsqueeze(0)
|
201 |
+
|
202 |
+
nb, nt, nch = input_data.shape
|
203 |
+
|
204 |
+
# Apply frequency weighting filters - account
|
205 |
+
# for the acoustic respose of the head and auditory system
|
206 |
+
input_data = self.apply_filter(input_data)
|
207 |
+
|
208 |
+
G = self.G # channel gains
|
209 |
+
T_g = self.block_size # 400 ms gating block standard
|
210 |
+
Gamma_a = -70.0 # -70 LKFS = absolute loudness threshold
|
211 |
+
|
212 |
+
unfolded = self._unfold(input_data)
|
213 |
+
|
214 |
+
z = (1.0 / (T_g * self.rate)) * unfolded.square().sum(2)
|
215 |
+
l = -0.691 + 10.0 * torch.log10((G[None, :nch, None] * z).sum(1, keepdim=True))
|
216 |
+
l = l.expand_as(z)
|
217 |
+
|
218 |
+
# find gating block indices above absolute threshold
|
219 |
+
z_avg_gated = z
|
220 |
+
z_avg_gated[l <= Gamma_a] = 0
|
221 |
+
masked = l > Gamma_a
|
222 |
+
z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)
|
223 |
+
|
224 |
+
# calculate the relative threshold value (see eq. 6)
|
225 |
+
Gamma_r = (
|
226 |
+
-0.691 + 10.0 * torch.log10((z_avg_gated * G[None, :nch]).sum(-1)) - 10.0
|
227 |
+
)
|
228 |
+
Gamma_r = Gamma_r[:, None, None]
|
229 |
+
Gamma_r = Gamma_r.expand(nb, nch, l.shape[-1])
|
230 |
+
|
231 |
+
# find gating block indices above relative and absolute thresholds (end of eq. 7)
|
232 |
+
z_avg_gated = z
|
233 |
+
z_avg_gated[l <= Gamma_a] = 0
|
234 |
+
z_avg_gated[l <= Gamma_r] = 0
|
235 |
+
masked = (l > Gamma_a) * (l > Gamma_r)
|
236 |
+
z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)
|
237 |
+
|
238 |
+
# # Cannot use nan_to_num (pytorch 1.8 does not come with GCP-supported cuda version)
|
239 |
+
# z_avg_gated = torch.nan_to_num(z_avg_gated)
|
240 |
+
z_avg_gated = torch.where(
|
241 |
+
z_avg_gated.isnan(), torch.zeros_like(z_avg_gated), z_avg_gated
|
242 |
+
)
|
243 |
+
z_avg_gated[z_avg_gated == float("inf")] = float(np.finfo(np.float32).max)
|
244 |
+
z_avg_gated[z_avg_gated == -float("inf")] = float(np.finfo(np.float32).min)
|
245 |
+
|
246 |
+
LUFS = -0.691 + 10.0 * torch.log10((G[None, :nch] * z_avg_gated).sum(1))
|
247 |
+
return LUFS.float()
|
248 |
+
|
249 |
+
@property
|
250 |
+
def filter_class(self):
|
251 |
+
return self._filter_class
|
252 |
+
|
253 |
+
@filter_class.setter
|
254 |
+
def filter_class(self, value):
|
255 |
+
from pyloudnorm import Meter
|
256 |
+
|
257 |
+
meter = Meter(self.rate)
|
258 |
+
meter.filter_class = value
|
259 |
+
self._filter_class = value
|
260 |
+
self._filters = meter._filters
|
261 |
+
|
262 |
+
|
263 |
+
class LoudnessMixin:
|
264 |
+
_loudness = None
|
265 |
+
MIN_LOUDNESS = -70
|
266 |
+
"""Minimum loudness possible."""
|
267 |
+
|
268 |
+
def loudness(
|
269 |
+
self, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs
|
270 |
+
):
|
271 |
+
"""Calculates loudness using an implementation of ITU-R BS.1770-4.
|
272 |
+
Allows control over gating block size and frequency weighting filters for
|
273 |
+
additional control. Measure the integrated gated loudness of a signal.
|
274 |
+
|
275 |
+
API is derived from PyLoudnorm, but this implementation is ported to PyTorch
|
276 |
+
and is tensorized across batches. When on GPU, an FIR approximation of the IIR
|
277 |
+
filters is used to compute loudness for speed.
|
278 |
+
|
279 |
+
Uses the weighting filters and block size defined by the meter
|
280 |
+
the integrated loudness is measured based upon the gating algorithm
|
281 |
+
defined in the ITU-R BS.1770-4 specification.
|
282 |
+
|
283 |
+
Parameters
|
284 |
+
----------
|
285 |
+
filter_class : str, optional
|
286 |
+
Class of weighting filter used.
|
287 |
+
K-weighting' (default), 'Fenton/Lee 1'
|
288 |
+
'Fenton/Lee 2', 'Dash et al.'
|
289 |
+
by default "K-weighting"
|
290 |
+
block_size : float, optional
|
291 |
+
Gating block size in seconds, by default 0.400
|
292 |
+
kwargs : dict, optional
|
293 |
+
Keyword arguments to :py:func:`audiotools.core.loudness.Meter`.
|
294 |
+
|
295 |
+
Returns
|
296 |
+
-------
|
297 |
+
torch.Tensor
|
298 |
+
Loudness of audio data.
|
299 |
+
"""
|
300 |
+
if self._loudness is not None:
|
301 |
+
return self._loudness.to(self.device)
|
302 |
+
original_length = self.signal_length
|
303 |
+
if self.signal_duration < 0.5:
|
304 |
+
pad_len = int((0.5 - self.signal_duration) * self.sample_rate)
|
305 |
+
self.zero_pad(0, pad_len)
|
306 |
+
|
307 |
+
# create BS.1770 meter
|
308 |
+
meter = Meter(
|
309 |
+
self.sample_rate, filter_class=filter_class, block_size=block_size, **kwargs
|
310 |
+
)
|
311 |
+
meter = meter.to(self.device)
|
312 |
+
# measure loudness
|
313 |
+
loudness = meter.integrated_loudness(self.audio_data.permute(0, 2, 1))
|
314 |
+
self.truncate_samples(original_length)
|
315 |
+
min_loudness = (
|
316 |
+
torch.ones_like(loudness, device=loudness.device) * self.MIN_LOUDNESS
|
317 |
+
)
|
318 |
+
self._loudness = torch.maximum(loudness, min_loudness)
|
319 |
+
|
320 |
+
return self._loudness.to(self.device)
|
audiotools/core/playback.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
These are utilities that allow one to embed an AudioSignal
|
3 |
+
as a playable object in a Jupyter notebook, or to play audio from
|
4 |
+
the terminal, etc.
|
5 |
+
""" # fmt: skip
|
6 |
+
import base64
|
7 |
+
import io
|
8 |
+
import random
|
9 |
+
import string
|
10 |
+
import subprocess
|
11 |
+
from tempfile import NamedTemporaryFile
|
12 |
+
|
13 |
+
import importlib_resources as pkg_resources
|
14 |
+
|
15 |
+
from . import templates
|
16 |
+
from .util import _close_temp_files
|
17 |
+
from .util import format_figure
|
18 |
+
|
19 |
+
headers = pkg_resources.files(templates).joinpath("headers.html").read_text()
|
20 |
+
widget = pkg_resources.files(templates).joinpath("widget.html").read_text()
|
21 |
+
|
22 |
+
DEFAULT_EXTENSION = ".wav"
|
23 |
+
|
24 |
+
|
25 |
+
def _check_imports(): # pragma: no cover
|
26 |
+
try:
|
27 |
+
import ffmpy
|
28 |
+
except:
|
29 |
+
ffmpy = False
|
30 |
+
|
31 |
+
try:
|
32 |
+
import IPython
|
33 |
+
except:
|
34 |
+
raise ImportError("IPython must be installed in order to use this function!")
|
35 |
+
return ffmpy, IPython
|
36 |
+
|
37 |
+
|
38 |
+
class PlayMixin:
|
39 |
+
def embed(self, ext: str = None, display: bool = True, return_html: bool = False):
|
40 |
+
"""Embeds audio as a playable audio embed in a notebook, or HTML
|
41 |
+
document, etc.
|
42 |
+
|
43 |
+
Parameters
|
44 |
+
----------
|
45 |
+
ext : str, optional
|
46 |
+
Extension to use when saving the audio, by default ".wav"
|
47 |
+
display : bool, optional
|
48 |
+
This controls whether or not to display the audio when called. This
|
49 |
+
is used when the embed is the last line in a Jupyter cell, to prevent
|
50 |
+
the audio from being embedded twice, by default True
|
51 |
+
return_html : bool, optional
|
52 |
+
Whether to return the data wrapped in an HTML audio element, by default False
|
53 |
+
|
54 |
+
Returns
|
55 |
+
-------
|
56 |
+
str
|
57 |
+
Either the element for display, or the HTML string of it.
|
58 |
+
"""
|
59 |
+
if ext is None:
|
60 |
+
ext = DEFAULT_EXTENSION
|
61 |
+
ext = f".{ext}" if not ext.startswith(".") else ext
|
62 |
+
ffmpy, IPython = _check_imports()
|
63 |
+
sr = self.sample_rate
|
64 |
+
tmpfiles = []
|
65 |
+
|
66 |
+
with _close_temp_files(tmpfiles):
|
67 |
+
tmp_wav = NamedTemporaryFile(mode="w+", suffix=".wav", delete=False)
|
68 |
+
tmpfiles.append(tmp_wav)
|
69 |
+
self.write(tmp_wav.name)
|
70 |
+
if ext != ".wav" and ffmpy:
|
71 |
+
tmp_converted = NamedTemporaryFile(mode="w+", suffix=ext, delete=False)
|
72 |
+
tmpfiles.append(tmp_wav)
|
73 |
+
ff = ffmpy.FFmpeg(
|
74 |
+
inputs={tmp_wav.name: None},
|
75 |
+
outputs={
|
76 |
+
tmp_converted.name: "-write_xing 0 -codec:a libmp3lame -b:a 128k -y -hide_banner -loglevel error"
|
77 |
+
},
|
78 |
+
)
|
79 |
+
ff.run()
|
80 |
+
else:
|
81 |
+
tmp_converted = tmp_wav
|
82 |
+
|
83 |
+
audio_element = IPython.display.Audio(data=tmp_converted.name, rate=sr)
|
84 |
+
if display:
|
85 |
+
IPython.display.display(audio_element)
|
86 |
+
|
87 |
+
if return_html:
|
88 |
+
audio_element = (
|
89 |
+
f"<audio "
|
90 |
+
f" controls "
|
91 |
+
f" src='{audio_element.src_attr()}'> "
|
92 |
+
f"</audio> "
|
93 |
+
)
|
94 |
+
return audio_element
|
95 |
+
|
96 |
+
def widget(
|
97 |
+
self,
|
98 |
+
title: str = None,
|
99 |
+
ext: str = ".wav",
|
100 |
+
add_headers: bool = True,
|
101 |
+
player_width: str = "100%",
|
102 |
+
margin: str = "10px",
|
103 |
+
plot_fn: str = "specshow",
|
104 |
+
return_html: bool = False,
|
105 |
+
**kwargs,
|
106 |
+
):
|
107 |
+
"""Creates a playable widget with spectrogram. Inspired (heavily) by
|
108 |
+
https://sjvasquez.github.io/blog/melnet/.
|
109 |
+
|
110 |
+
Parameters
|
111 |
+
----------
|
112 |
+
title : str, optional
|
113 |
+
Title of plot, placed in upper right of top-most axis.
|
114 |
+
ext : str, optional
|
115 |
+
Extension for embedding, by default ".mp3"
|
116 |
+
add_headers : bool, optional
|
117 |
+
Whether or not to add headers (use for first embed, False for later embeds), by default True
|
118 |
+
player_width : str, optional
|
119 |
+
Width of the player, as a string in a CSS rule, by default "100%"
|
120 |
+
margin : str, optional
|
121 |
+
Margin on all sides of player, by default "10px"
|
122 |
+
plot_fn : function, optional
|
123 |
+
Plotting function to use (by default self.specshow).
|
124 |
+
return_html : bool, optional
|
125 |
+
Whether to return the data wrapped in an HTML audio element, by default False
|
126 |
+
kwargs : dict, optional
|
127 |
+
Keyword arguments to plot_fn (by default self.specshow).
|
128 |
+
|
129 |
+
Returns
|
130 |
+
-------
|
131 |
+
HTML
|
132 |
+
HTML object.
|
133 |
+
"""
|
134 |
+
import matplotlib.pyplot as plt
|
135 |
+
|
136 |
+
def _save_fig_to_tag():
|
137 |
+
buffer = io.BytesIO()
|
138 |
+
|
139 |
+
plt.savefig(buffer, bbox_inches="tight", pad_inches=0)
|
140 |
+
plt.close()
|
141 |
+
|
142 |
+
buffer.seek(0)
|
143 |
+
data_uri = base64.b64encode(buffer.read()).decode("ascii")
|
144 |
+
tag = "data:image/png;base64,{0}".format(data_uri)
|
145 |
+
|
146 |
+
return tag
|
147 |
+
|
148 |
+
_, IPython = _check_imports()
|
149 |
+
|
150 |
+
header_html = ""
|
151 |
+
|
152 |
+
if add_headers:
|
153 |
+
header_html = headers.replace("PLAYER_WIDTH", str(player_width))
|
154 |
+
header_html = header_html.replace("MARGIN", str(margin))
|
155 |
+
IPython.display.display(IPython.display.HTML(header_html))
|
156 |
+
|
157 |
+
widget_html = widget
|
158 |
+
if isinstance(plot_fn, str):
|
159 |
+
plot_fn = getattr(self, plot_fn)
|
160 |
+
kwargs["title"] = title
|
161 |
+
plot_fn(**kwargs)
|
162 |
+
|
163 |
+
fig = plt.gcf()
|
164 |
+
pixels = fig.get_size_inches() * fig.dpi
|
165 |
+
|
166 |
+
tag = _save_fig_to_tag()
|
167 |
+
|
168 |
+
# Make the source image for the levels
|
169 |
+
self.specshow()
|
170 |
+
format_figure((12, 1.5))
|
171 |
+
levels_tag = _save_fig_to_tag()
|
172 |
+
|
173 |
+
player_id = "".join(random.choice(string.ascii_uppercase) for _ in range(10))
|
174 |
+
|
175 |
+
audio_elem = self.embed(ext=ext, display=False)
|
176 |
+
widget_html = widget_html.replace("AUDIO_SRC", audio_elem.src_attr())
|
177 |
+
widget_html = widget_html.replace("IMAGE_SRC", tag)
|
178 |
+
widget_html = widget_html.replace("LEVELS_SRC", levels_tag)
|
179 |
+
widget_html = widget_html.replace("PLAYER_ID", player_id)
|
180 |
+
|
181 |
+
# Calculate width/height of figure based on figure size.
|
182 |
+
widget_html = widget_html.replace("PADDING_AMOUNT", f"{int(pixels[1])}px")
|
183 |
+
widget_html = widget_html.replace("MAX_WIDTH", f"{int(pixels[0])}px")
|
184 |
+
|
185 |
+
IPython.display.display(IPython.display.HTML(widget_html))
|
186 |
+
|
187 |
+
if return_html:
|
188 |
+
html = header_html if add_headers else ""
|
189 |
+
html += widget_html
|
190 |
+
return html
|
191 |
+
|
192 |
+
def play(self):
|
193 |
+
"""
|
194 |
+
Plays an audio signal if ffplay from the ffmpeg suite of tools is installed.
|
195 |
+
Otherwise, will fail. The audio signal is written to a temporary file
|
196 |
+
and then played with ffplay.
|
197 |
+
"""
|
198 |
+
tmpfiles = []
|
199 |
+
with _close_temp_files(tmpfiles):
|
200 |
+
tmp_wav = NamedTemporaryFile(suffix=".wav", delete=False)
|
201 |
+
tmpfiles.append(tmp_wav)
|
202 |
+
self.write(tmp_wav.name)
|
203 |
+
print(self)
|
204 |
+
subprocess.call(
|
205 |
+
[
|
206 |
+
"ffplay",
|
207 |
+
"-nodisp",
|
208 |
+
"-autoexit",
|
209 |
+
"-hide_banner",
|
210 |
+
"-loglevel",
|
211 |
+
"error",
|
212 |
+
tmp_wav.name,
|
213 |
+
]
|
214 |
+
)
|
215 |
+
return self
|
216 |
+
|
217 |
+
|
218 |
+
if __name__ == "__main__": # pragma: no cover
|
219 |
+
from audiotools import AudioSignal
|
220 |
+
|
221 |
+
signal = AudioSignal(
|
222 |
+
"tests/audio/spk/f10_script4_produced.mp3", offset=5, duration=5
|
223 |
+
)
|
224 |
+
|
225 |
+
wave_html = signal.widget(
|
226 |
+
"Waveform",
|
227 |
+
plot_fn="waveplot",
|
228 |
+
return_html=True,
|
229 |
+
)
|
230 |
+
|
231 |
+
spec_html = signal.widget("Spectrogram", return_html=True, add_headers=False)
|
232 |
+
|
233 |
+
combined_html = signal.widget(
|
234 |
+
"Waveform + spectrogram",
|
235 |
+
plot_fn="wavespec",
|
236 |
+
return_html=True,
|
237 |
+
add_headers=False,
|
238 |
+
)
|
239 |
+
|
240 |
+
signal.low_pass(8000)
|
241 |
+
lowpass_html = signal.widget(
|
242 |
+
"Lowpassed audio",
|
243 |
+
plot_fn="wavespec",
|
244 |
+
return_html=True,
|
245 |
+
add_headers=False,
|
246 |
+
)
|
247 |
+
|
248 |
+
with open("/tmp/index.html", "w") as f:
|
249 |
+
f.write(wave_html)
|
250 |
+
f.write(spec_html)
|
251 |
+
f.write(combined_html)
|
252 |
+
f.write(lowpass_html)
|
audiotools/core/templates/__init__.py
ADDED
File without changes
|
audiotools/core/templates/headers.html
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<style>
|
2 |
+
.player {
|
3 |
+
width: 100%;
|
4 |
+
/*border: 1px solid black;*/
|
5 |
+
margin: 10px;
|
6 |
+
}
|
7 |
+
|
8 |
+
.underlay img {
|
9 |
+
width: 100%;
|
10 |
+
height: 100%;
|
11 |
+
}
|
12 |
+
|
13 |
+
.spectrogram {
|
14 |
+
height: 0;
|
15 |
+
width: 100%;
|
16 |
+
position: relative;
|
17 |
+
}
|
18 |
+
|
19 |
+
.audio-controls {
|
20 |
+
width: 100%;
|
21 |
+
height: 54px;
|
22 |
+
display: flex;
|
23 |
+
/*border-top: 1px solid black;*/
|
24 |
+
/*background-color: rgb(241, 243, 244);*/
|
25 |
+
background-color: rgb(248, 248, 248);
|
26 |
+
background-color: rgb(253, 253, 254);
|
27 |
+
border: 1px solid rgb(205, 208, 211);
|
28 |
+
margin-top: 20px;
|
29 |
+
/*border: 1px solid black;*/
|
30 |
+
border-radius: 30px;
|
31 |
+
|
32 |
+
}
|
33 |
+
|
34 |
+
.play-img {
|
35 |
+
margin: auto;
|
36 |
+
height: 45%;
|
37 |
+
width: 45%;
|
38 |
+
display: block;
|
39 |
+
}
|
40 |
+
|
41 |
+
.download-img {
|
42 |
+
margin: auto;
|
43 |
+
height: 100%;
|
44 |
+
width: 100%;
|
45 |
+
display: block;
|
46 |
+
}
|
47 |
+
|
48 |
+
.pause-img {
|
49 |
+
margin: auto;
|
50 |
+
height: 45%;
|
51 |
+
width: 45%;
|
52 |
+
display: none
|
53 |
+
}
|
54 |
+
|
55 |
+
.playpause {
|
56 |
+
margin:11px 11px 11px 11px;
|
57 |
+
width: 32px;
|
58 |
+
min-width: 32px;
|
59 |
+
height: 32px;
|
60 |
+
/*background-color: rgb(241, 243, 244);*/
|
61 |
+
background-color: rgba(0, 0, 0, 0.0);
|
62 |
+
/*border-right: 1px solid black;*/
|
63 |
+
/*border: 1px solid red;*/
|
64 |
+
border-radius: 16px;
|
65 |
+
color: black;
|
66 |
+
transition: 0.25s;
|
67 |
+
box-sizing: border-box !important;
|
68 |
+
}
|
69 |
+
|
70 |
+
.download {
|
71 |
+
margin:11px 11px 11px 11px;
|
72 |
+
width: 32px;
|
73 |
+
min-width: 32px;
|
74 |
+
height: 32px;
|
75 |
+
/*background-color: rgb(241, 243, 244);*/
|
76 |
+
background-color: rgba(0, 0, 0, 0.0);
|
77 |
+
/*border-right: 1px solid black;*/
|
78 |
+
/*border: 1px solid red;*/
|
79 |
+
border-radius: 16px;
|
80 |
+
color: black;
|
81 |
+
transition: 0.25s;
|
82 |
+
box-sizing: border-box !important;
|
83 |
+
}
|
84 |
+
|
85 |
+
/*.playpause:disabled {
|
86 |
+
background-color: red;
|
87 |
+
}*/
|
88 |
+
|
89 |
+
.playpause:hover {
|
90 |
+
background-color: rgba(10, 20, 30, 0.03);
|
91 |
+
}
|
92 |
+
|
93 |
+
.playpause:focus {
|
94 |
+
outline:none;
|
95 |
+
}
|
96 |
+
|
97 |
+
.response {
|
98 |
+
padding:0px 20px 0px 0px;
|
99 |
+
width: calc(100% - 132px);
|
100 |
+
height: 100%;
|
101 |
+
|
102 |
+
/*border: 1px solid red;*/
|
103 |
+
/*border-bottom: 1px solid rgb(89, 89, 89);*/
|
104 |
+
}
|
105 |
+
|
106 |
+
.response-canvas {
|
107 |
+
height: 100%;
|
108 |
+
width: 100%;
|
109 |
+
}
|
110 |
+
|
111 |
+
|
112 |
+
.underlay {
|
113 |
+
height: 100%;
|
114 |
+
width: 100%;
|
115 |
+
position: absolute;
|
116 |
+
top: 0;
|
117 |
+
left: 0;
|
118 |
+
}
|
119 |
+
|
120 |
+
.overlay{
|
121 |
+
width: 0%;
|
122 |
+
height:100%;
|
123 |
+
top: 0;
|
124 |
+
right: 0px;
|
125 |
+
|
126 |
+
background:rgba(255, 255, 255, 0.15);
|
127 |
+
overflow:hidden;
|
128 |
+
position: absolute;
|
129 |
+
z-index: 10;
|
130 |
+
border-left: solid 1px rgba(0, 0, 0, 0.664);
|
131 |
+
|
132 |
+
position: absolute;
|
133 |
+
pointer-events: none;
|
134 |
+
}
|
135 |
+
</style>
|
136 |
+
|
137 |
+
<script>
|
138 |
+
!function(t){if("object"==typeof exports&&"undefined"!=typeof module)module.exports=t();else if("function"==typeof define&&define.amd)define([],t);else{("undefined"!=typeof window?window:"undefined"!=typeof global?global:"undefined"!=typeof self?self:this).pako=t()}}(function(){return function(){return function t(e,a,i){function n(s,o){if(!a[s]){if(!e[s]){var l="function"==typeof require&&require;if(!o&&l)return l(s,!0);if(r)return r(s,!0);var h=new Error("Cannot find module '"+s+"'");throw h.code="MODULE_NOT_FOUND",h}var d=a[s]={exports:{}};e[s][0].call(d.exports,function(t){return n(e[s][1][t]||t)},d,d.exports,t,e,a,i)}return a[s].exports}for(var r="function"==typeof require&&require,s=0;s<i.length;s++)n(i[s]);return n}}()({1:[function(t,e,a){"use strict";var i=t("./zlib/deflate"),n=t("./utils/common"),r=t("./utils/strings"),s=t("./zlib/messages"),o=t("./zlib/zstream"),l=Object.prototype.toString,h=0,d=-1,f=0,_=8;function u(t){if(!(this instanceof u))return new u(t);this.options=n.assign({level:d,method:_,chunkSize:16384,windowBits:15,memLevel:8,strategy:f,to:""},t||{});var e=this.options;e.raw&&e.windowBits>0?e.windowBits=-e.windowBits:e.gzip&&e.windowBits>0&&e.windowBits<16&&(e.windowBits+=16),this.err=0,this.msg="",this.ended=!1,this.chunks=[],this.strm=new o,this.strm.avail_out=0;var a=i.deflateInit2(this.strm,e.level,e.method,e.windowBits,e.memLevel,e.strategy);if(a!==h)throw new Error(s[a]);if(e.header&&i.deflateSetHeader(this.strm,e.header),e.dictionary){var c;if(c="string"==typeof e.dictionary?r.string2buf(e.dictionary):"[object ArrayBuffer]"===l.call(e.dictionary)?new Uint8Array(e.dictionary):e.dictionary,(a=i.deflateSetDictionary(this.strm,c))!==h)throw new Error(s[a]);this._dict_set=!0}}function c(t,e){var a=new u(e);if(a.push(t,!0),a.err)throw a.msg||s[a.err];return a.result}u.prototype.push=function(t,e){var a,s,o=this.strm,d=this.options.chunkSize;if(this.ended)return!1;s=e===~~e?e:!0===e?4:0,"string"==typeof t?o.input=r.string2buf(t):"[object ArrayBuffer]"===l.call(t)?o.input=new Uint8Array(t):o.input=t,o.next_in=0,o.avail_in=o.input.length;do{if(0===o.avail_out&&(o.output=new n.Buf8(d),o.next_out=0,o.avail_out=d),1!==(a=i.deflate(o,s))&&a!==h)return this.onEnd(a),this.ended=!0,!1;0!==o.avail_out&&(0!==o.avail_in||4!==s&&2!==s)||("string"===this.options.to?this.onData(r.buf2binstring(n.shrinkBuf(o.output,o.next_out))):this.onData(n.shrinkBuf(o.output,o.next_out)))}while((o.avail_in>0||0===o.avail_out)&&1!==a);return 4===s?(a=i.deflateEnd(this.strm),this.onEnd(a),this.ended=!0,a===h):2!==s||(this.onEnd(h),o.avail_out=0,!0)},u.prototype.onData=function(t){this.chunks.push(t)},u.prototype.onEnd=function(t){t===h&&("string"===this.options.to?this.result=this.chunks.join(""):this.result=n.flattenChunks(this.chunks)),this.chunks=[],this.err=t,this.msg=this.strm.msg},a.Deflate=u,a.deflate=c,a.deflateRaw=function(t,e){return(e=e||{}).raw=!0,c(t,e)},a.gzip=function(t,e){return(e=e||{}).gzip=!0,c(t,e)}},{"./utils/common":3,"./utils/strings":4,"./zlib/deflate":8,"./zlib/messages":13,"./zlib/zstream":15}],2:[function(t,e,a){"use strict";var i=t("./zlib/inflate"),n=t("./utils/common"),r=t("./utils/strings"),s=t("./zlib/constants"),o=t("./zlib/messages"),l=t("./zlib/zstream"),h=t("./zlib/gzheader"),d=Object.prototype.toString;function f(t){if(!(this instanceof f))return new f(t);this.options=n.assign({chunkSize:16384,windowBits:0,to:""},t||{});var e=this.options;e.raw&&e.windowBits>=0&&e.windowBits<16&&(e.windowBits=-e.windowBits,0===e.windowBits&&(e.windowBits=-15)),!(e.windowBits>=0&&e.windowBits<16)||t&&t.windowBits||(e.windowBits+=32),e.windowBits>15&&e.windowBits<48&&0==(15&e.windowBits)&&(e.windowBits|=15),this.err=0,this.msg="",this.ended=!1,this.chunks=[],this.strm=new l,this.strm.avail_out=0;var a=i.inflateInit2(this.strm,e.windowBits);if(a!==s.Z_OK)throw new Error(o[a]);if(this.header=new h,i.inflateGetHeader(this.strm,this.header),e.dictionary&&("string"==typeof e.dictionary?e.dictionary=r.string2buf(e.dictionary):"[object ArrayBuffer]"===d.call(e.dictionary)&&(e.dictionary=new Uint8Array(e.dictionary)),e.raw&&(a=i.inflateSetDictionary(this.strm,e.dictionary))!==s.Z_OK))throw new Error(o[a])}function _(t,e){var a=new f(e);if(a.push(t,!0),a.err)throw a.msg||o[a.err];return a.result}f.prototype.push=function(t,e){var a,o,l,h,f,_=this.strm,u=this.options.chunkSize,c=this.options.dictionary,b=!1;if(this.ended)return!1;o=e===~~e?e:!0===e?s.Z_FINISH:s.Z_NO_FLUSH,"string"==typeof t?_.input=r.binstring2buf(t):"[object ArrayBuffer]"===d.call(t)?_.input=new Uint8Array(t):_.input=t,_.next_in=0,_.avail_in=_.input.length;do{if(0===_.avail_out&&(_.output=new n.Buf8(u),_.next_out=0,_.avail_out=u),(a=i.inflate(_,s.Z_NO_FLUSH))===s.Z_NEED_DICT&&c&&(a=i.inflateSetDictionary(this.strm,c)),a===s.Z_BUF_ERROR&&!0===b&&(a=s.Z_OK,b=!1),a!==s.Z_STREAM_END&&a!==s.Z_OK)return this.onEnd(a),this.ended=!0,!1;_.next_out&&(0!==_.avail_out&&a!==s.Z_STREAM_END&&(0!==_.avail_in||o!==s.Z_FINISH&&o!==s.Z_SYNC_FLUSH)||("string"===this.options.to?(l=r.utf8border(_.output,_.next_out),h=_.next_out-l,f=r.buf2string(_.output,l),_.next_out=h,_.avail_out=u-h,h&&n.arraySet(_.output,_.output,l,h,0),this.onData(f)):this.onData(n.shrinkBuf(_.output,_.next_out)))),0===_.avail_in&&0===_.avail_out&&(b=!0)}while((_.avail_in>0||0===_.avail_out)&&a!==s.Z_STREAM_END);return a===s.Z_STREAM_END&&(o=s.Z_FINISH),o===s.Z_FINISH?(a=i.inflateEnd(this.strm),this.onEnd(a),this.ended=!0,a===s.Z_OK):o!==s.Z_SYNC_FLUSH||(this.onEnd(s.Z_OK),_.avail_out=0,!0)},f.prototype.onData=function(t){this.chunks.push(t)},f.prototype.onEnd=function(t){t===s.Z_OK&&("string"===this.options.to?this.result=this.chunks.join(""):this.result=n.flattenChunks(this.chunks)),this.chunks=[],this.err=t,this.msg=this.strm.msg},a.Inflate=f,a.inflate=_,a.inflateRaw=function(t,e){return(e=e||{}).raw=!0,_(t,e)},a.ungzip=_},{"./utils/common":3,"./utils/strings":4,"./zlib/constants":6,"./zlib/gzheader":9,"./zlib/inflate":11,"./zlib/messages":13,"./zlib/zstream":15}],3:[function(t,e,a){"use strict";var i="undefined"!=typeof Uint8Array&&"undefined"!=typeof Uint16Array&&"undefined"!=typeof Int32Array;function n(t,e){return Object.prototype.hasOwnProperty.call(t,e)}a.assign=function(t){for(var e=Array.prototype.slice.call(arguments,1);e.length;){var a=e.shift();if(a){if("object"!=typeof a)throw new TypeError(a+"must be non-object");for(var i in a)n(a,i)&&(t[i]=a[i])}}return t},a.shrinkBuf=function(t,e){return t.length===e?t:t.subarray?t.subarray(0,e):(t.length=e,t)};var r={arraySet:function(t,e,a,i,n){if(e.subarray&&t.subarray)t.set(e.subarray(a,a+i),n);else for(var r=0;r<i;r++)t[n+r]=e[a+r]},flattenChunks:function(t){var e,a,i,n,r,s;for(i=0,e=0,a=t.length;e<a;e++)i+=t[e].length;for(s=new Uint8Array(i),n=0,e=0,a=t.length;e<a;e++)r=t[e],s.set(r,n),n+=r.length;return s}},s={arraySet:function(t,e,a,i,n){for(var r=0;r<i;r++)t[n+r]=e[a+r]},flattenChunks:function(t){return[].concat.apply([],t)}};a.setTyped=function(t){t?(a.Buf8=Uint8Array,a.Buf16=Uint16Array,a.Buf32=Int32Array,a.assign(a,r)):(a.Buf8=Array,a.Buf16=Array,a.Buf32=Array,a.assign(a,s))},a.setTyped(i)},{}],4:[function(t,e,a){"use strict";var i=t("./common"),n=!0,r=!0;try{String.fromCharCode.apply(null,[0])}catch(t){n=!1}try{String.fromCharCode.apply(null,new Uint8Array(1))}catch(t){r=!1}for(var s=new i.Buf8(256),o=0;o<256;o++)s[o]=o>=252?6:o>=248?5:o>=240?4:o>=224?3:o>=192?2:1;function l(t,e){if(e<65534&&(t.subarray&&r||!t.subarray&&n))return String.fromCharCode.apply(null,i.shrinkBuf(t,e));for(var a="",s=0;s<e;s++)a+=String.fromCharCode(t[s]);return a}s[254]=s[254]=1,a.string2buf=function(t){var e,a,n,r,s,o=t.length,l=0;for(r=0;r<o;r++)55296==(64512&(a=t.charCodeAt(r)))&&r+1<o&&56320==(64512&(n=t.charCodeAt(r+1)))&&(a=65536+(a-55296<<10)+(n-56320),r++),l+=a<128?1:a<2048?2:a<65536?3:4;for(e=new i.Buf8(l),s=0,r=0;s<l;r++)55296==(64512&(a=t.charCodeAt(r)))&&r+1<o&&56320==(64512&(n=t.charCodeAt(r+1)))&&(a=65536+(a-55296<<10)+(n-56320),r++),a<128?e[s++]=a:a<2048?(e[s++]=192|a>>>6,e[s++]=128|63&a):a<65536?(e[s++]=224|a>>>12,e[s++]=128|a>>>6&63,e[s++]=128|63&a):(e[s++]=240|a>>>18,e[s++]=128|a>>>12&63,e[s++]=128|a>>>6&63,e[s++]=128|63&a);return e},a.buf2binstring=function(t){return l(t,t.length)},a.binstring2buf=function(t){for(var e=new i.Buf8(t.length),a=0,n=e.length;a<n;a++)e[a]=t.charCodeAt(a);return e},a.buf2string=function(t,e){var a,i,n,r,o=e||t.length,h=new Array(2*o);for(i=0,a=0;a<o;)if((n=t[a++])<128)h[i++]=n;else if((r=s[n])>4)h[i++]=65533,a+=r-1;else{for(n&=2===r?31:3===r?15:7;r>1&&a<o;)n=n<<6|63&t[a++],r--;r>1?h[i++]=65533:n<65536?h[i++]=n:(n-=65536,h[i++]=55296|n>>10&1023,h[i++]=56320|1023&n)}return l(h,i)},a.utf8border=function(t,e){var a;for((e=e||t.length)>t.length&&(e=t.length),a=e-1;a>=0&&128==(192&t[a]);)a--;return a<0?e:0===a?e:a+s[t[a]]>e?a:e}},{"./common":3}],5:[function(t,e,a){"use strict";e.exports=function(t,e,a,i){for(var n=65535&t|0,r=t>>>16&65535|0,s=0;0!==a;){a-=s=a>2e3?2e3:a;do{r=r+(n=n+e[i++]|0)|0}while(--s);n%=65521,r%=65521}return n|r<<16|0}},{}],6:[function(t,e,a){"use strict";e.exports={Z_NO_FLUSH:0,Z_PARTIAL_FLUSH:1,Z_SYNC_FLUSH:2,Z_FULL_FLUSH:3,Z_FINISH:4,Z_BLOCK:5,Z_TREES:6,Z_OK:0,Z_STREAM_END:1,Z_NEED_DICT:2,Z_ERRNO:-1,Z_STREAM_ERROR:-2,Z_DATA_ERROR:-3,Z_BUF_ERROR:-5,Z_NO_COMPRESSION:0,Z_BEST_SPEED:1,Z_BEST_COMPRESSION:9,Z_DEFAULT_COMPRESSION:-1,Z_FILTERED:1,Z_HUFFMAN_ONLY:2,Z_RLE:3,Z_FIXED:4,Z_DEFAULT_STRATEGY:0,Z_BINARY:0,Z_TEXT:1,Z_UNKNOWN:2,Z_DEFLATED:8}},{}],7:[function(t,e,a){"use strict";var i=function(){for(var t,e=[],a=0;a<256;a++){t=a;for(var i=0;i<8;i++)t=1&t?3988292384^t>>>1:t>>>1;e[a]=t}return e}();e.exports=function(t,e,a,n){var r=i,s=n+a;t^=-1;for(var o=n;o<s;o++)t=t>>>8^r[255&(t^e[o])];return-1^t}},{}],8:[function(t,e,a){"use strict";var i,n=t("../utils/common"),r=t("./trees"),s=t("./adler32"),o=t("./crc32"),l=t("./messages"),h=0,d=1,f=3,_=4,u=5,c=0,b=1,g=-2,m=-3,w=-5,p=-1,v=1,k=2,y=3,x=4,z=0,B=2,S=8,E=9,A=15,Z=8,R=286,C=30,N=19,O=2*R+1,D=15,I=3,U=258,T=U+I+1,F=32,L=42,H=69,j=73,K=91,M=103,P=113,Y=666,q=1,G=2,X=3,W=4,J=3;function Q(t,e){return t.msg=l[e],e}function V(t){return(t<<1)-(t>4?9:0)}function $(t){for(var e=t.length;--e>=0;)t[e]=0}function tt(t){var e=t.state,a=e.pending;a>t.avail_out&&(a=t.avail_out),0!==a&&(n.arraySet(t.output,e.pending_buf,e.pending_out,a,t.next_out),t.next_out+=a,e.pending_out+=a,t.total_out+=a,t.avail_out-=a,e.pending-=a,0===e.pending&&(e.pending_out=0))}function et(t,e){r._tr_flush_block(t,t.block_start>=0?t.block_start:-1,t.strstart-t.block_start,e),t.block_start=t.strstart,tt(t.strm)}function at(t,e){t.pending_buf[t.pending++]=e}function it(t,e){t.pending_buf[t.pending++]=e>>>8&255,t.pending_buf[t.pending++]=255&e}function nt(t,e){var a,i,n=t.max_chain_length,r=t.strstart,s=t.prev_length,o=t.nice_match,l=t.strstart>t.w_size-T?t.strstart-(t.w_size-T):0,h=t.window,d=t.w_mask,f=t.prev,_=t.strstart+U,u=h[r+s-1],c=h[r+s];t.prev_length>=t.good_match&&(n>>=2),o>t.lookahead&&(o=t.lookahead);do{if(h[(a=e)+s]===c&&h[a+s-1]===u&&h[a]===h[r]&&h[++a]===h[r+1]){r+=2,a++;do{}while(h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&r<_);if(i=U-(_-r),r=_-U,i>s){if(t.match_start=e,s=i,i>=o)break;u=h[r+s-1],c=h[r+s]}}}while((e=f[e&d])>l&&0!=--n);return s<=t.lookahead?s:t.lookahead}function rt(t){var e,a,i,r,l,h,d,f,_,u,c=t.w_size;do{if(r=t.window_size-t.lookahead-t.strstart,t.strstart>=c+(c-T)){n.arraySet(t.window,t.window,c,c,0),t.match_start-=c,t.strstart-=c,t.block_start-=c,e=a=t.hash_size;do{i=t.head[--e],t.head[e]=i>=c?i-c:0}while(--a);e=a=c;do{i=t.prev[--e],t.prev[e]=i>=c?i-c:0}while(--a);r+=c}if(0===t.strm.avail_in)break;if(h=t.strm,d=t.window,f=t.strstart+t.lookahead,_=r,u=void 0,(u=h.avail_in)>_&&(u=_),a=0===u?0:(h.avail_in-=u,n.arraySet(d,h.input,h.next_in,u,f),1===h.state.wrap?h.adler=s(h.adler,d,u,f):2===h.state.wrap&&(h.adler=o(h.adler,d,u,f)),h.next_in+=u,h.total_in+=u,u),t.lookahead+=a,t.lookahead+t.insert>=I)for(l=t.strstart-t.insert,t.ins_h=t.window[l],t.ins_h=(t.ins_h<<t.hash_shift^t.window[l+1])&t.hash_mask;t.insert&&(t.ins_h=(t.ins_h<<t.hash_shift^t.window[l+I-1])&t.hash_mask,t.prev[l&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=l,l++,t.insert--,!(t.lookahead+t.insert<I)););}while(t.lookahead<T&&0!==t.strm.avail_in)}function st(t,e){for(var a,i;;){if(t.lookahead<T){if(rt(t),t.lookahead<T&&e===h)return q;if(0===t.lookahead)break}if(a=0,t.lookahead>=I&&(t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+I-1])&t.hash_mask,a=t.prev[t.strstart&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=t.strstart),0!==a&&t.strstart-a<=t.w_size-T&&(t.match_length=nt(t,a)),t.match_length>=I)if(i=r._tr_tally(t,t.strstart-t.match_start,t.match_length-I),t.lookahead-=t.match_length,t.match_length<=t.max_lazy_match&&t.lookahead>=I){t.match_length--;do{t.strstart++,t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+I-1])&t.hash_mask,a=t.prev[t.strstart&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=t.strstart}while(0!=--t.match_length);t.strstart++}else t.strstart+=t.match_length,t.match_length=0,t.ins_h=t.window[t.strstart],t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+1])&t.hash_mask;else i=r._tr_tally(t,0,t.window[t.strstart]),t.lookahead--,t.strstart++;if(i&&(et(t,!1),0===t.strm.avail_out))return q}return t.insert=t.strstart<I-1?t.strstart:I-1,e===_?(et(t,!0),0===t.strm.avail_out?X:W):t.last_lit&&(et(t,!1),0===t.strm.avail_out)?q:G}function ot(t,e){for(var a,i,n;;){if(t.lookahead<T){if(rt(t),t.lookahead<T&&e===h)return q;if(0===t.lookahead)break}if(a=0,t.lookahead>=I&&(t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+I-1])&t.hash_mask,a=t.prev[t.strstart&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=t.strstart),t.prev_length=t.match_length,t.prev_match=t.match_start,t.match_length=I-1,0!==a&&t.prev_length<t.max_lazy_match&&t.strstart-a<=t.w_size-T&&(t.match_length=nt(t,a),t.match_length<=5&&(t.strategy===v||t.match_length===I&&t.strstart-t.match_start>4096)&&(t.match_length=I-1)),t.prev_length>=I&&t.match_length<=t.prev_length){n=t.strstart+t.lookahead-I,i=r._tr_tally(t,t.strstart-1-t.prev_match,t.prev_length-I),t.lookahead-=t.prev_length-1,t.prev_length-=2;do{++t.strstart<=n&&(t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+I-1])&t.hash_mask,a=t.prev[t.strstart&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=t.strstart)}while(0!=--t.prev_length);if(t.match_available=0,t.match_length=I-1,t.strstart++,i&&(et(t,!1),0===t.strm.avail_out))return q}else if(t.match_available){if((i=r._tr_tally(t,0,t.window[t.strstart-1]))&&et(t,!1),t.strstart++,t.lookahead--,0===t.strm.avail_out)return q}else t.match_available=1,t.strstart++,t.lookahead--}return t.match_available&&(i=r._tr_tally(t,0,t.window[t.strstart-1]),t.match_available=0),t.insert=t.strstart<I-1?t.strstart:I-1,e===_?(et(t,!0),0===t.strm.avail_out?X:W):t.last_lit&&(et(t,!1),0===t.strm.avail_out)?q:G}function lt(t,e,a,i,n){this.good_length=t,this.max_lazy=e,this.nice_length=a,this.max_chain=i,this.func=n}function ht(){this.strm=null,this.status=0,this.pending_buf=null,this.pending_buf_size=0,this.pending_out=0,this.pending=0,this.wrap=0,this.gzhead=null,this.gzindex=0,this.method=S,this.last_flush=-1,this.w_size=0,this.w_bits=0,this.w_mask=0,this.window=null,this.window_size=0,this.prev=null,this.head=null,this.ins_h=0,this.hash_size=0,this.hash_bits=0,this.hash_mask=0,this.hash_shift=0,this.block_start=0,this.match_length=0,this.prev_match=0,this.match_available=0,this.strstart=0,this.match_start=0,this.lookahead=0,this.prev_length=0,this.max_chain_length=0,this.max_lazy_match=0,this.level=0,this.strategy=0,this.good_match=0,this.nice_match=0,this.dyn_ltree=new n.Buf16(2*O),this.dyn_dtree=new n.Buf16(2*(2*C+1)),this.bl_tree=new n.Buf16(2*(2*N+1)),$(this.dyn_ltree),$(this.dyn_dtree),$(this.bl_tree),this.l_desc=null,this.d_desc=null,this.bl_desc=null,this.bl_count=new n.Buf16(D+1),this.heap=new n.Buf16(2*R+1),$(this.heap),this.heap_len=0,this.heap_max=0,this.depth=new n.Buf16(2*R+1),$(this.depth),this.l_buf=0,this.lit_bufsize=0,this.last_lit=0,this.d_buf=0,this.opt_len=0,this.static_len=0,this.matches=0,this.insert=0,this.bi_buf=0,this.bi_valid=0}function dt(t){var e;return t&&t.state?(t.total_in=t.total_out=0,t.data_type=B,(e=t.state).pending=0,e.pending_out=0,e.wrap<0&&(e.wrap=-e.wrap),e.status=e.wrap?L:P,t.adler=2===e.wrap?0:1,e.last_flush=h,r._tr_init(e),c):Q(t,g)}function ft(t){var e,a=dt(t);return a===c&&((e=t.state).window_size=2*e.w_size,$(e.head),e.max_lazy_match=i[e.level].max_lazy,e.good_match=i[e.level].good_length,e.nice_match=i[e.level].nice_length,e.max_chain_length=i[e.level].max_chain,e.strstart=0,e.block_start=0,e.lookahead=0,e.insert=0,e.match_length=e.prev_length=I-1,e.match_available=0,e.ins_h=0),a}function _t(t,e,a,i,r,s){if(!t)return g;var o=1;if(e===p&&(e=6),i<0?(o=0,i=-i):i>15&&(o=2,i-=16),r<1||r>E||a!==S||i<8||i>15||e<0||e>9||s<0||s>x)return Q(t,g);8===i&&(i=9);var l=new ht;return t.state=l,l.strm=t,l.wrap=o,l.gzhead=null,l.w_bits=i,l.w_size=1<<l.w_bits,l.w_mask=l.w_size-1,l.hash_bits=r+7,l.hash_size=1<<l.hash_bits,l.hash_mask=l.hash_size-1,l.hash_shift=~~((l.hash_bits+I-1)/I),l.window=new n.Buf8(2*l.w_size),l.head=new n.Buf16(l.hash_size),l.prev=new n.Buf16(l.w_size),l.lit_bufsize=1<<r+6,l.pending_buf_size=4*l.lit_bufsize,l.pending_buf=new n.Buf8(l.pending_buf_size),l.d_buf=1*l.lit_bufsize,l.l_buf=3*l.lit_bufsize,l.level=e,l.strategy=s,l.method=a,ft(t)}i=[new lt(0,0,0,0,function(t,e){var a=65535;for(a>t.pending_buf_size-5&&(a=t.pending_buf_size-5);;){if(t.lookahead<=1){if(rt(t),0===t.lookahead&&e===h)return q;if(0===t.lookahead)break}t.strstart+=t.lookahead,t.lookahead=0;var i=t.block_start+a;if((0===t.strstart||t.strstart>=i)&&(t.lookahead=t.strstart-i,t.strstart=i,et(t,!1),0===t.strm.avail_out))return q;if(t.strstart-t.block_start>=t.w_size-T&&(et(t,!1),0===t.strm.avail_out))return q}return t.insert=0,e===_?(et(t,!0),0===t.strm.avail_out?X:W):(t.strstart>t.block_start&&(et(t,!1),t.strm.avail_out),q)}),new lt(4,4,8,4,st),new lt(4,5,16,8,st),new lt(4,6,32,32,st),new lt(4,4,16,16,ot),new lt(8,16,32,32,ot),new lt(8,16,128,128,ot),new lt(8,32,128,256,ot),new lt(32,128,258,1024,ot),new lt(32,258,258,4096,ot)],a.deflateInit=function(t,e){return _t(t,e,S,A,Z,z)},a.deflateInit2=_t,a.deflateReset=ft,a.deflateResetKeep=dt,a.deflateSetHeader=function(t,e){return t&&t.state?2!==t.state.wrap?g:(t.state.gzhead=e,c):g},a.deflate=function(t,e){var a,n,s,l;if(!t||!t.state||e>u||e<0)return t?Q(t,g):g;if(n=t.state,!t.output||!t.input&&0!==t.avail_in||n.status===Y&&e!==_)return Q(t,0===t.avail_out?w:g);if(n.strm=t,a=n.last_flush,n.last_flush=e,n.status===L)if(2===n.wrap)t.adler=0,at(n,31),at(n,139),at(n,8),n.gzhead?(at(n,(n.gzhead.text?1:0)+(n.gzhead.hcrc?2:0)+(n.gzhead.extra?4:0)+(n.gzhead.name?8:0)+(n.gzhead.comment?16:0)),at(n,255&n.gzhead.time),at(n,n.gzhead.time>>8&255),at(n,n.gzhead.time>>16&255),at(n,n.gzhead.time>>24&255),at(n,9===n.level?2:n.strategy>=k||n.level<2?4:0),at(n,255&n.gzhead.os),n.gzhead.extra&&n.gzhead.extra.length&&(at(n,255&n.gzhead.extra.length),at(n,n.gzhead.extra.length>>8&255)),n.gzhead.hcrc&&(t.adler=o(t.adler,n.pending_buf,n.pending,0)),n.gzindex=0,n.status=H):(at(n,0),at(n,0),at(n,0),at(n,0),at(n,0),at(n,9===n.level?2:n.strategy>=k||n.level<2?4:0),at(n,J),n.status=P);else{var m=S+(n.w_bits-8<<4)<<8;m|=(n.strategy>=k||n.level<2?0:n.level<6?1:6===n.level?2:3)<<6,0!==n.strstart&&(m|=F),m+=31-m%31,n.status=P,it(n,m),0!==n.strstart&&(it(n,t.adler>>>16),it(n,65535&t.adler)),t.adler=1}if(n.status===H)if(n.gzhead.extra){for(s=n.pending;n.gzindex<(65535&n.gzhead.extra.length)&&(n.pending!==n.pending_buf_size||(n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),tt(t),s=n.pending,n.pending!==n.pending_buf_size));)at(n,255&n.gzhead.extra[n.gzindex]),n.gzindex++;n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),n.gzindex===n.gzhead.extra.length&&(n.gzindex=0,n.status=j)}else n.status=j;if(n.status===j)if(n.gzhead.name){s=n.pending;do{if(n.pending===n.pending_buf_size&&(n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),tt(t),s=n.pending,n.pending===n.pending_buf_size)){l=1;break}l=n.gzindex<n.gzhead.name.length?255&n.gzhead.name.charCodeAt(n.gzindex++):0,at(n,l)}while(0!==l);n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),0===l&&(n.gzindex=0,n.status=K)}else n.status=K;if(n.status===K)if(n.gzhead.comment){s=n.pending;do{if(n.pending===n.pending_buf_size&&(n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),tt(t),s=n.pending,n.pending===n.pending_buf_size)){l=1;break}l=n.gzindex<n.gzhead.comment.length?255&n.gzhead.comment.charCodeAt(n.gzindex++):0,at(n,l)}while(0!==l);n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),0===l&&(n.status=M)}else n.status=M;if(n.status===M&&(n.gzhead.hcrc?(n.pending+2>n.pending_buf_size&&tt(t),n.pending+2<=n.pending_buf_size&&(at(n,255&t.adler),at(n,t.adler>>8&255),t.adler=0,n.status=P)):n.status=P),0!==n.pending){if(tt(t),0===t.avail_out)return n.last_flush=-1,c}else if(0===t.avail_in&&V(e)<=V(a)&&e!==_)return Q(t,w);if(n.status===Y&&0!==t.avail_in)return Q(t,w);if(0!==t.avail_in||0!==n.lookahead||e!==h&&n.status!==Y){var p=n.strategy===k?function(t,e){for(var a;;){if(0===t.lookahead&&(rt(t),0===t.lookahead)){if(e===h)return q;break}if(t.match_length=0,a=r._tr_tally(t,0,t.window[t.strstart]),t.lookahead--,t.strstart++,a&&(et(t,!1),0===t.strm.avail_out))return q}return t.insert=0,e===_?(et(t,!0),0===t.strm.avail_out?X:W):t.last_lit&&(et(t,!1),0===t.strm.avail_out)?q:G}(n,e):n.strategy===y?function(t,e){for(var a,i,n,s,o=t.window;;){if(t.lookahead<=U){if(rt(t),t.lookahead<=U&&e===h)return q;if(0===t.lookahead)break}if(t.match_length=0,t.lookahead>=I&&t.strstart>0&&(i=o[n=t.strstart-1])===o[++n]&&i===o[++n]&&i===o[++n]){s=t.strstart+U;do{}while(i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&n<s);t.match_length=U-(s-n),t.match_length>t.lookahead&&(t.match_length=t.lookahead)}if(t.match_length>=I?(a=r._tr_tally(t,1,t.match_length-I),t.lookahead-=t.match_length,t.strstart+=t.match_length,t.match_length=0):(a=r._tr_tally(t,0,t.window[t.strstart]),t.lookahead--,t.strstart++),a&&(et(t,!1),0===t.strm.avail_out))return q}return t.insert=0,e===_?(et(t,!0),0===t.strm.avail_out?X:W):t.last_lit&&(et(t,!1),0===t.strm.avail_out)?q:G}(n,e):i[n.level].func(n,e);if(p!==X&&p!==W||(n.status=Y),p===q||p===X)return 0===t.avail_out&&(n.last_flush=-1),c;if(p===G&&(e===d?r._tr_align(n):e!==u&&(r._tr_stored_block(n,0,0,!1),e===f&&($(n.head),0===n.lookahead&&(n.strstart=0,n.block_start=0,n.insert=0))),tt(t),0===t.avail_out))return n.last_flush=-1,c}return e!==_?c:n.wrap<=0?b:(2===n.wrap?(at(n,255&t.adler),at(n,t.adler>>8&255),at(n,t.adler>>16&255),at(n,t.adler>>24&255),at(n,255&t.total_in),at(n,t.total_in>>8&255),at(n,t.total_in>>16&255),at(n,t.total_in>>24&255)):(it(n,t.adler>>>16),it(n,65535&t.adler)),tt(t),n.wrap>0&&(n.wrap=-n.wrap),0!==n.pending?c:b)},a.deflateEnd=function(t){var e;return t&&t.state?(e=t.state.status)!==L&&e!==H&&e!==j&&e!==K&&e!==M&&e!==P&&e!==Y?Q(t,g):(t.state=null,e===P?Q(t,m):c):g},a.deflateSetDictionary=function(t,e){var a,i,r,o,l,h,d,f,_=e.length;if(!t||!t.state)return g;if(2===(o=(a=t.state).wrap)||1===o&&a.status!==L||a.lookahead)return g;for(1===o&&(t.adler=s(t.adler,e,_,0)),a.wrap=0,_>=a.w_size&&(0===o&&($(a.head),a.strstart=0,a.block_start=0,a.insert=0),f=new n.Buf8(a.w_size),n.arraySet(f,e,_-a.w_size,a.w_size,0),e=f,_=a.w_size),l=t.avail_in,h=t.next_in,d=t.input,t.avail_in=_,t.next_in=0,t.input=e,rt(a);a.lookahead>=I;){i=a.strstart,r=a.lookahead-(I-1);do{a.ins_h=(a.ins_h<<a.hash_shift^a.window[i+I-1])&a.hash_mask,a.prev[i&a.w_mask]=a.head[a.ins_h],a.head[a.ins_h]=i,i++}while(--r);a.strstart=i,a.lookahead=I-1,rt(a)}return a.strstart+=a.lookahead,a.block_start=a.strstart,a.insert=a.lookahead,a.lookahead=0,a.match_length=a.prev_length=I-1,a.match_available=0,t.next_in=h,t.input=d,t.avail_in=l,a.wrap=o,c},a.deflateInfo="pako deflate (from Nodeca project)"},{"../utils/common":3,"./adler32":5,"./crc32":7,"./messages":13,"./trees":14}],9:[function(t,e,a){"use strict";e.exports=function(){this.text=0,this.time=0,this.xflags=0,this.os=0,this.extra=null,this.extra_len=0,this.name="",this.comment="",this.hcrc=0,this.done=!1}},{}],10:[function(t,e,a){"use strict";e.exports=function(t,e){var a,i,n,r,s,o,l,h,d,f,_,u,c,b,g,m,w,p,v,k,y,x,z,B,S;a=t.state,i=t.next_in,B=t.input,n=i+(t.avail_in-5),r=t.next_out,S=t.output,s=r-(e-t.avail_out),o=r+(t.avail_out-257),l=a.dmax,h=a.wsize,d=a.whave,f=a.wnext,_=a.window,u=a.hold,c=a.bits,b=a.lencode,g=a.distcode,m=(1<<a.lenbits)-1,w=(1<<a.distbits)-1;t:do{c<15&&(u+=B[i++]<<c,c+=8,u+=B[i++]<<c,c+=8),p=b[u&m];e:for(;;){if(u>>>=v=p>>>24,c-=v,0===(v=p>>>16&255))S[r++]=65535&p;else{if(!(16&v)){if(0==(64&v)){p=b[(65535&p)+(u&(1<<v)-1)];continue e}if(32&v){a.mode=12;break t}t.msg="invalid literal/length code",a.mode=30;break t}k=65535&p,(v&=15)&&(c<v&&(u+=B[i++]<<c,c+=8),k+=u&(1<<v)-1,u>>>=v,c-=v),c<15&&(u+=B[i++]<<c,c+=8,u+=B[i++]<<c,c+=8),p=g[u&w];a:for(;;){if(u>>>=v=p>>>24,c-=v,!(16&(v=p>>>16&255))){if(0==(64&v)){p=g[(65535&p)+(u&(1<<v)-1)];continue a}t.msg="invalid distance code",a.mode=30;break t}if(y=65535&p,c<(v&=15)&&(u+=B[i++]<<c,(c+=8)<v&&(u+=B[i++]<<c,c+=8)),(y+=u&(1<<v)-1)>l){t.msg="invalid distance too far back",a.mode=30;break t}if(u>>>=v,c-=v,y>(v=r-s)){if((v=y-v)>d&&a.sane){t.msg="invalid distance too far back",a.mode=30;break t}if(x=0,z=_,0===f){if(x+=h-v,v<k){k-=v;do{S[r++]=_[x++]}while(--v);x=r-y,z=S}}else if(f<v){if(x+=h+f-v,(v-=f)<k){k-=v;do{S[r++]=_[x++]}while(--v);if(x=0,f<k){k-=v=f;do{S[r++]=_[x++]}while(--v);x=r-y,z=S}}}else if(x+=f-v,v<k){k-=v;do{S[r++]=_[x++]}while(--v);x=r-y,z=S}for(;k>2;)S[r++]=z[x++],S[r++]=z[x++],S[r++]=z[x++],k-=3;k&&(S[r++]=z[x++],k>1&&(S[r++]=z[x++]))}else{x=r-y;do{S[r++]=S[x++],S[r++]=S[x++],S[r++]=S[x++],k-=3}while(k>2);k&&(S[r++]=S[x++],k>1&&(S[r++]=S[x++]))}break}}break}}while(i<n&&r<o);i-=k=c>>3,u&=(1<<(c-=k<<3))-1,t.next_in=i,t.next_out=r,t.avail_in=i<n?n-i+5:5-(i-n),t.avail_out=r<o?o-r+257:257-(r-o),a.hold=u,a.bits=c}},{}],11:[function(t,e,a){"use strict";var i=t("../utils/common"),n=t("./adler32"),r=t("./crc32"),s=t("./inffast"),o=t("./inftrees"),l=0,h=1,d=2,f=4,_=5,u=6,c=0,b=1,g=2,m=-2,w=-3,p=-4,v=-5,k=8,y=1,x=2,z=3,B=4,S=5,E=6,A=7,Z=8,R=9,C=10,N=11,O=12,D=13,I=14,U=15,T=16,F=17,L=18,H=19,j=20,K=21,M=22,P=23,Y=24,q=25,G=26,X=27,W=28,J=29,Q=30,V=31,$=32,tt=852,et=592,at=15;function it(t){return(t>>>24&255)+(t>>>8&65280)+((65280&t)<<8)+((255&t)<<24)}function nt(){this.mode=0,this.last=!1,this.wrap=0,this.havedict=!1,this.flags=0,this.dmax=0,this.check=0,this.total=0,this.head=null,this.wbits=0,this.wsize=0,this.whave=0,this.wnext=0,this.window=null,this.hold=0,this.bits=0,this.length=0,this.offset=0,this.extra=0,this.lencode=null,this.distcode=null,this.lenbits=0,this.distbits=0,this.ncode=0,this.nlen=0,this.ndist=0,this.have=0,this.next=null,this.lens=new i.Buf16(320),this.work=new i.Buf16(288),this.lendyn=null,this.distdyn=null,this.sane=0,this.back=0,this.was=0}function rt(t){var e;return t&&t.state?(e=t.state,t.total_in=t.total_out=e.total=0,t.msg="",e.wrap&&(t.adler=1&e.wrap),e.mode=y,e.last=0,e.havedict=0,e.dmax=32768,e.head=null,e.hold=0,e.bits=0,e.lencode=e.lendyn=new i.Buf32(tt),e.distcode=e.distdyn=new i.Buf32(et),e.sane=1,e.back=-1,c):m}function st(t){var e;return t&&t.state?((e=t.state).wsize=0,e.whave=0,e.wnext=0,rt(t)):m}function ot(t,e){var a,i;return t&&t.state?(i=t.state,e<0?(a=0,e=-e):(a=1+(e>>4),e<48&&(e&=15)),e&&(e<8||e>15)?m:(null!==i.window&&i.wbits!==e&&(i.window=null),i.wrap=a,i.wbits=e,st(t))):m}function lt(t,e){var a,i;return t?(i=new nt,t.state=i,i.window=null,(a=ot(t,e))!==c&&(t.state=null),a):m}var ht,dt,ft=!0;function _t(t){if(ft){var e;for(ht=new i.Buf32(512),dt=new i.Buf32(32),e=0;e<144;)t.lens[e++]=8;for(;e<256;)t.lens[e++]=9;for(;e<280;)t.lens[e++]=7;for(;e<288;)t.lens[e++]=8;for(o(h,t.lens,0,288,ht,0,t.work,{bits:9}),e=0;e<32;)t.lens[e++]=5;o(d,t.lens,0,32,dt,0,t.work,{bits:5}),ft=!1}t.lencode=ht,t.lenbits=9,t.distcode=dt,t.distbits=5}function ut(t,e,a,n){var r,s=t.state;return null===s.window&&(s.wsize=1<<s.wbits,s.wnext=0,s.whave=0,s.window=new i.Buf8(s.wsize)),n>=s.wsize?(i.arraySet(s.window,e,a-s.wsize,s.wsize,0),s.wnext=0,s.whave=s.wsize):((r=s.wsize-s.wnext)>n&&(r=n),i.arraySet(s.window,e,a-n,r,s.wnext),(n-=r)?(i.arraySet(s.window,e,a-n,n,0),s.wnext=n,s.whave=s.wsize):(s.wnext+=r,s.wnext===s.wsize&&(s.wnext=0),s.whave<s.wsize&&(s.whave+=r))),0}a.inflateReset=st,a.inflateReset2=ot,a.inflateResetKeep=rt,a.inflateInit=function(t){return lt(t,at)},a.inflateInit2=lt,a.inflate=function(t,e){var a,tt,et,at,nt,rt,st,ot,lt,ht,dt,ft,ct,bt,gt,mt,wt,pt,vt,kt,yt,xt,zt,Bt,St=0,Et=new i.Buf8(4),At=[16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15];if(!t||!t.state||!t.output||!t.input&&0!==t.avail_in)return m;(a=t.state).mode===O&&(a.mode=D),nt=t.next_out,et=t.output,st=t.avail_out,at=t.next_in,tt=t.input,rt=t.avail_in,ot=a.hold,lt=a.bits,ht=rt,dt=st,xt=c;t:for(;;)switch(a.mode){case y:if(0===a.wrap){a.mode=D;break}for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(2&a.wrap&&35615===ot){a.check=0,Et[0]=255&ot,Et[1]=ot>>>8&255,a.check=r(a.check,Et,2,0),ot=0,lt=0,a.mode=x;break}if(a.flags=0,a.head&&(a.head.done=!1),!(1&a.wrap)||(((255&ot)<<8)+(ot>>8))%31){t.msg="incorrect header check",a.mode=Q;break}if((15&ot)!==k){t.msg="unknown compression method",a.mode=Q;break}if(lt-=4,yt=8+(15&(ot>>>=4)),0===a.wbits)a.wbits=yt;else if(yt>a.wbits){t.msg="invalid window size",a.mode=Q;break}a.dmax=1<<yt,t.adler=a.check=1,a.mode=512&ot?C:O,ot=0,lt=0;break;case x:for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(a.flags=ot,(255&a.flags)!==k){t.msg="unknown compression method",a.mode=Q;break}if(57344&a.flags){t.msg="unknown header flags set",a.mode=Q;break}a.head&&(a.head.text=ot>>8&1),512&a.flags&&(Et[0]=255&ot,Et[1]=ot>>>8&255,a.check=r(a.check,Et,2,0)),ot=0,lt=0,a.mode=z;case z:for(;lt<32;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.head&&(a.head.time=ot),512&a.flags&&(Et[0]=255&ot,Et[1]=ot>>>8&255,Et[2]=ot>>>16&255,Et[3]=ot>>>24&255,a.check=r(a.check,Et,4,0)),ot=0,lt=0,a.mode=B;case B:for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.head&&(a.head.xflags=255&ot,a.head.os=ot>>8),512&a.flags&&(Et[0]=255&ot,Et[1]=ot>>>8&255,a.check=r(a.check,Et,2,0)),ot=0,lt=0,a.mode=S;case S:if(1024&a.flags){for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.length=ot,a.head&&(a.head.extra_len=ot),512&a.flags&&(Et[0]=255&ot,Et[1]=ot>>>8&255,a.check=r(a.check,Et,2,0)),ot=0,lt=0}else a.head&&(a.head.extra=null);a.mode=E;case E:if(1024&a.flags&&((ft=a.length)>rt&&(ft=rt),ft&&(a.head&&(yt=a.head.extra_len-a.length,a.head.extra||(a.head.extra=new Array(a.head.extra_len)),i.arraySet(a.head.extra,tt,at,ft,yt)),512&a.flags&&(a.check=r(a.check,tt,ft,at)),rt-=ft,at+=ft,a.length-=ft),a.length))break t;a.length=0,a.mode=A;case A:if(2048&a.flags){if(0===rt)break t;ft=0;do{yt=tt[at+ft++],a.head&&yt&&a.length<65536&&(a.head.name+=String.fromCharCode(yt))}while(yt&&ft<rt);if(512&a.flags&&(a.check=r(a.check,tt,ft,at)),rt-=ft,at+=ft,yt)break t}else a.head&&(a.head.name=null);a.length=0,a.mode=Z;case Z:if(4096&a.flags){if(0===rt)break t;ft=0;do{yt=tt[at+ft++],a.head&&yt&&a.length<65536&&(a.head.comment+=String.fromCharCode(yt))}while(yt&&ft<rt);if(512&a.flags&&(a.check=r(a.check,tt,ft,at)),rt-=ft,at+=ft,yt)break t}else a.head&&(a.head.comment=null);a.mode=R;case R:if(512&a.flags){for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(ot!==(65535&a.check)){t.msg="header crc mismatch",a.mode=Q;break}ot=0,lt=0}a.head&&(a.head.hcrc=a.flags>>9&1,a.head.done=!0),t.adler=a.check=0,a.mode=O;break;case C:for(;lt<32;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}t.adler=a.check=it(ot),ot=0,lt=0,a.mode=N;case N:if(0===a.havedict)return t.next_out=nt,t.avail_out=st,t.next_in=at,t.avail_in=rt,a.hold=ot,a.bits=lt,g;t.adler=a.check=1,a.mode=O;case O:if(e===_||e===u)break t;case D:if(a.last){ot>>>=7<,lt-=7<,a.mode=X;break}for(;lt<3;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}switch(a.last=1&ot,lt-=1,3&(ot>>>=1)){case 0:a.mode=I;break;case 1:if(_t(a),a.mode=j,e===u){ot>>>=2,lt-=2;break t}break;case 2:a.mode=F;break;case 3:t.msg="invalid block type",a.mode=Q}ot>>>=2,lt-=2;break;case I:for(ot>>>=7<,lt-=7<lt<32;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if((65535&ot)!=(ot>>>16^65535)){t.msg="invalid stored block lengths",a.mode=Q;break}if(a.length=65535&ot,ot=0,lt=0,a.mode=U,e===u)break t;case U:a.mode=T;case T:if(ft=a.length){if(ft>rt&&(ft=rt),ft>st&&(ft=st),0===ft)break t;i.arraySet(et,tt,at,ft,nt),rt-=ft,at+=ft,st-=ft,nt+=ft,a.length-=ft;break}a.mode=O;break;case F:for(;lt<14;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(a.nlen=257+(31&ot),ot>>>=5,lt-=5,a.ndist=1+(31&ot),ot>>>=5,lt-=5,a.ncode=4+(15&ot),ot>>>=4,lt-=4,a.nlen>286||a.ndist>30){t.msg="too many length or distance symbols",a.mode=Q;break}a.have=0,a.mode=L;case L:for(;a.have<a.ncode;){for(;lt<3;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.lens[At[a.have++]]=7&ot,ot>>>=3,lt-=3}for(;a.have<19;)a.lens[At[a.have++]]=0;if(a.lencode=a.lendyn,a.lenbits=7,zt={bits:a.lenbits},xt=o(l,a.lens,0,19,a.lencode,0,a.work,zt),a.lenbits=zt.bits,xt){t.msg="invalid code lengths set",a.mode=Q;break}a.have=0,a.mode=H;case H:for(;a.have<a.nlen+a.ndist;){for(;mt=(St=a.lencode[ot&(1<<a.lenbits)-1])>>>16&255,wt=65535&St,!((gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(wt<16)ot>>>=gt,lt-=gt,a.lens[a.have++]=wt;else{if(16===wt){for(Bt=gt+2;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(ot>>>=gt,lt-=gt,0===a.have){t.msg="invalid bit length repeat",a.mode=Q;break}yt=a.lens[a.have-1],ft=3+(3&ot),ot>>>=2,lt-=2}else if(17===wt){for(Bt=gt+3;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}lt-=gt,yt=0,ft=3+(7&(ot>>>=gt)),ot>>>=3,lt-=3}else{for(Bt=gt+7;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}lt-=gt,yt=0,ft=11+(127&(ot>>>=gt)),ot>>>=7,lt-=7}if(a.have+ft>a.nlen+a.ndist){t.msg="invalid bit length repeat",a.mode=Q;break}for(;ft--;)a.lens[a.have++]=yt}}if(a.mode===Q)break;if(0===a.lens[256]){t.msg="invalid code -- missing end-of-block",a.mode=Q;break}if(a.lenbits=9,zt={bits:a.lenbits},xt=o(h,a.lens,0,a.nlen,a.lencode,0,a.work,zt),a.lenbits=zt.bits,xt){t.msg="invalid literal/lengths set",a.mode=Q;break}if(a.distbits=6,a.distcode=a.distdyn,zt={bits:a.distbits},xt=o(d,a.lens,a.nlen,a.ndist,a.distcode,0,a.work,zt),a.distbits=zt.bits,xt){t.msg="invalid distances set",a.mode=Q;break}if(a.mode=j,e===u)break t;case j:a.mode=K;case K:if(rt>=6&&st>=258){t.next_out=nt,t.avail_out=st,t.next_in=at,t.avail_in=rt,a.hold=ot,a.bits=lt,s(t,dt),nt=t.next_out,et=t.output,st=t.avail_out,at=t.next_in,tt=t.input,rt=t.avail_in,ot=a.hold,lt=a.bits,a.mode===O&&(a.back=-1);break}for(a.back=0;mt=(St=a.lencode[ot&(1<<a.lenbits)-1])>>>16&255,wt=65535&St,!((gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(mt&&0==(240&mt)){for(pt=gt,vt=mt,kt=wt;mt=(St=a.lencode[kt+((ot&(1<<pt+vt)-1)>>pt)])>>>16&255,wt=65535&St,!(pt+(gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}ot>>>=pt,lt-=pt,a.back+=pt}if(ot>>>=gt,lt-=gt,a.back+=gt,a.length=wt,0===mt){a.mode=G;break}if(32&mt){a.back=-1,a.mode=O;break}if(64&mt){t.msg="invalid literal/length code",a.mode=Q;break}a.extra=15&mt,a.mode=M;case M:if(a.extra){for(Bt=a.extra;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.length+=ot&(1<<a.extra)-1,ot>>>=a.extra,lt-=a.extra,a.back+=a.extra}a.was=a.length,a.mode=P;case P:for(;mt=(St=a.distcode[ot&(1<<a.distbits)-1])>>>16&255,wt=65535&St,!((gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(0==(240&mt)){for(pt=gt,vt=mt,kt=wt;mt=(St=a.distcode[kt+((ot&(1<<pt+vt)-1)>>pt)])>>>16&255,wt=65535&St,!(pt+(gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}ot>>>=pt,lt-=pt,a.back+=pt}if(ot>>>=gt,lt-=gt,a.back+=gt,64&mt){t.msg="invalid distance code",a.mode=Q;break}a.offset=wt,a.extra=15&mt,a.mode=Y;case Y:if(a.extra){for(Bt=a.extra;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.offset+=ot&(1<<a.extra)-1,ot>>>=a.extra,lt-=a.extra,a.back+=a.extra}if(a.offset>a.dmax){t.msg="invalid distance too far back",a.mode=Q;break}a.mode=q;case q:if(0===st)break t;if(ft=dt-st,a.offset>ft){if((ft=a.offset-ft)>a.whave&&a.sane){t.msg="invalid distance too far back",a.mode=Q;break}ft>a.wnext?(ft-=a.wnext,ct=a.wsize-ft):ct=a.wnext-ft,ft>a.length&&(ft=a.length),bt=a.window}else bt=et,ct=nt-a.offset,ft=a.length;ft>st&&(ft=st),st-=ft,a.length-=ft;do{et[nt++]=bt[ct++]}while(--ft);0===a.length&&(a.mode=K);break;case G:if(0===st)break t;et[nt++]=a.length,st--,a.mode=K;break;case X:if(a.wrap){for(;lt<32;){if(0===rt)break t;rt--,ot|=tt[at++]<<lt,lt+=8}if(dt-=st,t.total_out+=dt,a.total+=dt,dt&&(t.adler=a.check=a.flags?r(a.check,et,dt,nt-dt):n(a.check,et,dt,nt-dt)),dt=st,(a.flags?ot:it(ot))!==a.check){t.msg="incorrect data check",a.mode=Q;break}ot=0,lt=0}a.mode=W;case W:if(a.wrap&&a.flags){for(;lt<32;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(ot!==(4294967295&a.total)){t.msg="incorrect length check",a.mode=Q;break}ot=0,lt=0}a.mode=J;case J:xt=b;break t;case Q:xt=w;break t;case V:return p;case $:default:return m}return t.next_out=nt,t.avail_out=st,t.next_in=at,t.avail_in=rt,a.hold=ot,a.bits=lt,(a.wsize||dt!==t.avail_out&&a.mode<Q&&(a.mode<X||e!==f))&&ut(t,t.output,t.next_out,dt-t.avail_out)?(a.mode=V,p):(ht-=t.avail_in,dt-=t.avail_out,t.total_in+=ht,t.total_out+=dt,a.total+=dt,a.wrap&&dt&&(t.adler=a.check=a.flags?r(a.check,et,dt,t.next_out-dt):n(a.check,et,dt,t.next_out-dt)),t.data_type=a.bits+(a.last?64:0)+(a.mode===O?128:0)+(a.mode===j||a.mode===U?256:0),(0===ht&&0===dt||e===f)&&xt===c&&(xt=v),xt)},a.inflateEnd=function(t){if(!t||!t.state)return m;var e=t.state;return e.window&&(e.window=null),t.state=null,c},a.inflateGetHeader=function(t,e){var a;return t&&t.state?0==(2&(a=t.state).wrap)?m:(a.head=e,e.done=!1,c):m},a.inflateSetDictionary=function(t,e){var a,i=e.length;return t&&t.state?0!==(a=t.state).wrap&&a.mode!==N?m:a.mode===N&&n(1,e,i,0)!==a.check?w:ut(t,e,i,i)?(a.mode=V,p):(a.havedict=1,c):m},a.inflateInfo="pako inflate (from Nodeca project)"},{"../utils/common":3,"./adler32":5,"./crc32":7,"./inffast":10,"./inftrees":12}],12:[function(t,e,a){"use strict";var i=t("../utils/common"),n=[3,4,5,6,7,8,9,10,11,13,15,17,19,23,27,31,35,43,51,59,67,83,99,115,131,163,195,227,258,0,0],r=[16,16,16,16,16,16,16,16,17,17,17,17,18,18,18,18,19,19,19,19,20,20,20,20,21,21,21,21,16,72,78],s=[1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577,0,0],o=[16,16,16,16,17,17,18,18,19,19,20,20,21,21,22,22,23,23,24,24,25,25,26,26,27,27,28,28,29,29,64,64];e.exports=function(t,e,a,l,h,d,f,_){var u,c,b,g,m,w,p,v,k,y=_.bits,x=0,z=0,B=0,S=0,E=0,A=0,Z=0,R=0,C=0,N=0,O=null,D=0,I=new i.Buf16(16),U=new i.Buf16(16),T=null,F=0;for(x=0;x<=15;x++)I[x]=0;for(z=0;z<l;z++)I[e[a+z]]++;for(E=y,S=15;S>=1&&0===I[S];S--);if(E>S&&(E=S),0===S)return h[d++]=20971520,h[d++]=20971520,_.bits=1,0;for(B=1;B<S&&0===I[B];B++);for(E<B&&(E=B),R=1,x=1;x<=15;x++)if(R<<=1,(R-=I[x])<0)return-1;if(R>0&&(0===t||1!==S))return-1;for(U[1]=0,x=1;x<15;x++)U[x+1]=U[x]+I[x];for(z=0;z<l;z++)0!==e[a+z]&&(f[U[e[a+z]]++]=z);if(0===t?(O=T=f,w=19):1===t?(O=n,D-=257,T=r,F-=257,w=256):(O=s,T=o,w=-1),N=0,z=0,x=B,m=d,A=E,Z=0,b=-1,g=(C=1<<E)-1,1===t&&C>852||2===t&&C>592)return 1;for(;;){p=x-Z,f[z]<w?(v=0,k=f[z]):f[z]>w?(v=T[F+f[z]],k=O[D+f[z]]):(v=96,k=0),u=1<<x-Z,B=c=1<<A;do{h[m+(N>>Z)+(c-=u)]=p<<24|v<<16|k|0}while(0!==c);for(u=1<<x-1;N&u;)u>>=1;if(0!==u?(N&=u-1,N+=u):N=0,z++,0==--I[x]){if(x===S)break;x=e[a+f[z]]}if(x>E&&(N&g)!==b){for(0===Z&&(Z=E),m+=B,R=1<<(A=x-Z);A+Z<S&&!((R-=I[A+Z])<=0);)A++,R<<=1;if(C+=1<<A,1===t&&C>852||2===t&&C>592)return 1;h[b=N&g]=E<<24|A<<16|m-d|0}}return 0!==N&&(h[m+N]=x-Z<<24|64<<16|0),_.bits=E,0}},{"../utils/common":3}],13:[function(t,e,a){"use strict";e.exports={2:"need dictionary",1:"stream end",0:"","-1":"file error","-2":"stream error","-3":"data error","-4":"insufficient memory","-5":"buffer error","-6":"incompatible version"}},{}],14:[function(t,e,a){"use strict";var i=t("../utils/common"),n=4,r=0,s=1,o=2;function l(t){for(var e=t.length;--e>=0;)t[e]=0}var h=0,d=1,f=2,_=29,u=256,c=u+1+_,b=30,g=19,m=2*c+1,w=15,p=16,v=7,k=256,y=16,x=17,z=18,B=[0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0],S=[0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13],E=[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,3,7],A=[16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15],Z=new Array(2*(c+2));l(Z);var R=new Array(2*b);l(R);var C=new Array(512);l(C);var N=new Array(256);l(N);var O=new Array(_);l(O);var D,I,U,T=new Array(b);function F(t,e,a,i,n){this.static_tree=t,this.extra_bits=e,this.extra_base=a,this.elems=i,this.max_length=n,this.has_stree=t&&t.length}function L(t,e){this.dyn_tree=t,this.max_code=0,this.stat_desc=e}function H(t){return t<256?C[t]:C[256+(t>>>7)]}function j(t,e){t.pending_buf[t.pending++]=255&e,t.pending_buf[t.pending++]=e>>>8&255}function K(t,e,a){t.bi_valid>p-a?(t.bi_buf|=e<<t.bi_valid&65535,j(t,t.bi_buf),t.bi_buf=e>>p-t.bi_valid,t.bi_valid+=a-p):(t.bi_buf|=e<<t.bi_valid&65535,t.bi_valid+=a)}function M(t,e,a){K(t,a[2*e],a[2*e+1])}function P(t,e){var a=0;do{a|=1&t,t>>>=1,a<<=1}while(--e>0);return a>>>1}function Y(t,e,a){var i,n,r=new Array(w+1),s=0;for(i=1;i<=w;i++)r[i]=s=s+a[i-1]<<1;for(n=0;n<=e;n++){var o=t[2*n+1];0!==o&&(t[2*n]=P(r[o]++,o))}}function q(t){var e;for(e=0;e<c;e++)t.dyn_ltree[2*e]=0;for(e=0;e<b;e++)t.dyn_dtree[2*e]=0;for(e=0;e<g;e++)t.bl_tree[2*e]=0;t.dyn_ltree[2*k]=1,t.opt_len=t.static_len=0,t.last_lit=t.matches=0}function G(t){t.bi_valid>8?j(t,t.bi_buf):t.bi_valid>0&&(t.pending_buf[t.pending++]=t.bi_buf),t.bi_buf=0,t.bi_valid=0}function X(t,e,a,i){var n=2*e,r=2*a;return t[n]<t[r]||t[n]===t[r]&&i[e]<=i[a]}function W(t,e,a){for(var i=t.heap[a],n=a<<1;n<=t.heap_len&&(n<t.heap_len&&X(e,t.heap[n+1],t.heap[n],t.depth)&&n++,!X(e,i,t.heap[n],t.depth));)t.heap[a]=t.heap[n],a=n,n<<=1;t.heap[a]=i}function J(t,e,a){var i,n,r,s,o=0;if(0!==t.last_lit)do{i=t.pending_buf[t.d_buf+2*o]<<8|t.pending_buf[t.d_buf+2*o+1],n=t.pending_buf[t.l_buf+o],o++,0===i?M(t,n,e):(M(t,(r=N[n])+u+1,e),0!==(s=B[r])&&K(t,n-=O[r],s),M(t,r=H(--i),a),0!==(s=S[r])&&K(t,i-=T[r],s))}while(o<t.last_lit);M(t,k,e)}function Q(t,e){var a,i,n,r=e.dyn_tree,s=e.stat_desc.static_tree,o=e.stat_desc.has_stree,l=e.stat_desc.elems,h=-1;for(t.heap_len=0,t.heap_max=m,a=0;a<l;a++)0!==r[2*a]?(t.heap[++t.heap_len]=h=a,t.depth[a]=0):r[2*a+1]=0;for(;t.heap_len<2;)r[2*(n=t.heap[++t.heap_len]=h<2?++h:0)]=1,t.depth[n]=0,t.opt_len--,o&&(t.static_len-=s[2*n+1]);for(e.max_code=h,a=t.heap_len>>1;a>=1;a--)W(t,r,a);n=l;do{a=t.heap[1],t.heap[1]=t.heap[t.heap_len--],W(t,r,1),i=t.heap[1],t.heap[--t.heap_max]=a,t.heap[--t.heap_max]=i,r[2*n]=r[2*a]+r[2*i],t.depth[n]=(t.depth[a]>=t.depth[i]?t.depth[a]:t.depth[i])+1,r[2*a+1]=r[2*i+1]=n,t.heap[1]=n++,W(t,r,1)}while(t.heap_len>=2);t.heap[--t.heap_max]=t.heap[1],function(t,e){var a,i,n,r,s,o,l=e.dyn_tree,h=e.max_code,d=e.stat_desc.static_tree,f=e.stat_desc.has_stree,_=e.stat_desc.extra_bits,u=e.stat_desc.extra_base,c=e.stat_desc.max_length,b=0;for(r=0;r<=w;r++)t.bl_count[r]=0;for(l[2*t.heap[t.heap_max]+1]=0,a=t.heap_max+1;a<m;a++)(r=l[2*l[2*(i=t.heap[a])+1]+1]+1)>c&&(r=c,b++),l[2*i+1]=r,i>h||(t.bl_count[r]++,s=0,i>=u&&(s=_[i-u]),o=l[2*i],t.opt_len+=o*(r+s),f&&(t.static_len+=o*(d[2*i+1]+s)));if(0!==b){do{for(r=c-1;0===t.bl_count[r];)r--;t.bl_count[r]--,t.bl_count[r+1]+=2,t.bl_count[c]--,b-=2}while(b>0);for(r=c;0!==r;r--)for(i=t.bl_count[r];0!==i;)(n=t.heap[--a])>h||(l[2*n+1]!==r&&(t.opt_len+=(r-l[2*n+1])*l[2*n],l[2*n+1]=r),i--)}}(t,e),Y(r,h,t.bl_count)}function V(t,e,a){var i,n,r=-1,s=e[1],o=0,l=7,h=4;for(0===s&&(l=138,h=3),e[2*(a+1)+1]=65535,i=0;i<=a;i++)n=s,s=e[2*(i+1)+1],++o<l&&n===s||(o<h?t.bl_tree[2*n]+=o:0!==n?(n!==r&&t.bl_tree[2*n]++,t.bl_tree[2*y]++):o<=10?t.bl_tree[2*x]++:t.bl_tree[2*z]++,o=0,r=n,0===s?(l=138,h=3):n===s?(l=6,h=3):(l=7,h=4))}function $(t,e,a){var i,n,r=-1,s=e[1],o=0,l=7,h=4;for(0===s&&(l=138,h=3),i=0;i<=a;i++)if(n=s,s=e[2*(i+1)+1],!(++o<l&&n===s)){if(o<h)do{M(t,n,t.bl_tree)}while(0!=--o);else 0!==n?(n!==r&&(M(t,n,t.bl_tree),o--),M(t,y,t.bl_tree),K(t,o-3,2)):o<=10?(M(t,x,t.bl_tree),K(t,o-3,3)):(M(t,z,t.bl_tree),K(t,o-11,7));o=0,r=n,0===s?(l=138,h=3):n===s?(l=6,h=3):(l=7,h=4)}}l(T);var tt=!1;function et(t,e,a,n){K(t,(h<<1)+(n?1:0),3),function(t,e,a,n){G(t),n&&(j(t,a),j(t,~a)),i.arraySet(t.pending_buf,t.window,e,a,t.pending),t.pending+=a}(t,e,a,!0)}a._tr_init=function(t){tt||(function(){var t,e,a,i,n,r=new Array(w+1);for(a=0,i=0;i<_-1;i++)for(O[i]=a,t=0;t<1<<B[i];t++)N[a++]=i;for(N[a-1]=i,n=0,i=0;i<16;i++)for(T[i]=n,t=0;t<1<<S[i];t++)C[n++]=i;for(n>>=7;i<b;i++)for(T[i]=n<<7,t=0;t<1<<S[i]-7;t++)C[256+n++]=i;for(e=0;e<=w;e++)r[e]=0;for(t=0;t<=143;)Z[2*t+1]=8,t++,r[8]++;for(;t<=255;)Z[2*t+1]=9,t++,r[9]++;for(;t<=279;)Z[2*t+1]=7,t++,r[7]++;for(;t<=287;)Z[2*t+1]=8,t++,r[8]++;for(Y(Z,c+1,r),t=0;t<b;t++)R[2*t+1]=5,R[2*t]=P(t,5);D=new F(Z,B,u+1,c,w),I=new F(R,S,0,b,w),U=new F(new Array(0),E,0,g,v)}(),tt=!0),t.l_desc=new L(t.dyn_ltree,D),t.d_desc=new L(t.dyn_dtree,I),t.bl_desc=new L(t.bl_tree,U),t.bi_buf=0,t.bi_valid=0,q(t)},a._tr_stored_block=et,a._tr_flush_block=function(t,e,a,i){var l,h,_=0;t.level>0?(t.strm.data_type===o&&(t.strm.data_type=function(t){var e,a=4093624447;for(e=0;e<=31;e++,a>>>=1)if(1&a&&0!==t.dyn_ltree[2*e])return r;if(0!==t.dyn_ltree[18]||0!==t.dyn_ltree[20]||0!==t.dyn_ltree[26])return s;for(e=32;e<u;e++)if(0!==t.dyn_ltree[2*e])return s;return r}(t)),Q(t,t.l_desc),Q(t,t.d_desc),_=function(t){var e;for(V(t,t.dyn_ltree,t.l_desc.max_code),V(t,t.dyn_dtree,t.d_desc.max_code),Q(t,t.bl_desc),e=g-1;e>=3&&0===t.bl_tree[2*A[e]+1];e--);return t.opt_len+=3*(e+1)+5+5+4,e}(t),l=t.opt_len+3+7>>>3,(h=t.static_len+3+7>>>3)<=l&&(l=h)):l=h=a+5,a+4<=l&&-1!==e?et(t,e,a,i):t.strategy===n||h===l?(K(t,(d<<1)+(i?1:0),3),J(t,Z,R)):(K(t,(f<<1)+(i?1:0),3),function(t,e,a,i){var n;for(K(t,e-257,5),K(t,a-1,5),K(t,i-4,4),n=0;n<i;n++)K(t,t.bl_tree[2*A[n]+1],3);$(t,t.dyn_ltree,e-1),$(t,t.dyn_dtree,a-1)}(t,t.l_desc.max_code+1,t.d_desc.max_code+1,_+1),J(t,t.dyn_ltree,t.dyn_dtree)),q(t),i&&G(t)},a._tr_tally=function(t,e,a){return t.pending_buf[t.d_buf+2*t.last_lit]=e>>>8&255,t.pending_buf[t.d_buf+2*t.last_lit+1]=255&e,t.pending_buf[t.l_buf+t.last_lit]=255&a,t.last_lit++,0===e?t.dyn_ltree[2*a]++:(t.matches++,e--,t.dyn_ltree[2*(N[a]+u+1)]++,t.dyn_dtree[2*H(e)]++),t.last_lit===t.lit_bufsize-1},a._tr_align=function(t){K(t,d<<1,3),M(t,k,Z),function(t){16===t.bi_valid?(j(t,t.bi_buf),t.bi_buf=0,t.bi_valid=0):t.bi_valid>=8&&(t.pending_buf[t.pending++]=255&t.bi_buf,t.bi_buf>>=8,t.bi_valid-=8)}(t)}},{"../utils/common":3}],15:[function(t,e,a){"use strict";e.exports=function(){this.input=null,this.next_in=0,this.avail_in=0,this.total_in=0,this.output=null,this.next_out=0,this.avail_out=0,this.total_out=0,this.msg="",this.state=null,this.data_type=2,this.adler=0}},{}],"/":[function(t,e,a){"use strict";var i={};(0,t("./lib/utils/common").assign)(i,t("./lib/deflate"),t("./lib/inflate"),t("./lib/zlib/constants")),e.exports=i},{"./lib/deflate":1,"./lib/inflate":2,"./lib/utils/common":3,"./lib/zlib/constants":6}]},{},[])("/")});
|
139 |
+
</script>
|
140 |
+
<script>
|
141 |
+
!function(){var e={};"object"==typeof module?module.exports=e:window.UPNG=e,function(e,r){e.toRGBA8=function(r){var t=r.width,n=r.height;if(null==r.tabs.acTL)return[e.toRGBA8.decodeImage(r.data,t,n,r).buffer];var i=[];null==r.frames[0].data&&(r.frames[0].data=r.data);for(var a,f=new Uint8Array(t*n*4),o=0;o<r.frames.length;o++){var s=r.frames[o],l=s.rect.x,c=s.rect.y,u=s.rect.width,d=s.rect.height,h=e.toRGBA8.decodeImage(s.data,u,d,r);if(0==o?a=h:0==s.blend?e._copyTile(h,u,d,a,t,n,l,c,0):1==s.blend&&e._copyTile(h,u,d,a,t,n,l,c,1),i.push(a.buffer),a=a.slice(0),0==s.dispose);else if(1==s.dispose)e._copyTile(f,u,d,a,t,n,l,c,0);else if(2==s.dispose){for(var v=o-1;2==r.frames[v].dispose;)v--;a=new Uint8Array(i[v]).slice(0)}}return i},e.toRGBA8.decodeImage=function(r,t,n,i){var a=t*n,f=e.decode._getBPP(i),o=Math.ceil(t*f/8),s=new Uint8Array(4*a),l=new Uint32Array(s.buffer),c=i.ctype,u=i.depth,d=e._bin.readUshort;if(6==c){var h=a<<2;if(8==u)for(var v=0;v<h;v++)s[v]=r[v];if(16==u)for(v=0;v<h;v++)s[v]=r[v<<1]}else if(2==c){var p=i.tabs.tRNS,b=-1,g=-1,m=-1;if(p&&(b=p[0],g=p[1],m=p[2]),8==u)for(v=0;v<a;v++){var y=3*v;s[M=v<<2]=r[y],s[M+1]=r[y+1],s[M+2]=r[y+2],s[M+3]=255,-1!=b&&r[y]==b&&r[y+1]==g&&r[y+2]==m&&(s[M+3]=0)}if(16==u)for(v=0;v<a;v++){y=6*v;s[M=v<<2]=r[y],s[M+1]=r[y+2],s[M+2]=r[y+4],s[M+3]=255,-1!=b&&d(r,y)==b&&d(r,y+2)==g&&d(r,y+4)==m&&(s[M+3]=0)}}else if(3==c){var w=i.tabs.PLTE,A=i.tabs.tRNS,U=A?A.length:0;if(1==u)for(var _=0;_<n;_++){var q=_*o,I=_*t;for(v=0;v<t;v++){var M=I+v<<2,T=3*(z=r[q+(v>>3)]>>7-((7&v)<<0)&1);s[M]=w[T],s[M+1]=w[T+1],s[M+2]=w[T+2],s[M+3]=z<U?A[z]:255}}if(2==u)for(_=0;_<n;_++)for(q=_*o,I=_*t,v=0;v<t;v++){M=I+v<<2,T=3*(z=r[q+(v>>2)]>>6-((3&v)<<1)&3);s[M]=w[T],s[M+1]=w[T+1],s[M+2]=w[T+2],s[M+3]=z<U?A[z]:255}if(4==u)for(_=0;_<n;_++)for(q=_*o,I=_*t,v=0;v<t;v++){M=I+v<<2,T=3*(z=r[q+(v>>1)]>>4-((1&v)<<2)&15);s[M]=w[T],s[M+1]=w[T+1],s[M+2]=w[T+2],s[M+3]=z<U?A[z]:255}if(8==u)for(v=0;v<a;v++){var z;M=v<<2,T=3*(z=r[v]);s[M]=w[T],s[M+1]=w[T+1],s[M+2]=w[T+2],s[M+3]=z<U?A[z]:255}}else if(4==c){if(8==u)for(v=0;v<a;v++){M=v<<2;var R=r[N=v<<1];s[M]=R,s[M+1]=R,s[M+2]=R,s[M+3]=r[N+1]}if(16==u)for(v=0;v<a;v++){var N;M=v<<2,R=r[N=v<<2];s[M]=R,s[M+1]=R,s[M+2]=R,s[M+3]=r[N+2]}}else if(0==c){b=i.tabs.tRNS?i.tabs.tRNS:-1;if(1==u)for(v=0;v<a;v++){var L=(R=255*(r[v>>3]>>7-(7&v)&1))==255*b?0:255;l[v]=L<<24|R<<16|R<<8|R}if(2==u)for(v=0;v<a;v++){L=(R=85*(r[v>>2]>>6-((3&v)<<1)&3))==85*b?0:255;l[v]=L<<24|R<<16|R<<8|R}if(4==u)for(v=0;v<a;v++){L=(R=17*(r[v>>1]>>4-((1&v)<<2)&15))==17*b?0:255;l[v]=L<<24|R<<16|R<<8|R}if(8==u)for(v=0;v<a;v++){L=(R=r[v])==b?0:255;l[v]=L<<24|R<<16|R<<8|R}if(16==u)for(v=0;v<a;v++){R=r[v<<1],L=d(r,v<<1)==b?0:255;l[v]=L<<24|R<<16|R<<8|R}}return s},e.decode=function(r){for(var t,n=new Uint8Array(r),i=8,a=e._bin,f=a.readUshort,o=a.readUint,s={tabs:{},frames:[]},l=new Uint8Array(n.length),c=0,u=0,d=[137,80,78,71,13,10,26,10],h=0;h<8;h++)if(n[h]!=d[h])throw"The input is not a PNG file!";for(;i<n.length;){var v=a.readUint(n,i);i+=4;var p=a.readASCII(n,i,4);if(i+=4,"IHDR"==p)e.decode._IHDR(n,i,s);else if("IDAT"==p){for(h=0;h<v;h++)l[c+h]=n[i+h];c+=v}else if("acTL"==p)s.tabs[p]={num_frames:o(n,i),num_plays:o(n,i+4)},t=new Uint8Array(n.length);else if("fcTL"==p){var b;if(0!=u)(b=s.frames[s.frames.length-1]).data=e.decode._decompress(s,t.slice(0,u),b.rect.width,b.rect.height),u=0;var g={x:o(n,i+12),y:o(n,i+16),width:o(n,i+4),height:o(n,i+8)},m=f(n,i+22);m=f(n,i+20)/(0==m?100:m);var y={rect:g,delay:Math.round(1e3*m),dispose:n[i+24],blend:n[i+25]};s.frames.push(y)}else if("fdAT"==p){for(h=0;h<v-4;h++)t[u+h]=n[i+h+4];u+=v-4}else if("pHYs"==p)s.tabs[p]=[a.readUint(n,i),a.readUint(n,i+4),n[i+8]];else if("cHRM"==p){s.tabs[p]=[];for(h=0;h<8;h++)s.tabs[p].push(a.readUint(n,i+4*h))}else if("tEXt"==p){null==s.tabs[p]&&(s.tabs[p]={});var w=a.nextZero(n,i),A=a.readASCII(n,i,w-i),U=a.readASCII(n,w+1,i+v-w-1);s.tabs[p][A]=U}else if("iTXt"==p){null==s.tabs[p]&&(s.tabs[p]={});w=0;var _=i;w=a.nextZero(n,_);A=a.readASCII(n,_,w-_),n[_=w+1],n[_+1];_+=2,w=a.nextZero(n,_);a.readASCII(n,_,w-_);_=w+1,w=a.nextZero(n,_);a.readUTF8(n,_,w-_);_=w+1;U=a.readUTF8(n,_,v-(_-i));s.tabs[p][A]=U}else if("PLTE"==p)s.tabs[p]=a.readBytes(n,i,v);else if("hIST"==p){var q=s.tabs.PLTE.length/3;s.tabs[p]=[];for(h=0;h<q;h++)s.tabs[p].push(f(n,i+2*h))}else if("tRNS"==p)3==s.ctype?s.tabs[p]=a.readBytes(n,i,v):0==s.ctype?s.tabs[p]=f(n,i):2==s.ctype&&(s.tabs[p]=[f(n,i),f(n,i+2),f(n,i+4)]);else if("gAMA"==p)s.tabs[p]=a.readUint(n,i)/1e5;else if("sRGB"==p)s.tabs[p]=n[i];else if("bKGD"==p)0==s.ctype||4==s.ctype?s.tabs[p]=[f(n,i)]:2==s.ctype||6==s.ctype?s.tabs[p]=[f(n,i),f(n,i+2),f(n,i+4)]:3==s.ctype&&(s.tabs[p]=n[i]);else if("IEND"==p)break;i+=v;a.readUint(n,i);i+=4}0!=u&&((b=s.frames[s.frames.length-1]).data=e.decode._decompress(s,t.slice(0,u),b.rect.width,b.rect.height),u=0);return s.data=e.decode._decompress(s,l,s.width,s.height),delete s.compress,delete s.interlace,delete s.filter,s},e.decode._decompress=function(r,t,n,i){return 0==r.compress&&(t=e.decode._inflate(t)),0==r.interlace?t=e.decode._filterZero(t,r,0,n,i):1==r.interlace&&(t=e.decode._readInterlace(t,r)),t},e.decode._inflate=function(e){return r.inflate(e)},e.decode._readInterlace=function(r,t){for(var n=t.width,i=t.height,a=e.decode._getBPP(t),f=a>>3,o=Math.ceil(n*a/8),s=new Uint8Array(i*o),l=0,c=[0,0,4,0,2,0,1],u=[0,4,0,2,0,1,0],d=[8,8,8,4,4,2,2],h=[8,8,4,4,2,2,1],v=0;v<7;){for(var p=d[v],b=h[v],g=0,m=0,y=c[v];y<i;)y+=p,m++;for(var w=u[v];w<n;)w+=b,g++;var A=Math.ceil(g*a/8);e.decode._filterZero(r,t,l,g,m);for(var U=0,_=c[v];_<i;){for(var q=u[v],I=l+U*A<<3;q<n;){var M;if(1==a)M=(M=r[I>>3])>>7-(7&I)&1,s[_*o+(q>>3)]|=M<<7-((3&q)<<0);if(2==a)M=(M=r[I>>3])>>6-(7&I)&3,s[_*o+(q>>2)]|=M<<6-((3&q)<<1);if(4==a)M=(M=r[I>>3])>>4-(7&I)&15,s[_*o+(q>>1)]|=M<<4-((1&q)<<2);if(a>=8)for(var T=_*o+q*f,z=0;z<f;z++)s[T+z]=r[(I>>3)+z];I+=a,q+=b}U++,_+=p}g*m!=0&&(l+=m*(1+A)),v+=1}return s},e.decode._getBPP=function(e){return[1,null,3,1,2,null,4][e.ctype]*e.depth},e.decode._filterZero=function(r,t,n,i,a){var f=e.decode._getBPP(t),o=Math.ceil(i*f/8),s=e.decode._paeth;f=Math.ceil(f/8);for(var l=0;l<a;l++){var c=n+l*o,u=c+l+1,d=r[u-1];if(0==d)for(var h=0;h<o;h++)r[c+h]=r[u+h];else if(1==d){for(h=0;h<f;h++)r[c+h]=r[u+h];for(h=f;h<o;h++)r[c+h]=r[u+h]+r[c+h-f]&255}else if(0==l){for(h=0;h<f;h++)r[c+h]=r[u+h];if(2==d)for(h=f;h<o;h++)r[c+h]=255&r[u+h];if(3==d)for(h=f;h<o;h++)r[c+h]=r[u+h]+(r[c+h-f]>>1)&255;if(4==d)for(h=f;h<o;h++)r[c+h]=r[u+h]+s(r[c+h-f],0,0)&255}else{if(2==d)for(h=0;h<o;h++)r[c+h]=r[u+h]+r[c+h-o]&255;if(3==d){for(h=0;h<f;h++)r[c+h]=r[u+h]+(r[c+h-o]>>1)&255;for(h=f;h<o;h++)r[c+h]=r[u+h]+(r[c+h-o]+r[c+h-f]>>1)&255}if(4==d){for(h=0;h<f;h++)r[c+h]=r[u+h]+s(0,r[c+h-o],0)&255;for(h=f;h<o;h++)r[c+h]=r[u+h]+s(r[c+h-f],r[c+h-o],r[c+h-f-o])&255}}}return r},e.decode._paeth=function(e,r,t){var n=e+r-t,i=Math.abs(n-e),a=Math.abs(n-r),f=Math.abs(n-t);return i<=a&&i<=f?e:a<=f?r:t},e.decode._IHDR=function(r,t,n){var i=e._bin;n.width=i.readUint(r,t),t+=4,n.height=i.readUint(r,t),t+=4,n.depth=r[t],t++,n.ctype=r[t],t++,n.compress=r[t],t++,n.filter=r[t],t++,n.interlace=r[t],t++},e._bin={nextZero:function(e,r){for(;0!=e[r];)r++;return r},readUshort:function(e,r){return e[r]<<8|e[r+1]},writeUshort:function(e,r,t){e[r]=t>>8&255,e[r+1]=255&t},readUint:function(e,r){return 16777216*e[r]+(e[r+1]<<16|e[r+2]<<8|e[r+3])},writeUint:function(e,r,t){e[r]=t>>24&255,e[r+1]=t>>16&255,e[r+2]=t>>8&255,e[r+3]=255&t},readASCII:function(e,r,t){for(var n="",i=0;i<t;i++)n+=String.fromCharCode(e[r+i]);return n},writeASCII:function(e,r,t){for(var n=0;n<t.length;n++)e[r+n]=t.charCodeAt(n)},readBytes:function(e,r,t){for(var n=[],i=0;i<t;i++)n.push(e[r+i]);return n},pad:function(e){return e.length<2?"0"+e:e},readUTF8:function(r,t,n){for(var i,a="",f=0;f<n;f++)a+="%"+e._bin.pad(r[t+f].toString(16));try{i=decodeURIComponent(a)}catch(i){return e._bin.readASCII(r,t,n)}return i}},e._copyTile=function(e,r,t,n,i,a,f,o,s){for(var l=Math.min(r,i),c=Math.min(t,a),u=0,d=0,h=0;h<c;h++)for(var v=0;v<l;v++)if(f>=0&&o>=0?(u=h*r+v<<2,d=(o+h)*i+f+v<<2):(u=(-o+h)*r-f+v<<2,d=h*i+v<<2),0==s)n[d]=e[u],n[d+1]=e[u+1],n[d+2]=e[u+2],n[d+3]=e[u+3];else if(1==s){var p=e[u+3]*(1/255),b=e[u]*p,g=e[u+1]*p,m=e[u+2]*p,y=n[d+3]*(1/255),w=n[d]*y,A=n[d+1]*y,U=n[d+2]*y,_=1-p,q=p+y*_,I=0==q?0:1/q;n[d+3]=255*q,n[d+0]=(b+w*_)*I,n[d+1]=(g+A*_)*I,n[d+2]=(m+U*_)*I}else if(2==s){p=e[u+3],b=e[u],g=e[u+1],m=e[u+2],y=n[d+3],w=n[d],A=n[d+1],U=n[d+2];p==y&&b==w&&g==A&&m==U?(n[d]=0,n[d+1]=0,n[d+2]=0,n[d+3]=0):(n[d]=b,n[d+1]=g,n[d+2]=m,n[d+3]=p)}else if(3==s){p=e[u+3],b=e[u],g=e[u+1],m=e[u+2],y=n[d+3],w=n[d],A=n[d+1],U=n[d+2];if(p==y&&b==w&&g==A&&m==U)continue;if(p<220&&y>20)return!1}return!0},e.encode=function(r,t,n,i,a,f){null==i&&(i=0),null==f&&(f=!1);var o=e.encode.compress(r,t,n,i,!1,f);return e.encode.compressPNG(o,-1),e.encode._main(o,t,n,a)},e.encodeLL=function(r,t,n,i,a,f,o){for(var s={ctype:0+(1==i?0:2)+(0==a?0:4),depth:f,frames:[]},l=(i+a)*f,c=l*t,u=0;u<r.length;u++)s.frames.push({rect:{x:0,y:0,width:t,height:n},img:new Uint8Array(r[u]),blend:0,dispose:1,bpp:Math.ceil(l/8),bpl:Math.ceil(c/8)});return e.encode.compressPNG(s,4),e.encode._main(s,t,n,o)},e.encode._main=function(r,t,n,i){var a=e.crc.crc,f=e._bin.writeUint,o=e._bin.writeUshort,s=e._bin.writeASCII,l=8,c=r.frames.length>1,u=!1,d=46+(c?20:0);if(3==r.ctype){for(var h=r.plte.length,v=0;v<h;v++)r.plte[v]>>>24!=255&&(u=!0);d+=8+3*h+4+(u?8+1*h+4:0)}for(var p=0;p<r.frames.length;p++){c&&(d+=38),d+=(q=r.frames[p]).cimg.length+12,0!=p&&(d+=4)}d+=12;var b=new Uint8Array(d),g=[137,80,78,71,13,10,26,10];for(v=0;v<8;v++)b[v]=g[v];if(f(b,l,13),s(b,l+=4,"IHDR"),f(b,l+=4,t),f(b,l+=4,n),b[l+=4]=r.depth,b[++l]=r.ctype,b[++l]=0,b[++l]=0,b[++l]=0,f(b,++l,a(b,l-17,17)),f(b,l+=4,1),s(b,l+=4,"sRGB"),b[l+=4]=1,f(b,++l,a(b,l-5,5)),l+=4,c&&(f(b,l,8),s(b,l+=4,"acTL"),f(b,l+=4,r.frames.length),f(b,l+=4,0),f(b,l+=4,a(b,l-12,12)),l+=4),3==r.ctype){f(b,l,3*(h=r.plte.length)),s(b,l+=4,"PLTE"),l+=4;for(v=0;v<h;v++){var m=3*v,y=r.plte[v],w=255&y,A=y>>>8&255,U=y>>>16&255;b[l+m+0]=w,b[l+m+1]=A,b[l+m+2]=U}if(f(b,l+=3*h,a(b,l-3*h-4,3*h+4)),l+=4,u){f(b,l,h),s(b,l+=4,"tRNS"),l+=4;for(v=0;v<h;v++)b[l+v]=r.plte[v]>>>24&255;f(b,l+=h,a(b,l-h-4,h+4)),l+=4}}var _=0;for(p=0;p<r.frames.length;p++){var q=r.frames[p];c&&(f(b,l,26),s(b,l+=4,"fcTL"),f(b,l+=4,_++),f(b,l+=4,q.rect.width),f(b,l+=4,q.rect.height),f(b,l+=4,q.rect.x),f(b,l+=4,q.rect.y),o(b,l+=4,i[p]),o(b,l+=2,1e3),b[l+=2]=q.dispose,b[++l]=q.blend,f(b,++l,a(b,l-30,30)),l+=4);var I=q.cimg;f(b,l,(h=I.length)+(0==p?0:4));var M=l+=4;s(b,l,0==p?"IDAT":"fdAT"),l+=4,0!=p&&(f(b,l,_++),l+=4);for(v=0;v<h;v++)b[l+v]=I[v];f(b,l+=h,a(b,M,l-M)),l+=4}return f(b,l,0),s(b,l+=4,"IEND"),f(b,l+=4,a(b,l-4,4)),l+=4,b.buffer},e.encode.compressPNG=function(r,t){for(var n=0;n<r.frames.length;n++){var i=r.frames[n],a=(i.rect.width,i.rect.height),f=new Uint8Array(a*i.bpl+a);i.cimg=e.encode._filterZero(i.img,a,i.bpp,i.bpl,f,t)}},e.encode.compress=function(r,t,n,i,a,f){null==f&&(f=!1);for(var o=6,s=8,l=255,c=0;c<r.length;c++)for(var u=new Uint8Array(r[c]),d=u.length,h=0;h<d;h+=4)l&=u[h+3];var v=255!=l,p=v&&a,b=e.encode.framize(r,t,n,a,p),g={},m=[],y=[];if(0!=i){var w=[];for(h=0;h<b.length;h++)w.push(b[h].img.buffer);var A=e.encode.concatRGBA(w,a),U=e.quantize(A,i),_=0,q=new Uint8Array(U.abuf);for(h=0;h<b.length;h++){var I=(F=b[h].img).length;y.push(new Uint8Array(U.inds.buffer,_>>2,I>>2));for(c=0;c<I;c+=4)F[c]=q[_+c],F[c+1]=q[_+c+1],F[c+2]=q[_+c+2],F[c+3]=q[_+c+3];_+=I}for(h=0;h<U.plte.length;h++)m.push(U.plte[h].est.rgba)}else for(c=0;c<b.length;c++){var M=b[c],T=new Uint32Array(M.img.buffer),z=M.rect.width,R=(d=T.length,new Uint8Array(d));y.push(R);for(h=0;h<d;h++){var N=T[h];if(0!=h&&N==T[h-1])R[h]=R[h-1];else if(h>z&&N==T[h-z])R[h]=R[h-z];else{var L=g[N];if(null==L&&(g[N]=L=m.length,m.push(N),m.length>=300))break;R[h]=L}}}var P=m.length;P<=256&&0==f&&(s=P<=2?1:P<=4?2:P<=16?4:8,a&&(s=8));for(c=0;c<b.length;c++){(M=b[c]).rect.x,M.rect.y,z=M.rect.width;var S=M.rect.height,D=M.img,B=(new Uint32Array(D.buffer),4*z),x=4;if(P<=256&&0==f){B=Math.ceil(s*z/8);for(var C=new Uint8Array(B*S),G=y[c],Z=0;Z<S;Z++){h=Z*B;var k=Z*z;if(8==s)for(var E=0;E<z;E++)C[h+E]=G[k+E];else if(4==s)for(E=0;E<z;E++)C[h+(E>>1)]|=G[k+E]<<4-4*(1&E);else if(2==s)for(E=0;E<z;E++)C[h+(E>>2)]|=G[k+E]<<6-2*(3&E);else if(1==s)for(E=0;E<z;E++)C[h+(E>>3)]|=G[k+E]<<7-1*(7&E)}D=C,o=3,x=1}else if(0==v&&1==b.length){C=new Uint8Array(z*S*3);var H=z*S;for(h=0;h<H;h++){var F,K=4*h;C[F=3*h]=D[K],C[F+1]=D[K+1],C[F+2]=D[K+2]}D=C,o=2,x=3,B=3*z}M.img=D,M.bpl=B,M.bpp=x}return{ctype:o,depth:s,plte:m,frames:b}},e.encode.framize=function(r,t,n,i,a){for(var f=[],o=0;o<r.length;o++){var s=new Uint8Array(r[o]),l=new Uint32Array(s.buffer),c=0,u=0,d=t,h=n,v=0;if(0==o||a)s=s.slice(0);else{for(var p=i||1==o||2==f[f.length-2].dispose?1:2,b=0,g=1e9,m=0;m<p;m++){for(var y=new Uint8Array(r[o-1-m]),w=new Uint32Array(r[o-1-m]),A=t,U=n,_=-1,q=-1,I=0;I<n;I++)for(var M=0;M<t;M++){var T=I*t+M;l[T]!=w[T]&&(M<A&&(A=M),M>_&&(_=M),I<U&&(U=I),I>q&&(q=I))}var z=-1==_?1:(_-A+1)*(q-U+1);z<g&&(g=z,b=m,-1==_?(c=u=0,d=h=1):(c=A,u=U,d=_-A+1,h=q-U+1))}y=new Uint8Array(r[o-1-b]);1==b&&(f[f.length-1].dispose=2);var R=new Uint8Array(d*h*4);new Uint32Array(R.buffer);e._copyTile(y,t,n,R,d,h,-c,-u,0),e._copyTile(s,t,n,R,d,h,-c,-u,3)?(e._copyTile(s,t,n,R,d,h,-c,-u,2),v=1):(e._copyTile(s,t,n,R,d,h,-c,-u,0),v=0),s=R}f.push({rect:{x:c,y:u,width:d,height:h},img:s,blend:v,dispose:a?1:0})}return f},e.encode._filterZero=function(t,n,i,a,f,o){if(-1!=o){for(var s=0;s<n;s++)e.encode._filterLine(f,t,s,a,i,o);return r.deflate(f)}for(var l=[],c=0;c<5;c++)if(!(n*a>5e5)||2!=c&&3!=c&&4!=c){for(s=0;s<n;s++)e.encode._filterLine(f,t,s,a,i,c);if(l.push(r.deflate(f)),1==i)break}for(var u,d=1e9,h=0;h<l.length;h++)l[h].length<d&&(u=h,d=l[h].length);return l[u]},e.encode._filterLine=function(r,t,n,i,a,f){var o=n*i,s=o+n,l=e.decode._paeth;if(r[s]=f,s++,0==f)for(var c=0;c<i;c++)r[s+c]=t[o+c];else if(1==f){for(c=0;c<a;c++)r[s+c]=t[o+c];for(c=a;c<i;c++)r[s+c]=t[o+c]-t[o+c-a]+256&255}else if(0==n){for(c=0;c<a;c++)r[s+c]=t[o+c];if(2==f)for(c=a;c<i;c++)r[s+c]=t[o+c];if(3==f)for(c=a;c<i;c++)r[s+c]=t[o+c]-(t[o+c-a]>>1)+256&255;if(4==f)for(c=a;c<i;c++)r[s+c]=t[o+c]-l(t[o+c-a],0,0)+256&255}else{if(2==f)for(c=0;c<i;c++)r[s+c]=t[o+c]+256-t[o+c-i]&255;if(3==f){for(c=0;c<a;c++)r[s+c]=t[o+c]+256-(t[o+c-i]>>1)&255;for(c=a;c<i;c++)r[s+c]=t[o+c]+256-(t[o+c-i]+t[o+c-a]>>1)&255}if(4==f){for(c=0;c<a;c++)r[s+c]=t[o+c]+256-l(0,t[o+c-i],0)&255;for(c=a;c<i;c++)r[s+c]=t[o+c]+256-l(t[o+c-a],t[o+c-i],t[o+c-a-i])&255}}},e.crc={table:function(){for(var e=new Uint32Array(256),r=0;r<256;r++){for(var t=r,n=0;n<8;n++)1&t?t=3988292384^t>>>1:t>>>=1;e[r]=t}return e}(),update:function(r,t,n,i){for(var a=0;a<i;a++)r=e.crc.table[255&(r^t[n+a])]^r>>>8;return r},crc:function(r,t,n){return 4294967295^e.crc.update(4294967295,r,t,n)}},e.quantize=function(r,t){for(var n=new Uint8Array(r),i=n.slice(0),a=new Uint32Array(i.buffer),f=e.quantize.getKDtree(i,t),o=f[0],s=f[1],l=(e.quantize.planeDst,n),c=a,u=l.length,d=new Uint8Array(n.length>>2),h=0;h<u;h+=4){var v=l[h]*(1/255),p=l[h+1]*(1/255),b=l[h+2]*(1/255),g=l[h+3]*(1/255),m=e.quantize.getNearest(o,v,p,b,g);d[h>>2]=m.ind,c[h>>2]=m.est.rgba}return{abuf:i.buffer,inds:d,plte:s}},e.quantize.getKDtree=function(r,t,n){null==n&&(n=1e-4);var i=new Uint32Array(r.buffer),a={i0:0,i1:r.length,bst:null,est:null,tdst:0,left:null,right:null};a.bst=e.quantize.stats(r,a.i0,a.i1),a.est=e.quantize.estats(a.bst);for(var f=[a];f.length<t;){for(var o=0,s=0,l=0;l<f.length;l++)f[l].est.L>o&&(o=f[l].est.L,s=l);if(o<n)break;var c=f[s],u=e.quantize.splitPixels(r,i,c.i0,c.i1,c.est.e,c.est.eMq255);if(c.i0>=u||c.i1<=u)c.est.L=0;else{var d={i0:c.i0,i1:u,bst:null,est:null,tdst:0,left:null,right:null};d.bst=e.quantize.stats(r,d.i0,d.i1),d.est=e.quantize.estats(d.bst);var h={i0:u,i1:c.i1,bst:null,est:null,tdst:0,left:null,right:null};h.bst={R:[],m:[],N:c.bst.N-d.bst.N};for(l=0;l<16;l++)h.bst.R[l]=c.bst.R[l]-d.bst.R[l];for(l=0;l<4;l++)h.bst.m[l]=c.bst.m[l]-d.bst.m[l];h.est=e.quantize.estats(h.bst),c.left=d,c.right=h,f[s]=d,f.push(h)}}f.sort(function(e,r){return r.bst.N-e.bst.N});for(l=0;l<f.length;l++)f[l].ind=l;return[a,f]},e.quantize.getNearest=function(r,t,n,i,a){if(null==r.left)return r.tdst=e.quantize.dist(r.est.q,t,n,i,a),r;var f=e.quantize.planeDst(r.est,t,n,i,a),o=r.left,s=r.right;f>0&&(o=r.right,s=r.left);var l=e.quantize.getNearest(o,t,n,i,a);if(l.tdst<=f*f)return l;var c=e.quantize.getNearest(s,t,n,i,a);return c.tdst<l.tdst?c:l},e.quantize.planeDst=function(e,r,t,n,i){var a=e.e;return a[0]*r+a[1]*t+a[2]*n+a[3]*i-e.eMq},e.quantize.dist=function(e,r,t,n,i){var a=r-e[0],f=t-e[1],o=n-e[2],s=i-e[3];return a*a+f*f+o*o+s*s},e.quantize.splitPixels=function(r,t,n,i,a,f){var o=e.quantize.vecDot;i-=4;for(;n<i;){for(;o(r,n,a)<=f;)n+=4;for(;o(r,i,a)>f;)i-=4;if(n>=i)break;var s=t[n>>2];t[n>>2]=t[i>>2],t[i>>2]=s,n+=4,i-=4}for(;o(r,n,a)>f;)n-=4;return n+4},e.quantize.vecDot=function(e,r,t){return e[r]*t[0]+e[r+1]*t[1]+e[r+2]*t[2]+e[r+3]*t[3]},e.quantize.stats=function(e,r,t){for(var n=[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],i=[0,0,0,0],a=t-r>>2,f=r;f<t;f+=4){var o=e[f]*(1/255),s=e[f+1]*(1/255),l=e[f+2]*(1/255),c=e[f+3]*(1/255);i[0]+=o,i[1]+=s,i[2]+=l,i[3]+=c,n[0]+=o*o,n[1]+=o*s,n[2]+=o*l,n[3]+=o*c,n[5]+=s*s,n[6]+=s*l,n[7]+=s*c,n[10]+=l*l,n[11]+=l*c,n[15]+=c*c}return n[4]=n[1],n[8]=n[2],n[9]=n[6],n[12]=n[3],n[13]=n[7],n[14]=n[11],{R:n,m:i,N:a}},e.quantize.estats=function(r){var t=r.R,n=r.m,i=r.N,a=n[0],f=n[1],o=n[2],s=n[3],l=0==i?0:1/i,c=[t[0]-a*a*l,t[1]-a*f*l,t[2]-a*o*l,t[3]-a*s*l,t[4]-f*a*l,t[5]-f*f*l,t[6]-f*o*l,t[7]-f*s*l,t[8]-o*a*l,t[9]-o*f*l,t[10]-o*o*l,t[11]-o*s*l,t[12]-s*a*l,t[13]-s*f*l,t[14]-s*o*l,t[15]-s*s*l],u=c,d=e.M4,h=[.5,.5,.5,.5],v=0,p=0;if(0!=i)for(var b=0;b<10&&(h=d.multVec(u,h),p=Math.sqrt(d.dot(h,h)),h=d.sml(1/p,h),!(Math.abs(p-v)<1e-9));b++)v=p;var g=[a*l,f*l,o*l,s*l];return{Cov:c,q:g,e:h,L:v,eMq255:d.dot(d.sml(255,g),h),eMq:d.dot(h,g),rgba:(Math.round(255*g[3])<<24|Math.round(255*g[2])<<16|Math.round(255*g[1])<<8|Math.round(255*g[0])<<0)>>>0}},e.M4={multVec:function(e,r){return[e[0]*r[0]+e[1]*r[1]+e[2]*r[2]+e[3]*r[3],e[4]*r[0]+e[5]*r[1]+e[6]*r[2]+e[7]*r[3],e[8]*r[0]+e[9]*r[1]+e[10]*r[2]+e[11]*r[3],e[12]*r[0]+e[13]*r[1]+e[14]*r[2]+e[15]*r[3]]},dot:function(e,r){return e[0]*r[0]+e[1]*r[1]+e[2]*r[2]+e[3]*r[3]},sml:function(e,r){return[e*r[0],e*r[1],e*r[2],e*r[3]]}},e.encode.concatRGBA=function(e,r){for(var t=0,n=0;n<e.length;n++)t+=e[n].byteLength;var i=new Uint8Array(t),a=0;for(n=0;n<e.length;n++){for(var f=new Uint8Array(e[n]),o=f.length,s=0;s<o;s+=4){var l=f[s],c=f[s+1],u=f[s+2],d=f[s+3];r&&(d=0==(128&d)?0:255),0==d&&(l=c=u=0),i[a+s]=l,i[a+s+1]=c,i[a+s+2]=u,i[a+s+3]=d}a+=o}return i.buffer}}(e,"function"==typeof require?require("pako"):window.pako)}();
|
142 |
+
</script>
|
143 |
+
|
144 |
+
<script>
|
145 |
+
class Player {
|
146 |
+
|
147 |
+
constructor(container) {
|
148 |
+
this.container = container
|
149 |
+
this.global_frac = 0.0
|
150 |
+
this.container = document.getElementById(container)
|
151 |
+
this.progress = null;
|
152 |
+
this.mat = [[]]
|
153 |
+
|
154 |
+
this.player = this.container.querySelector('audio')
|
155 |
+
this.demo_img = this.container.querySelector('.underlay > img')
|
156 |
+
this.overlay = this.container.querySelector('.overlay')
|
157 |
+
this.playpause = this.container.querySelector(".playpause");
|
158 |
+
this.download = this.container.querySelector(".download");
|
159 |
+
this.play_img = this.container.querySelector('.play-img')
|
160 |
+
this.pause_img = this.container.querySelector('.pause-img')
|
161 |
+
this.canvas = this.container.querySelector('.response-canvas')
|
162 |
+
this.response_container = this.container.querySelector('.response')
|
163 |
+
this.context = this.canvas.getContext('2d');
|
164 |
+
|
165 |
+
// console.log(this.player.duration)
|
166 |
+
var togglePlayPause = () => {
|
167 |
+
if (this.player.networkState !== 1) {
|
168 |
+
return
|
169 |
+
}
|
170 |
+
if (this.player.paused || this.player.ended) {
|
171 |
+
this.play()
|
172 |
+
} else {
|
173 |
+
this.pause()
|
174 |
+
}
|
175 |
+
}
|
176 |
+
|
177 |
+
this.update = () => {
|
178 |
+
this.global_frac = this.player.currentTime / this.player.duration
|
179 |
+
// this.global_frac = frac
|
180 |
+
// console.log(this.player.currentTime, this.player.duration, this.global_frac)
|
181 |
+
this.overlay.style.width = (100*(1.0 - this.global_frac)).toString() + '%'
|
182 |
+
this.redraw()
|
183 |
+
}
|
184 |
+
|
185 |
+
// var start = null;
|
186 |
+
this.updateLoop = (timestamp) => {
|
187 |
+
// if (!start) start = timestamp;
|
188 |
+
// var progress = timestamp - start;
|
189 |
+
this.update()
|
190 |
+
// this.progress = setTimeout(this.updateLoop, 10)
|
191 |
+
this.progress = window.requestAnimationFrame(this.updateLoop)
|
192 |
+
}
|
193 |
+
|
194 |
+
this.seek = (e) => {
|
195 |
+
this.global_frac = e.offsetX / this.demo_img.width
|
196 |
+
this.player.currentTime = this.global_frac * this.player.duration
|
197 |
+
// console.log(this.global_frac)
|
198 |
+
this.overlay.style.width = (100*(1.0 - this.global_frac)).toString() + '%'
|
199 |
+
this.redraw()
|
200 |
+
}
|
201 |
+
|
202 |
+
var download_audio = () => {
|
203 |
+
var url = this.player.querySelector('#src').src
|
204 |
+
const a = document.createElement('a')
|
205 |
+
a.href = url
|
206 |
+
a.download = "download"
|
207 |
+
document.body.appendChild(a)
|
208 |
+
a.click()
|
209 |
+
document.body.removeChild(a)
|
210 |
+
}
|
211 |
+
|
212 |
+
this.demo_img.onclick = this.seek;
|
213 |
+
this.playpause.disabled = true
|
214 |
+
this.player.onplay = this.updateLoop
|
215 |
+
this.player.onpause = () => {
|
216 |
+
window.cancelAnimationFrame(this.progress)
|
217 |
+
this.update();
|
218 |
+
}
|
219 |
+
this.player.onended = () => {this.pause()}
|
220 |
+
this.playpause.onclick = togglePlayPause;
|
221 |
+
this.download.onclick = download_audio;
|
222 |
+
}
|
223 |
+
|
224 |
+
load(audio_fname, img_fname, levels_fname) {
|
225 |
+
this.pause()
|
226 |
+
window.cancelAnimationFrame(this.progress)
|
227 |
+
this.playpause.disabled = true
|
228 |
+
|
229 |
+
this.player.querySelector('#src').setAttribute("src", audio_fname)
|
230 |
+
this.player.load()
|
231 |
+
this.demo_img.setAttribute("src", img_fname)
|
232 |
+
this.overlay.style.width = '0%'
|
233 |
+
|
234 |
+
fetch(levels_fname)
|
235 |
+
.then(response => response.arrayBuffer())
|
236 |
+
.then(text => {
|
237 |
+
this.mat = this.parse(text);
|
238 |
+
this.playpause.disabled = false;
|
239 |
+
this.redraw();
|
240 |
+
})
|
241 |
+
}
|
242 |
+
|
243 |
+
parse(buffer) {
|
244 |
+
var img = UPNG.decode(buffer)
|
245 |
+
var dat = UPNG.toRGBA8(img)[0]
|
246 |
+
var view = new DataView(dat)
|
247 |
+
var data = new Array(img.width).fill(0).map(() => new Array(img.height).fill(0));
|
248 |
+
|
249 |
+
var min =100
|
250 |
+
var max = -100
|
251 |
+
var idx = 0
|
252 |
+
for (let i=0; i < img.height*img.width*4; i+=4) {
|
253 |
+
var rgba = [view.getUint8(i, 1) / 255, view.getUint8(i + 1, 1) / 255, view.getUint8(i + 2, 1) / 255, view.getUint8(i + 3, 1) / 255]
|
254 |
+
var norm = Math.pow(Math.pow(rgba[0], 2) + Math.pow(rgba[1], 2) + Math.pow(rgba[2], 2), 0.5)
|
255 |
+
data[idx % img.width][img.height - Math.floor(idx / img.width) - 1] = norm
|
256 |
+
|
257 |
+
idx += 1
|
258 |
+
min = Math.min(min, norm)
|
259 |
+
max = Math.max(max, norm)
|
260 |
+
}
|
261 |
+
for (let i = 0; i < data.length; i++) {
|
262 |
+
for (let j = 0; j < data[i].length; j++) {
|
263 |
+
data[i][j] = Math.pow((data[i][j] - min) / (max - min), 1.5)
|
264 |
+
}
|
265 |
+
}
|
266 |
+
var data3 = new Array(img.width).fill(0).map(() => new Array(img.height).fill(0));
|
267 |
+
for (let i = 0; i < data.length; i++) {
|
268 |
+
for (let j = 0; j < data[i].length; j++) {
|
269 |
+
if (i == 0 || i == (data.length - 1)) {
|
270 |
+
data3[i][j] = data[i][j]
|
271 |
+
} else{
|
272 |
+
data3[i][j] = 0.33*(data[i - 1][j]) + 0.33*(data[i][j]) + 0.33*(data[i + 1][j])
|
273 |
+
// data3[i][j] = 0.00*(data[i - 1][j]) + 1.00*(data[i][j]) + 0.00*(data[i + 1][j])
|
274 |
+
}
|
275 |
+
}
|
276 |
+
}
|
277 |
+
|
278 |
+
var scale = 5
|
279 |
+
var data2 = new Array(scale*img.width).fill(0).map(() => new Array(img.height).fill(0));
|
280 |
+
for (let j = 0; j < data[0].length; j++) {
|
281 |
+
for (let i = 0; i < data.length - 1; i++) {
|
282 |
+
for (let k = 0; k < scale; k++) {
|
283 |
+
data2[scale*i + k][j] = (1.0 - (k/scale))*data3[i][j] + (k / scale)*data3[i + 1][j]
|
284 |
+
}
|
285 |
+
}
|
286 |
+
}
|
287 |
+
return data2
|
288 |
+
}
|
289 |
+
|
290 |
+
play() {
|
291 |
+
this.player.play();
|
292 |
+
this.play_img.style.display = 'none'
|
293 |
+
this.pause_img.style.display = 'block'
|
294 |
+
}
|
295 |
+
|
296 |
+
pause() {
|
297 |
+
this.player.pause();
|
298 |
+
this.pause_img.style.display = 'none'
|
299 |
+
this.play_img.style.display = 'block'
|
300 |
+
}
|
301 |
+
|
302 |
+
redraw() {
|
303 |
+
this.canvas.width = window.devicePixelRatio*this.response_container.offsetWidth;
|
304 |
+
this.canvas.height = window.devicePixelRatio*this.response_container.offsetHeight;
|
305 |
+
|
306 |
+
this.context.clearRect(0, 0, this.canvas.width, this.canvas.height)
|
307 |
+
this.canvas.style.width = (this.canvas.width / window.devicePixelRatio).toString() + "px";
|
308 |
+
this.canvas.style.height = (this.canvas.height / window.devicePixelRatio).toString() + "px";
|
309 |
+
|
310 |
+
var f = this.global_frac*this.mat.length
|
311 |
+
var tstep = Math.min(Math.floor(f), this.mat.length - 2)
|
312 |
+
var heights = this.mat[tstep]
|
313 |
+
var bar_width = (this.canvas.width / heights.length) - 1
|
314 |
+
|
315 |
+
for (let k = 0; k < heights.length - 1; k++) {
|
316 |
+
var height = Math.max(Math.round((heights[k])*this.canvas.height), 3)
|
317 |
+
this.context.fillStyle = '#696f7b';
|
318 |
+
this.context.fillRect(k*(bar_width + 1), (this.canvas.height - height) / 2, bar_width, height);
|
319 |
+
}
|
320 |
+
}
|
321 |
+
}
|
322 |
+
</script>
|
audiotools/core/templates/pandoc.css
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
Copyright (c) 2017 Chris Patuzzo
|
3 |
+
https://twitter.com/chrispatuzzo
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
22 |
+
*/
|
23 |
+
|
24 |
+
body {
|
25 |
+
font-family: Helvetica, arial, sans-serif;
|
26 |
+
font-size: 14px;
|
27 |
+
line-height: 1.6;
|
28 |
+
padding-top: 10px;
|
29 |
+
padding-bottom: 10px;
|
30 |
+
background-color: white;
|
31 |
+
padding: 30px;
|
32 |
+
color: #333;
|
33 |
+
}
|
34 |
+
|
35 |
+
body > *:first-child {
|
36 |
+
margin-top: 0 !important;
|
37 |
+
}
|
38 |
+
|
39 |
+
body > *:last-child {
|
40 |
+
margin-bottom: 0 !important;
|
41 |
+
}
|
42 |
+
|
43 |
+
a {
|
44 |
+
color: #4183C4;
|
45 |
+
text-decoration: none;
|
46 |
+
}
|
47 |
+
|
48 |
+
a.absent {
|
49 |
+
color: #cc0000;
|
50 |
+
}
|
51 |
+
|
52 |
+
a.anchor {
|
53 |
+
display: block;
|
54 |
+
padding-left: 30px;
|
55 |
+
margin-left: -30px;
|
56 |
+
cursor: pointer;
|
57 |
+
position: absolute;
|
58 |
+
top: 0;
|
59 |
+
left: 0;
|
60 |
+
bottom: 0;
|
61 |
+
}
|
62 |
+
|
63 |
+
h1, h2, h3, h4, h5, h6 {
|
64 |
+
margin: 20px 0 10px;
|
65 |
+
padding: 0;
|
66 |
+
font-weight: bold;
|
67 |
+
-webkit-font-smoothing: antialiased;
|
68 |
+
cursor: text;
|
69 |
+
position: relative;
|
70 |
+
}
|
71 |
+
|
72 |
+
h2:first-child, h1:first-child, h1:first-child + h2, h3:first-child, h4:first-child, h5:first-child, h6:first-child {
|
73 |
+
margin-top: 0;
|
74 |
+
padding-top: 0;
|
75 |
+
}
|
76 |
+
|
77 |
+
h1:hover a.anchor, h2:hover a.anchor, h3:hover a.anchor, h4:hover a.anchor, h5:hover a.anchor, h6:hover a.anchor {
|
78 |
+
text-decoration: none;
|
79 |
+
}
|
80 |
+
|
81 |
+
h1 tt, h1 code {
|
82 |
+
font-size: inherit;
|
83 |
+
}
|
84 |
+
|
85 |
+
h2 tt, h2 code {
|
86 |
+
font-size: inherit;
|
87 |
+
}
|
88 |
+
|
89 |
+
h3 tt, h3 code {
|
90 |
+
font-size: inherit;
|
91 |
+
}
|
92 |
+
|
93 |
+
h4 tt, h4 code {
|
94 |
+
font-size: inherit;
|
95 |
+
}
|
96 |
+
|
97 |
+
h5 tt, h5 code {
|
98 |
+
font-size: inherit;
|
99 |
+
}
|
100 |
+
|
101 |
+
h6 tt, h6 code {
|
102 |
+
font-size: inherit;
|
103 |
+
}
|
104 |
+
|
105 |
+
h1 {
|
106 |
+
font-size: 28px;
|
107 |
+
color: black;
|
108 |
+
}
|
109 |
+
|
110 |
+
h2 {
|
111 |
+
font-size: 24px;
|
112 |
+
border-bottom: 1px solid #cccccc;
|
113 |
+
color: black;
|
114 |
+
}
|
115 |
+
|
116 |
+
h3 {
|
117 |
+
font-size: 18px;
|
118 |
+
}
|
119 |
+
|
120 |
+
h4 {
|
121 |
+
font-size: 16px;
|
122 |
+
}
|
123 |
+
|
124 |
+
h5 {
|
125 |
+
font-size: 14px;
|
126 |
+
}
|
127 |
+
|
128 |
+
h6 {
|
129 |
+
color: #777777;
|
130 |
+
font-size: 14px;
|
131 |
+
}
|
132 |
+
|
133 |
+
p, blockquote, ul, ol, dl, li, table, pre {
|
134 |
+
margin: 15px 0;
|
135 |
+
}
|
136 |
+
|
137 |
+
hr {
|
138 |
+
border: 0 none;
|
139 |
+
color: #cccccc;
|
140 |
+
height: 4px;
|
141 |
+
padding: 0;
|
142 |
+
}
|
143 |
+
|
144 |
+
body > h2:first-child {
|
145 |
+
margin-top: 0;
|
146 |
+
padding-top: 0;
|
147 |
+
}
|
148 |
+
|
149 |
+
body > h1:first-child {
|
150 |
+
margin-top: 0;
|
151 |
+
padding-top: 0;
|
152 |
+
}
|
153 |
+
|
154 |
+
body > h1:first-child + h2 {
|
155 |
+
margin-top: 0;
|
156 |
+
padding-top: 0;
|
157 |
+
}
|
158 |
+
|
159 |
+
body > h3:first-child, body > h4:first-child, body > h5:first-child, body > h6:first-child {
|
160 |
+
margin-top: 0;
|
161 |
+
padding-top: 0;
|
162 |
+
}
|
163 |
+
|
164 |
+
a:first-child h1, a:first-child h2, a:first-child h3, a:first-child h4, a:first-child h5, a:first-child h6 {
|
165 |
+
margin-top: 0;
|
166 |
+
padding-top: 0;
|
167 |
+
}
|
168 |
+
|
169 |
+
h1 p, h2 p, h3 p, h4 p, h5 p, h6 p {
|
170 |
+
margin-top: 0;
|
171 |
+
}
|
172 |
+
|
173 |
+
li p.first {
|
174 |
+
display: inline-block;
|
175 |
+
}
|
176 |
+
|
177 |
+
ul, ol {
|
178 |
+
padding-left: 30px;
|
179 |
+
}
|
180 |
+
|
181 |
+
ul :first-child, ol :first-child {
|
182 |
+
margin-top: 0;
|
183 |
+
}
|
184 |
+
|
185 |
+
ul :last-child, ol :last-child {
|
186 |
+
margin-bottom: 0;
|
187 |
+
}
|
188 |
+
|
189 |
+
dl {
|
190 |
+
padding: 0;
|
191 |
+
}
|
192 |
+
|
193 |
+
dl dt {
|
194 |
+
font-size: 14px;
|
195 |
+
font-weight: bold;
|
196 |
+
font-style: italic;
|
197 |
+
padding: 0;
|
198 |
+
margin: 15px 0 5px;
|
199 |
+
}
|
200 |
+
|
201 |
+
dl dt:first-child {
|
202 |
+
padding: 0;
|
203 |
+
}
|
204 |
+
|
205 |
+
dl dt > :first-child {
|
206 |
+
margin-top: 0;
|
207 |
+
}
|
208 |
+
|
209 |
+
dl dt > :last-child {
|
210 |
+
margin-bottom: 0;
|
211 |
+
}
|
212 |
+
|
213 |
+
dl dd {
|
214 |
+
margin: 0 0 15px;
|
215 |
+
padding: 0 15px;
|
216 |
+
}
|
217 |
+
|
218 |
+
dl dd > :first-child {
|
219 |
+
margin-top: 0;
|
220 |
+
}
|
221 |
+
|
222 |
+
dl dd > :last-child {
|
223 |
+
margin-bottom: 0;
|
224 |
+
}
|
225 |
+
|
226 |
+
blockquote {
|
227 |
+
border-left: 4px solid #dddddd;
|
228 |
+
padding: 0 15px;
|
229 |
+
color: #777777;
|
230 |
+
}
|
231 |
+
|
232 |
+
blockquote > :first-child {
|
233 |
+
margin-top: 0;
|
234 |
+
}
|
235 |
+
|
236 |
+
blockquote > :last-child {
|
237 |
+
margin-bottom: 0;
|
238 |
+
}
|
239 |
+
|
240 |
+
table {
|
241 |
+
padding: 0;
|
242 |
+
}
|
243 |
+
table tr {
|
244 |
+
border-top: 1px solid #cccccc;
|
245 |
+
background-color: white;
|
246 |
+
margin: 0;
|
247 |
+
padding: 0;
|
248 |
+
}
|
249 |
+
|
250 |
+
table tr:nth-child(2n) {
|
251 |
+
background-color: #f8f8f8;
|
252 |
+
}
|
253 |
+
|
254 |
+
table tr th {
|
255 |
+
font-weight: bold;
|
256 |
+
border: 1px solid #cccccc;
|
257 |
+
text-align: left;
|
258 |
+
margin: 0;
|
259 |
+
padding: 6px 13px;
|
260 |
+
}
|
261 |
+
|
262 |
+
table tr td {
|
263 |
+
border: 1px solid #cccccc;
|
264 |
+
text-align: left;
|
265 |
+
margin: 0;
|
266 |
+
padding: 6px 13px;
|
267 |
+
}
|
268 |
+
|
269 |
+
table tr th :first-child, table tr td :first-child {
|
270 |
+
margin-top: 0;
|
271 |
+
}
|
272 |
+
|
273 |
+
table tr th :last-child, table tr td :last-child {
|
274 |
+
margin-bottom: 0;
|
275 |
+
}
|
276 |
+
|
277 |
+
img {
|
278 |
+
max-width: 100%;
|
279 |
+
}
|
280 |
+
|
281 |
+
span.frame {
|
282 |
+
display: block;
|
283 |
+
overflow: hidden;
|
284 |
+
}
|
285 |
+
|
286 |
+
span.frame > span {
|
287 |
+
border: 1px solid #dddddd;
|
288 |
+
display: block;
|
289 |
+
float: left;
|
290 |
+
overflow: hidden;
|
291 |
+
margin: 13px 0 0;
|
292 |
+
padding: 7px;
|
293 |
+
width: auto;
|
294 |
+
}
|
295 |
+
|
296 |
+
span.frame span img {
|
297 |
+
display: block;
|
298 |
+
float: left;
|
299 |
+
}
|
300 |
+
|
301 |
+
span.frame span span {
|
302 |
+
clear: both;
|
303 |
+
color: #333333;
|
304 |
+
display: block;
|
305 |
+
padding: 5px 0 0;
|
306 |
+
}
|
307 |
+
|
308 |
+
span.align-center {
|
309 |
+
display: block;
|
310 |
+
overflow: hidden;
|
311 |
+
clear: both;
|
312 |
+
}
|
313 |
+
|
314 |
+
span.align-center > span {
|
315 |
+
display: block;
|
316 |
+
overflow: hidden;
|
317 |
+
margin: 13px auto 0;
|
318 |
+
text-align: center;
|
319 |
+
}
|
320 |
+
|
321 |
+
span.align-center span img {
|
322 |
+
margin: 0 auto;
|
323 |
+
text-align: center;
|
324 |
+
}
|
325 |
+
|
326 |
+
span.align-right {
|
327 |
+
display: block;
|
328 |
+
overflow: hidden;
|
329 |
+
clear: both;
|
330 |
+
}
|
331 |
+
|
332 |
+
span.align-right > span {
|
333 |
+
display: block;
|
334 |
+
overflow: hidden;
|
335 |
+
margin: 13px 0 0;
|
336 |
+
text-align: right;
|
337 |
+
}
|
338 |
+
|
339 |
+
span.align-right span img {
|
340 |
+
margin: 0;
|
341 |
+
text-align: right;
|
342 |
+
}
|
343 |
+
|
344 |
+
span.float-left {
|
345 |
+
display: block;
|
346 |
+
margin-right: 13px;
|
347 |
+
overflow: hidden;
|
348 |
+
float: left;
|
349 |
+
}
|
350 |
+
|
351 |
+
span.float-left span {
|
352 |
+
margin: 13px 0 0;
|
353 |
+
}
|
354 |
+
|
355 |
+
span.float-right {
|
356 |
+
display: block;
|
357 |
+
margin-left: 13px;
|
358 |
+
overflow: hidden;
|
359 |
+
float: right;
|
360 |
+
}
|
361 |
+
|
362 |
+
span.float-right > span {
|
363 |
+
display: block;
|
364 |
+
overflow: hidden;
|
365 |
+
margin: 13px auto 0;
|
366 |
+
text-align: right;
|
367 |
+
}
|
368 |
+
|
369 |
+
code, tt {
|
370 |
+
margin: 0 2px;
|
371 |
+
padding: 0 5px;
|
372 |
+
white-space: nowrap;
|
373 |
+
border-radius: 3px;
|
374 |
+
}
|
375 |
+
|
376 |
+
pre code {
|
377 |
+
margin: 0;
|
378 |
+
padding: 0;
|
379 |
+
white-space: pre;
|
380 |
+
border: none;
|
381 |
+
background: transparent;
|
382 |
+
}
|
383 |
+
|
384 |
+
.highlight pre {
|
385 |
+
font-size: 13px;
|
386 |
+
line-height: 19px;
|
387 |
+
overflow: auto;
|
388 |
+
padding: 6px 10px;
|
389 |
+
border-radius: 3px;
|
390 |
+
}
|
391 |
+
|
392 |
+
pre {
|
393 |
+
font-size: 13px;
|
394 |
+
line-height: 19px;
|
395 |
+
overflow: auto;
|
396 |
+
padding: 6px 10px;
|
397 |
+
border-radius: 3px;
|
398 |
+
}
|
399 |
+
|
400 |
+
pre code, pre tt {
|
401 |
+
background-color: transparent;
|
402 |
+
border: none;
|
403 |
+
}
|
404 |
+
|
405 |
+
body {
|
406 |
+
max-width: 600px;
|
407 |
+
}
|
audiotools/core/templates/widget.html
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div id='PLAYER_ID' class='player' style="max-width: MAX_WIDTH;">
|
2 |
+
<div class='spectrogram' style="padding-top: PADDING_AMOUNT;">
|
3 |
+
<div class='overlay'></div>
|
4 |
+
<div class='underlay'>
|
5 |
+
<img>
|
6 |
+
</div>
|
7 |
+
</div>
|
8 |
+
|
9 |
+
<div class='audio-controls'>
|
10 |
+
<button id="playpause" disabled class='playpause' title="play">
|
11 |
+
<svg class='play-img' width="14px" height="19px" viewBox="0 0 14 19">
|
12 |
+
<polygon id="Triangle" fill="#000000" transform="translate(9, 9.5) rotate(90) translate(-7, -9.5) " points="7 2.5 16.5 16.5 -2.5 16.5"></polygon>
|
13 |
+
</svg>
|
14 |
+
<svg class='pause-img' width="16px" height="19px" viewBox="0 0 16 19">
|
15 |
+
<g fill="#000000" stroke="#000000">
|
16 |
+
<rect id="Rectangle" x="0.5" y="0.5" width="4" height="18"></rect>
|
17 |
+
<rect id="Rectangle" x="11.5" y="0.5" width="4" height="18"></rect>
|
18 |
+
</g>
|
19 |
+
</svg>
|
20 |
+
</button>
|
21 |
+
|
22 |
+
<audio class='play'>
|
23 |
+
<source id='src'>
|
24 |
+
</audio>
|
25 |
+
<div class='response'>
|
26 |
+
<canvas class='response-canvas'></canvas>
|
27 |
+
</div>
|
28 |
+
|
29 |
+
<button id="download" class='download' title="download">
|
30 |
+
<svg class='download-img' x="0px" y="0px" viewBox="0 0 29.978 29.978" style="enable-background:new 0 0 29.978 29.978;" xml:space="preserve">
|
31 |
+
<g>
|
32 |
+
<path d="M25.462,19.105v6.848H4.515v-6.848H0.489v8.861c0,1.111,0.9,2.012,2.016,2.012h24.967c1.115,0,2.016-0.9,2.016-2.012
|
33 |
+
v-8.861H25.462z"/>
|
34 |
+
<path d="M14.62,18.426l-5.764-6.965c0,0-0.877-0.828,0.074-0.828s3.248,0,3.248,0s0-0.557,0-1.416c0-2.449,0-6.906,0-8.723
|
35 |
+
c0,0-0.129-0.494,0.615-0.494c0.75,0,4.035,0,4.572,0c0.536,0,0.524,0.416,0.524,0.416c0,1.762,0,6.373,0,8.742
|
36 |
+
c0,0.768,0,1.266,0,1.266s1.842,0,2.998,0c1.154,0,0.285,0.867,0.285,0.867s-4.904,6.51-5.588,7.193
|
37 |
+
C15.092,18.979,14.62,18.426,14.62,18.426z"/>
|
38 |
+
</g>
|
39 |
+
</svg>
|
40 |
+
</button>
|
41 |
+
</div>
|
42 |
+
</div>
|
43 |
+
|
44 |
+
<script>
|
45 |
+
var PLAYER_ID = new Player('PLAYER_ID')
|
46 |
+
PLAYER_ID.load(
|
47 |
+
"AUDIO_SRC",
|
48 |
+
"IMAGE_SRC",
|
49 |
+
"LEVELS_SRC"
|
50 |
+
)
|
51 |
+
window.addEventListener("resize", function() {PLAYER_ID.redraw()})
|
52 |
+
</script>
|
audiotools/core/util.py
ADDED
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import glob
|
3 |
+
import math
|
4 |
+
import numbers
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import typing
|
8 |
+
from contextlib import contextmanager
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Dict
|
12 |
+
from typing import List
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torchaudio
|
17 |
+
from flatten_dict import flatten
|
18 |
+
from flatten_dict import unflatten
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class Info:
|
23 |
+
"""Shim for torchaudio.info API changes."""
|
24 |
+
|
25 |
+
sample_rate: float
|
26 |
+
num_frames: int
|
27 |
+
|
28 |
+
@property
|
29 |
+
def duration(self) -> float:
|
30 |
+
return self.num_frames / self.sample_rate
|
31 |
+
|
32 |
+
|
33 |
+
def info(audio_path: str):
|
34 |
+
"""Shim for torchaudio.info to make 0.7.2 API match 0.8.0.
|
35 |
+
|
36 |
+
Parameters
|
37 |
+
----------
|
38 |
+
audio_path : str
|
39 |
+
Path to audio file.
|
40 |
+
"""
|
41 |
+
# try default backend first, then fallback to soundfile
|
42 |
+
try:
|
43 |
+
info = torchaudio.info(str(audio_path))
|
44 |
+
except: # pragma: no cover
|
45 |
+
info = torchaudio.backend.soundfile_backend.info(str(audio_path))
|
46 |
+
|
47 |
+
if isinstance(info, tuple): # pragma: no cover
|
48 |
+
signal_info = info[0]
|
49 |
+
info = Info(sample_rate=signal_info.rate, num_frames=signal_info.length)
|
50 |
+
else:
|
51 |
+
info = Info(sample_rate=info.sample_rate, num_frames=info.num_frames)
|
52 |
+
|
53 |
+
return info
|
54 |
+
|
55 |
+
|
56 |
+
def ensure_tensor(
|
57 |
+
x: typing.Union[np.ndarray, torch.Tensor, float, int],
|
58 |
+
ndim: int = None,
|
59 |
+
batch_size: int = None,
|
60 |
+
):
|
61 |
+
"""Ensures that the input ``x`` is a tensor of specified
|
62 |
+
dimensions and batch size.
|
63 |
+
|
64 |
+
Parameters
|
65 |
+
----------
|
66 |
+
x : typing.Union[np.ndarray, torch.Tensor, float, int]
|
67 |
+
Data that will become a tensor on its way out.
|
68 |
+
ndim : int, optional
|
69 |
+
How many dimensions should be in the output, by default None
|
70 |
+
batch_size : int, optional
|
71 |
+
The batch size of the output, by default None
|
72 |
+
|
73 |
+
Returns
|
74 |
+
-------
|
75 |
+
torch.Tensor
|
76 |
+
Modified version of ``x`` as a tensor.
|
77 |
+
"""
|
78 |
+
if not torch.is_tensor(x):
|
79 |
+
x = torch.as_tensor(x)
|
80 |
+
if ndim is not None:
|
81 |
+
assert x.ndim <= ndim
|
82 |
+
while x.ndim < ndim:
|
83 |
+
x = x.unsqueeze(-1)
|
84 |
+
if batch_size is not None:
|
85 |
+
if x.shape[0] != batch_size:
|
86 |
+
shape = list(x.shape)
|
87 |
+
shape[0] = batch_size
|
88 |
+
x = x.expand(*shape)
|
89 |
+
return x
|
90 |
+
|
91 |
+
|
92 |
+
def _get_value(other):
|
93 |
+
from . import AudioSignal
|
94 |
+
|
95 |
+
if isinstance(other, AudioSignal):
|
96 |
+
return other.audio_data
|
97 |
+
return other
|
98 |
+
|
99 |
+
|
100 |
+
def hz_to_bin(hz: torch.Tensor, n_fft: int, sample_rate: int):
|
101 |
+
"""Closest frequency bin given a frequency, number
|
102 |
+
of bins, and a sampling rate.
|
103 |
+
|
104 |
+
Parameters
|
105 |
+
----------
|
106 |
+
hz : torch.Tensor
|
107 |
+
Tensor of frequencies in Hz.
|
108 |
+
n_fft : int
|
109 |
+
Number of FFT bins.
|
110 |
+
sample_rate : int
|
111 |
+
Sample rate of audio.
|
112 |
+
|
113 |
+
Returns
|
114 |
+
-------
|
115 |
+
torch.Tensor
|
116 |
+
Closest bins to the data.
|
117 |
+
"""
|
118 |
+
shape = hz.shape
|
119 |
+
hz = hz.flatten()
|
120 |
+
freqs = torch.linspace(0, sample_rate / 2, 2 + n_fft // 2)
|
121 |
+
hz[hz > sample_rate / 2] = sample_rate / 2
|
122 |
+
|
123 |
+
closest = (hz[None, :] - freqs[:, None]).abs()
|
124 |
+
closest_bins = closest.min(dim=0).indices
|
125 |
+
|
126 |
+
return closest_bins.reshape(*shape)
|
127 |
+
|
128 |
+
|
129 |
+
def random_state(seed: typing.Union[int, np.random.RandomState]):
|
130 |
+
"""
|
131 |
+
Turn seed into a np.random.RandomState instance.
|
132 |
+
|
133 |
+
Parameters
|
134 |
+
----------
|
135 |
+
seed : typing.Union[int, np.random.RandomState] or None
|
136 |
+
If seed is None, return the RandomState singleton used by np.random.
|
137 |
+
If seed is an int, return a new RandomState instance seeded with seed.
|
138 |
+
If seed is already a RandomState instance, return it.
|
139 |
+
Otherwise raise ValueError.
|
140 |
+
|
141 |
+
Returns
|
142 |
+
-------
|
143 |
+
np.random.RandomState
|
144 |
+
Random state object.
|
145 |
+
|
146 |
+
Raises
|
147 |
+
------
|
148 |
+
ValueError
|
149 |
+
If seed is not valid, an error is thrown.
|
150 |
+
"""
|
151 |
+
if seed is None or seed is np.random:
|
152 |
+
return np.random.mtrand._rand
|
153 |
+
elif isinstance(seed, (numbers.Integral, np.integer, int)):
|
154 |
+
return np.random.RandomState(seed)
|
155 |
+
elif isinstance(seed, np.random.RandomState):
|
156 |
+
return seed
|
157 |
+
else:
|
158 |
+
raise ValueError(
|
159 |
+
"%r cannot be used to seed a numpy.random.RandomState" " instance" % seed
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
def seed(random_seed, set_cudnn=False):
|
164 |
+
"""
|
165 |
+
Seeds all random states with the same random seed
|
166 |
+
for reproducibility. Seeds ``numpy``, ``random`` and ``torch``
|
167 |
+
random generators.
|
168 |
+
For full reproducibility, two further options must be set
|
169 |
+
according to the torch documentation:
|
170 |
+
https://pytorch.org/docs/stable/notes/randomness.html
|
171 |
+
To do this, ``set_cudnn`` must be True. It defaults to
|
172 |
+
False, since setting it to True results in a performance
|
173 |
+
hit.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
random_seed (int): integer corresponding to random seed to
|
177 |
+
use.
|
178 |
+
set_cudnn (bool): Whether or not to set cudnn into determinstic
|
179 |
+
mode and off of benchmark mode. Defaults to False.
|
180 |
+
"""
|
181 |
+
|
182 |
+
torch.manual_seed(random_seed)
|
183 |
+
np.random.seed(random_seed)
|
184 |
+
random.seed(random_seed)
|
185 |
+
|
186 |
+
if set_cudnn:
|
187 |
+
torch.backends.cudnn.deterministic = True
|
188 |
+
torch.backends.cudnn.benchmark = False
|
189 |
+
|
190 |
+
|
191 |
+
@contextmanager
|
192 |
+
def _close_temp_files(tmpfiles: list):
|
193 |
+
"""Utility function for creating a context and closing all temporary files
|
194 |
+
once the context is exited. For correct functionality, all temporary file
|
195 |
+
handles created inside the context must be appended to the ```tmpfiles```
|
196 |
+
list.
|
197 |
+
|
198 |
+
This function is taken wholesale from Scaper.
|
199 |
+
|
200 |
+
Parameters
|
201 |
+
----------
|
202 |
+
tmpfiles : list
|
203 |
+
List of temporary file handles
|
204 |
+
"""
|
205 |
+
|
206 |
+
def _close():
|
207 |
+
for t in tmpfiles:
|
208 |
+
try:
|
209 |
+
t.close()
|
210 |
+
os.unlink(t.name)
|
211 |
+
except:
|
212 |
+
pass
|
213 |
+
|
214 |
+
try:
|
215 |
+
yield
|
216 |
+
except: # pragma: no cover
|
217 |
+
_close()
|
218 |
+
raise
|
219 |
+
_close()
|
220 |
+
|
221 |
+
|
222 |
+
AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"]
|
223 |
+
|
224 |
+
|
225 |
+
def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS):
|
226 |
+
"""Finds all audio files in a directory recursively.
|
227 |
+
Returns a list.
|
228 |
+
|
229 |
+
Parameters
|
230 |
+
----------
|
231 |
+
folder : str
|
232 |
+
Folder to look for audio files in, recursively.
|
233 |
+
ext : List[str], optional
|
234 |
+
Extensions to look for without the ., by default
|
235 |
+
``['.wav', '.flac', '.mp3', '.mp4']``.
|
236 |
+
"""
|
237 |
+
folder = Path(folder)
|
238 |
+
# Take care of case where user has passed in an audio file directly
|
239 |
+
# into one of the calling functions.
|
240 |
+
if str(folder).endswith(tuple(ext)):
|
241 |
+
# if, however, there's a glob in the path, we need to
|
242 |
+
# return the glob, not the file.
|
243 |
+
if "*" in str(folder):
|
244 |
+
return glob.glob(str(folder), recursive=("**" in str(folder)))
|
245 |
+
else:
|
246 |
+
return [folder]
|
247 |
+
|
248 |
+
files = []
|
249 |
+
for x in ext:
|
250 |
+
files += folder.glob(f"**/*{x}")
|
251 |
+
return files
|
252 |
+
|
253 |
+
|
254 |
+
def read_sources(
|
255 |
+
sources: List[str],
|
256 |
+
remove_empty: bool = True,
|
257 |
+
relative_path: str = "",
|
258 |
+
ext: List[str] = AUDIO_EXTENSIONS,
|
259 |
+
):
|
260 |
+
"""Reads audio sources that can either be folders
|
261 |
+
full of audio files, or CSV files that contain paths
|
262 |
+
to audio files. CSV files that adhere to the expected
|
263 |
+
format can be generated by
|
264 |
+
:py:func:`audiotools.data.preprocess.create_csv`.
|
265 |
+
|
266 |
+
Parameters
|
267 |
+
----------
|
268 |
+
sources : List[str]
|
269 |
+
List of audio sources to be converted into a
|
270 |
+
list of lists of audio files.
|
271 |
+
remove_empty : bool, optional
|
272 |
+
Whether or not to remove rows with an empty "path"
|
273 |
+
from each CSV file, by default True.
|
274 |
+
|
275 |
+
Returns
|
276 |
+
-------
|
277 |
+
list
|
278 |
+
List of lists of rows of CSV files.
|
279 |
+
"""
|
280 |
+
files = []
|
281 |
+
relative_path = Path(relative_path)
|
282 |
+
for source in sources:
|
283 |
+
source = str(source)
|
284 |
+
_files = []
|
285 |
+
if source.endswith(".csv"):
|
286 |
+
with open(source, "r") as f:
|
287 |
+
reader = csv.DictReader(f)
|
288 |
+
for x in reader:
|
289 |
+
if remove_empty and x["path"] == "":
|
290 |
+
continue
|
291 |
+
if x["path"] != "":
|
292 |
+
x["path"] = str(relative_path / x["path"])
|
293 |
+
_files.append(x)
|
294 |
+
else:
|
295 |
+
for x in find_audio(source, ext=ext):
|
296 |
+
x = str(relative_path / x)
|
297 |
+
_files.append({"path": x})
|
298 |
+
files.append(sorted(_files, key=lambda x: x["path"]))
|
299 |
+
return files
|
300 |
+
|
301 |
+
|
302 |
+
def choose_from_list_of_lists(
|
303 |
+
state: np.random.RandomState, list_of_lists: list, p: float = None
|
304 |
+
):
|
305 |
+
"""Choose a single item from a list of lists.
|
306 |
+
|
307 |
+
Parameters
|
308 |
+
----------
|
309 |
+
state : np.random.RandomState
|
310 |
+
Random state to use when choosing an item.
|
311 |
+
list_of_lists : list
|
312 |
+
A list of lists from which items will be drawn.
|
313 |
+
p : float, optional
|
314 |
+
Probabilities of each list, by default None
|
315 |
+
|
316 |
+
Returns
|
317 |
+
-------
|
318 |
+
typing.Any
|
319 |
+
An item from the list of lists.
|
320 |
+
"""
|
321 |
+
source_idx = state.choice(list(range(len(list_of_lists))), p=p)
|
322 |
+
item_idx = state.randint(len(list_of_lists[source_idx]))
|
323 |
+
return list_of_lists[source_idx][item_idx], source_idx, item_idx
|
324 |
+
|
325 |
+
|
326 |
+
@contextmanager
|
327 |
+
def chdir(newdir: typing.Union[Path, str]):
|
328 |
+
"""
|
329 |
+
Context manager for switching directories to run a
|
330 |
+
function. Useful for when you want to use relative
|
331 |
+
paths to different runs.
|
332 |
+
|
333 |
+
Parameters
|
334 |
+
----------
|
335 |
+
newdir : typing.Union[Path, str]
|
336 |
+
Directory to switch to.
|
337 |
+
"""
|
338 |
+
curdir = os.getcwd()
|
339 |
+
try:
|
340 |
+
os.chdir(newdir)
|
341 |
+
yield
|
342 |
+
finally:
|
343 |
+
os.chdir(curdir)
|
344 |
+
|
345 |
+
|
346 |
+
def prepare_batch(batch: typing.Union[dict, list, torch.Tensor], device: str = "cpu"):
|
347 |
+
"""Moves items in a batch (typically generated by a DataLoader as a list
|
348 |
+
or a dict) to the specified device. This works even if dictionaries
|
349 |
+
are nested.
|
350 |
+
|
351 |
+
Parameters
|
352 |
+
----------
|
353 |
+
batch : typing.Union[dict, list, torch.Tensor]
|
354 |
+
Batch, typically generated by a dataloader, that will be moved to
|
355 |
+
the device.
|
356 |
+
device : str, optional
|
357 |
+
Device to move batch to, by default "cpu"
|
358 |
+
|
359 |
+
Returns
|
360 |
+
-------
|
361 |
+
typing.Union[dict, list, torch.Tensor]
|
362 |
+
Batch with all values moved to the specified device.
|
363 |
+
"""
|
364 |
+
if isinstance(batch, dict):
|
365 |
+
batch = flatten(batch)
|
366 |
+
for key, val in batch.items():
|
367 |
+
try:
|
368 |
+
batch[key] = val.to(device)
|
369 |
+
except:
|
370 |
+
pass
|
371 |
+
batch = unflatten(batch)
|
372 |
+
elif torch.is_tensor(batch):
|
373 |
+
batch = batch.to(device)
|
374 |
+
elif isinstance(batch, list):
|
375 |
+
for i in range(len(batch)):
|
376 |
+
try:
|
377 |
+
batch[i] = batch[i].to(device)
|
378 |
+
except:
|
379 |
+
pass
|
380 |
+
return batch
|
381 |
+
|
382 |
+
|
383 |
+
def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None):
|
384 |
+
"""Samples from a distribution defined by a tuple. The first
|
385 |
+
item in the tuple is the distribution type, and the rest of the
|
386 |
+
items are arguments to that distribution. The distribution function
|
387 |
+
is gotten from the ``np.random.RandomState`` object.
|
388 |
+
|
389 |
+
Parameters
|
390 |
+
----------
|
391 |
+
dist_tuple : tuple
|
392 |
+
Distribution tuple
|
393 |
+
state : np.random.RandomState, optional
|
394 |
+
Random state, or seed to use, by default None
|
395 |
+
|
396 |
+
Returns
|
397 |
+
-------
|
398 |
+
typing.Union[float, int, str]
|
399 |
+
Draw from the distribution.
|
400 |
+
|
401 |
+
Examples
|
402 |
+
--------
|
403 |
+
Sample from a uniform distribution:
|
404 |
+
|
405 |
+
>>> dist_tuple = ("uniform", 0, 1)
|
406 |
+
>>> sample_from_dist(dist_tuple)
|
407 |
+
|
408 |
+
Sample from a constant distribution:
|
409 |
+
|
410 |
+
>>> dist_tuple = ("const", 0)
|
411 |
+
>>> sample_from_dist(dist_tuple)
|
412 |
+
|
413 |
+
Sample from a normal distribution:
|
414 |
+
|
415 |
+
>>> dist_tuple = ("normal", 0, 0.5)
|
416 |
+
>>> sample_from_dist(dist_tuple)
|
417 |
+
|
418 |
+
"""
|
419 |
+
if dist_tuple[0] == "const":
|
420 |
+
return dist_tuple[1]
|
421 |
+
state = random_state(state)
|
422 |
+
dist_fn = getattr(state, dist_tuple[0])
|
423 |
+
return dist_fn(*dist_tuple[1:])
|
424 |
+
|
425 |
+
|
426 |
+
def collate(list_of_dicts: list, n_splits: int = None):
|
427 |
+
"""Collates a list of dictionaries (e.g. as returned by a
|
428 |
+
dataloader) into a dictionary with batched values. This routine
|
429 |
+
uses the default torch collate function for everything
|
430 |
+
except AudioSignal objects, which are handled by the
|
431 |
+
:py:func:`audiotools.core.audio_signal.AudioSignal.batch`
|
432 |
+
function.
|
433 |
+
|
434 |
+
This function takes n_splits to enable splitting a batch
|
435 |
+
into multiple sub-batches for the purposes of gradient accumulation,
|
436 |
+
etc.
|
437 |
+
|
438 |
+
Parameters
|
439 |
+
----------
|
440 |
+
list_of_dicts : list
|
441 |
+
List of dictionaries to be collated.
|
442 |
+
n_splits : int
|
443 |
+
Number of splits to make when creating the batches (split into
|
444 |
+
sub-batches). Useful for things like gradient accumulation.
|
445 |
+
|
446 |
+
Returns
|
447 |
+
-------
|
448 |
+
dict
|
449 |
+
Dictionary containing batched data.
|
450 |
+
"""
|
451 |
+
|
452 |
+
from . import AudioSignal
|
453 |
+
|
454 |
+
batches = []
|
455 |
+
list_len = len(list_of_dicts)
|
456 |
+
|
457 |
+
return_list = False if n_splits is None else True
|
458 |
+
n_splits = 1 if n_splits is None else n_splits
|
459 |
+
n_items = int(math.ceil(list_len / n_splits))
|
460 |
+
|
461 |
+
for i in range(0, list_len, n_items):
|
462 |
+
# Flatten the dictionaries to avoid recursion.
|
463 |
+
list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]]
|
464 |
+
dict_of_lists = {
|
465 |
+
k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0]
|
466 |
+
}
|
467 |
+
|
468 |
+
batch = {}
|
469 |
+
for k, v in dict_of_lists.items():
|
470 |
+
if isinstance(v, list):
|
471 |
+
if all(isinstance(s, AudioSignal) for s in v):
|
472 |
+
batch[k] = AudioSignal.batch(v, pad_signals=True)
|
473 |
+
else:
|
474 |
+
# Borrow the default collate fn from torch.
|
475 |
+
batch[k] = torch.utils.data._utils.collate.default_collate(v)
|
476 |
+
batches.append(unflatten(batch))
|
477 |
+
|
478 |
+
batches = batches[0] if not return_list else batches
|
479 |
+
return batches
|
480 |
+
|
481 |
+
|
482 |
+
BASE_SIZE = 864
|
483 |
+
DEFAULT_FIG_SIZE = (9, 3)
|
484 |
+
|
485 |
+
|
486 |
+
def format_figure(
|
487 |
+
fig_size: tuple = None,
|
488 |
+
title: str = None,
|
489 |
+
fig=None,
|
490 |
+
format_axes: bool = True,
|
491 |
+
format: bool = True,
|
492 |
+
font_color: str = "white",
|
493 |
+
):
|
494 |
+
"""Prettifies the spectrogram and waveform plots. A title
|
495 |
+
can be inset into the top right corner, and the axes can be
|
496 |
+
inset into the figure, allowing the data to take up the entire
|
497 |
+
image. Used in
|
498 |
+
|
499 |
+
- :py:func:`audiotools.core.display.DisplayMixin.specshow`
|
500 |
+
- :py:func:`audiotools.core.display.DisplayMixin.waveplot`
|
501 |
+
- :py:func:`audiotools.core.display.DisplayMixin.wavespec`
|
502 |
+
|
503 |
+
Parameters
|
504 |
+
----------
|
505 |
+
fig_size : tuple, optional
|
506 |
+
Size of figure, by default (9, 3)
|
507 |
+
title : str, optional
|
508 |
+
Title to inset in top right, by default None
|
509 |
+
fig : matplotlib.figure.Figure, optional
|
510 |
+
Figure object, if None ``plt.gcf()`` will be used, by default None
|
511 |
+
format_axes : bool, optional
|
512 |
+
Format the axes to be inside the figure, by default True
|
513 |
+
format : bool, optional
|
514 |
+
This formatting can be skipped entirely by passing ``format=False``
|
515 |
+
to any of the plotting functions that use this formater, by default True
|
516 |
+
font_color : str, optional
|
517 |
+
Color of font of axes, by default "white"
|
518 |
+
"""
|
519 |
+
import matplotlib
|
520 |
+
import matplotlib.pyplot as plt
|
521 |
+
|
522 |
+
if fig_size is None:
|
523 |
+
fig_size = DEFAULT_FIG_SIZE
|
524 |
+
if not format:
|
525 |
+
return
|
526 |
+
if fig is None:
|
527 |
+
fig = plt.gcf()
|
528 |
+
fig.set_size_inches(*fig_size)
|
529 |
+
axs = fig.axes
|
530 |
+
|
531 |
+
pixels = (fig.get_size_inches() * fig.dpi)[0]
|
532 |
+
font_scale = pixels / BASE_SIZE
|
533 |
+
|
534 |
+
if format_axes:
|
535 |
+
axs = fig.axes
|
536 |
+
|
537 |
+
for ax in axs:
|
538 |
+
ymin, _ = ax.get_ylim()
|
539 |
+
xmin, _ = ax.get_xlim()
|
540 |
+
|
541 |
+
ticks = ax.get_yticks()
|
542 |
+
for t in ticks[2:-1]:
|
543 |
+
t = axs[0].annotate(
|
544 |
+
f"{(t / 1000):2.1f}k",
|
545 |
+
xy=(xmin, t),
|
546 |
+
xycoords="data",
|
547 |
+
xytext=(5, -5),
|
548 |
+
textcoords="offset points",
|
549 |
+
ha="left",
|
550 |
+
va="top",
|
551 |
+
color=font_color,
|
552 |
+
fontsize=12 * font_scale,
|
553 |
+
alpha=0.75,
|
554 |
+
)
|
555 |
+
|
556 |
+
ticks = ax.get_xticks()[2:]
|
557 |
+
for t in ticks[:-1]:
|
558 |
+
t = axs[0].annotate(
|
559 |
+
f"{t:2.1f}s",
|
560 |
+
xy=(t, ymin),
|
561 |
+
xycoords="data",
|
562 |
+
xytext=(5, 5),
|
563 |
+
textcoords="offset points",
|
564 |
+
ha="center",
|
565 |
+
va="bottom",
|
566 |
+
color=font_color,
|
567 |
+
fontsize=12 * font_scale,
|
568 |
+
alpha=0.75,
|
569 |
+
)
|
570 |
+
|
571 |
+
ax.margins(0, 0)
|
572 |
+
ax.set_axis_off()
|
573 |
+
ax.xaxis.set_major_locator(plt.NullLocator())
|
574 |
+
ax.yaxis.set_major_locator(plt.NullLocator())
|
575 |
+
|
576 |
+
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
577 |
+
|
578 |
+
if title is not None:
|
579 |
+
t = axs[0].annotate(
|
580 |
+
title,
|
581 |
+
xy=(1, 1),
|
582 |
+
xycoords="axes fraction",
|
583 |
+
fontsize=20 * font_scale,
|
584 |
+
xytext=(-5, -5),
|
585 |
+
textcoords="offset points",
|
586 |
+
ha="right",
|
587 |
+
va="top",
|
588 |
+
color="white",
|
589 |
+
)
|
590 |
+
t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black"))
|
591 |
+
|
592 |
+
|
593 |
+
def generate_chord_dataset(
|
594 |
+
max_voices: int = 8,
|
595 |
+
sample_rate: int = 44100,
|
596 |
+
num_items: int = 5,
|
597 |
+
duration: float = 1.0,
|
598 |
+
min_note: str = "C2",
|
599 |
+
max_note: str = "C6",
|
600 |
+
output_dir: Path = "chords",
|
601 |
+
):
|
602 |
+
"""
|
603 |
+
Generates a toy multitrack dataset of chords, synthesized from sine waves.
|
604 |
+
|
605 |
+
|
606 |
+
Parameters
|
607 |
+
----------
|
608 |
+
max_voices : int, optional
|
609 |
+
Maximum number of voices in a chord, by default 8
|
610 |
+
sample_rate : int, optional
|
611 |
+
Sample rate of audio, by default 44100
|
612 |
+
num_items : int, optional
|
613 |
+
Number of items to generate, by default 5
|
614 |
+
duration : float, optional
|
615 |
+
Duration of each item, by default 1.0
|
616 |
+
min_note : str, optional
|
617 |
+
Minimum note in the dataset, by default "C2"
|
618 |
+
max_note : str, optional
|
619 |
+
Maximum note in the dataset, by default "C6"
|
620 |
+
output_dir : Path, optional
|
621 |
+
Directory to save the dataset, by default "chords"
|
622 |
+
|
623 |
+
"""
|
624 |
+
import librosa
|
625 |
+
from . import AudioSignal
|
626 |
+
from ..data.preprocess import create_csv
|
627 |
+
|
628 |
+
min_midi = librosa.note_to_midi(min_note)
|
629 |
+
max_midi = librosa.note_to_midi(max_note)
|
630 |
+
|
631 |
+
tracks = []
|
632 |
+
for idx in range(num_items):
|
633 |
+
track = {}
|
634 |
+
# figure out how many voices to put in this track
|
635 |
+
num_voices = random.randint(1, max_voices)
|
636 |
+
for voice_idx in range(num_voices):
|
637 |
+
# choose some random params
|
638 |
+
midinote = random.randint(min_midi, max_midi)
|
639 |
+
dur = random.uniform(0.85 * duration, duration)
|
640 |
+
|
641 |
+
sig = AudioSignal.wave(
|
642 |
+
frequency=librosa.midi_to_hz(midinote),
|
643 |
+
duration=dur,
|
644 |
+
sample_rate=sample_rate,
|
645 |
+
shape="sine",
|
646 |
+
)
|
647 |
+
track[f"voice_{voice_idx}"] = sig
|
648 |
+
tracks.append(track)
|
649 |
+
|
650 |
+
# save the tracks to disk
|
651 |
+
output_dir = Path(output_dir)
|
652 |
+
output_dir.mkdir(exist_ok=True)
|
653 |
+
for idx, track in enumerate(tracks):
|
654 |
+
track_dir = output_dir / f"track_{idx}"
|
655 |
+
track_dir.mkdir(exist_ok=True)
|
656 |
+
for voice_name, sig in track.items():
|
657 |
+
sig.write(track_dir / f"{voice_name}.wav")
|
658 |
+
|
659 |
+
all_voices = list(set([k for track in tracks for k in track.keys()]))
|
660 |
+
voice_lists = {voice: [] for voice in all_voices}
|
661 |
+
for track in tracks:
|
662 |
+
for voice_name in all_voices:
|
663 |
+
if voice_name in track:
|
664 |
+
voice_lists[voice_name].append(track[voice_name].path_to_file)
|
665 |
+
else:
|
666 |
+
voice_lists[voice_name].append("")
|
667 |
+
|
668 |
+
for voice_name, paths in voice_lists.items():
|
669 |
+
create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True)
|
670 |
+
|
671 |
+
return output_dir
|
audiotools/core/whisper.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class WhisperMixin:
|
5 |
+
is_initialized = False
|
6 |
+
|
7 |
+
def setup_whisper(
|
8 |
+
self,
|
9 |
+
pretrained_model_name_or_path: str = "openai/whisper-base.en",
|
10 |
+
device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
11 |
+
):
|
12 |
+
from transformers import WhisperForConditionalGeneration
|
13 |
+
from transformers import WhisperProcessor
|
14 |
+
|
15 |
+
self.whisper_device = device
|
16 |
+
self.whisper_processor = WhisperProcessor.from_pretrained(
|
17 |
+
pretrained_model_name_or_path
|
18 |
+
)
|
19 |
+
self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
|
20 |
+
pretrained_model_name_or_path
|
21 |
+
).to(self.whisper_device)
|
22 |
+
self.is_initialized = True
|
23 |
+
|
24 |
+
def get_whisper_features(self) -> torch.Tensor:
|
25 |
+
"""Preprocess audio signal as per the whisper model's training config.
|
26 |
+
|
27 |
+
Returns
|
28 |
+
-------
|
29 |
+
torch.Tensor
|
30 |
+
The prepinput features of the audio signal. Shape: (1, channels, seq_len)
|
31 |
+
"""
|
32 |
+
import torch
|
33 |
+
|
34 |
+
if not self.is_initialized:
|
35 |
+
self.setup_whisper()
|
36 |
+
|
37 |
+
signal = self.to(self.device)
|
38 |
+
raw_speech = list(
|
39 |
+
(
|
40 |
+
signal.clone()
|
41 |
+
.resample(self.whisper_processor.feature_extractor.sampling_rate)
|
42 |
+
.audio_data[:, 0, :]
|
43 |
+
.numpy()
|
44 |
+
)
|
45 |
+
)
|
46 |
+
|
47 |
+
with torch.inference_mode():
|
48 |
+
input_features = self.whisper_processor(
|
49 |
+
raw_speech,
|
50 |
+
sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
|
51 |
+
return_tensors="pt",
|
52 |
+
).input_features
|
53 |
+
|
54 |
+
return input_features
|
55 |
+
|
56 |
+
def get_whisper_transcript(self) -> str:
|
57 |
+
"""Get the transcript of the audio signal using the whisper model.
|
58 |
+
|
59 |
+
Returns
|
60 |
+
-------
|
61 |
+
str
|
62 |
+
The transcript of the audio signal, including special tokens such as <|startoftranscript|> and <|endoftext|>.
|
63 |
+
"""
|
64 |
+
|
65 |
+
if not self.is_initialized:
|
66 |
+
self.setup_whisper()
|
67 |
+
|
68 |
+
input_features = self.get_whisper_features()
|
69 |
+
|
70 |
+
with torch.inference_mode():
|
71 |
+
input_features = input_features.to(self.whisper_device)
|
72 |
+
generated_ids = self.whisper_model.generate(inputs=input_features)
|
73 |
+
|
74 |
+
transcription = self.whisper_processor.batch_decode(generated_ids)
|
75 |
+
return transcription[0]
|
76 |
+
|
77 |
+
def get_whisper_embeddings(self) -> torch.Tensor:
|
78 |
+
"""Get the last hidden state embeddings of the audio signal using the whisper model.
|
79 |
+
|
80 |
+
Returns
|
81 |
+
-------
|
82 |
+
torch.Tensor
|
83 |
+
The Whisper embeddings of the audio signal. Shape: (1, seq_len, hidden_size)
|
84 |
+
"""
|
85 |
+
import torch
|
86 |
+
|
87 |
+
if not self.is_initialized:
|
88 |
+
self.setup_whisper()
|
89 |
+
|
90 |
+
input_features = self.get_whisper_features()
|
91 |
+
encoder = self.whisper_model.get_encoder()
|
92 |
+
|
93 |
+
with torch.inference_mode():
|
94 |
+
input_features = input_features.to(self.whisper_device)
|
95 |
+
embeddings = encoder(input_features)
|
96 |
+
|
97 |
+
return embeddings.last_hidden_state
|
audiotools/data/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from . import datasets
|
2 |
+
from . import preprocess
|
3 |
+
from . import transforms
|
audiotools/data/datasets.py
ADDED
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Callable
|
3 |
+
from typing import Dict
|
4 |
+
from typing import List
|
5 |
+
from typing import Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from torch.utils.data import SequentialSampler
|
9 |
+
from torch.utils.data.distributed import DistributedSampler
|
10 |
+
|
11 |
+
from ..core import AudioSignal
|
12 |
+
from ..core import util
|
13 |
+
|
14 |
+
|
15 |
+
class AudioLoader:
|
16 |
+
"""Loads audio endlessly from a list of audio sources
|
17 |
+
containing paths to audio files. Audio sources can be
|
18 |
+
folders full of audio files (which are found via file
|
19 |
+
extension) or by providing a CSV file which contains paths
|
20 |
+
to audio files.
|
21 |
+
|
22 |
+
Parameters
|
23 |
+
----------
|
24 |
+
sources : List[str], optional
|
25 |
+
Sources containing folders, or CSVs with
|
26 |
+
paths to audio files, by default None
|
27 |
+
weights : List[float], optional
|
28 |
+
Weights to sample audio files from each source, by default None
|
29 |
+
relative_path : str, optional
|
30 |
+
Path audio should be loaded relative to, by default ""
|
31 |
+
transform : Callable, optional
|
32 |
+
Transform to instantiate alongside audio sample,
|
33 |
+
by default None
|
34 |
+
ext : List[str]
|
35 |
+
List of extensions to find audio within each source by. Can
|
36 |
+
also be a file name (e.g. "vocals.wav"). by default
|
37 |
+
``['.wav', '.flac', '.mp3', '.mp4']``.
|
38 |
+
shuffle: bool
|
39 |
+
Whether to shuffle the files within the dataloader. Defaults to True.
|
40 |
+
shuffle_state: int
|
41 |
+
State to use to seed the shuffle of the files.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
sources: List[str] = None,
|
47 |
+
weights: List[float] = None,
|
48 |
+
transform: Callable = None,
|
49 |
+
relative_path: str = "",
|
50 |
+
ext: List[str] = util.AUDIO_EXTENSIONS,
|
51 |
+
shuffle: bool = True,
|
52 |
+
shuffle_state: int = 0,
|
53 |
+
):
|
54 |
+
self.audio_lists = util.read_sources(
|
55 |
+
sources, relative_path=relative_path, ext=ext
|
56 |
+
)
|
57 |
+
|
58 |
+
self.audio_indices = [
|
59 |
+
(src_idx, item_idx)
|
60 |
+
for src_idx, src in enumerate(self.audio_lists)
|
61 |
+
for item_idx in range(len(src))
|
62 |
+
]
|
63 |
+
if shuffle:
|
64 |
+
state = util.random_state(shuffle_state)
|
65 |
+
state.shuffle(self.audio_indices)
|
66 |
+
|
67 |
+
self.sources = sources
|
68 |
+
self.weights = weights
|
69 |
+
self.transform = transform
|
70 |
+
|
71 |
+
def __call__(
|
72 |
+
self,
|
73 |
+
state,
|
74 |
+
sample_rate: int,
|
75 |
+
duration: float,
|
76 |
+
loudness_cutoff: float = -40,
|
77 |
+
num_channels: int = 1,
|
78 |
+
offset: float = None,
|
79 |
+
source_idx: int = None,
|
80 |
+
item_idx: int = None,
|
81 |
+
global_idx: int = None,
|
82 |
+
):
|
83 |
+
if source_idx is not None and item_idx is not None:
|
84 |
+
try:
|
85 |
+
audio_info = self.audio_lists[source_idx][item_idx]
|
86 |
+
except:
|
87 |
+
audio_info = {"path": "none"}
|
88 |
+
elif global_idx is not None:
|
89 |
+
source_idx, item_idx = self.audio_indices[
|
90 |
+
global_idx % len(self.audio_indices)
|
91 |
+
]
|
92 |
+
audio_info = self.audio_lists[source_idx][item_idx]
|
93 |
+
else:
|
94 |
+
audio_info, source_idx, item_idx = util.choose_from_list_of_lists(
|
95 |
+
state, self.audio_lists, p=self.weights
|
96 |
+
)
|
97 |
+
|
98 |
+
path = audio_info["path"]
|
99 |
+
signal = AudioSignal.zeros(duration, sample_rate, num_channels)
|
100 |
+
|
101 |
+
if path != "none":
|
102 |
+
if offset is None:
|
103 |
+
signal = AudioSignal.salient_excerpt(
|
104 |
+
path,
|
105 |
+
duration=duration,
|
106 |
+
state=state,
|
107 |
+
loudness_cutoff=loudness_cutoff,
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
signal = AudioSignal(
|
111 |
+
path,
|
112 |
+
offset=offset,
|
113 |
+
duration=duration,
|
114 |
+
)
|
115 |
+
|
116 |
+
if num_channels == 1:
|
117 |
+
signal = signal.to_mono()
|
118 |
+
signal = signal.resample(sample_rate)
|
119 |
+
|
120 |
+
if signal.duration < duration:
|
121 |
+
signal = signal.zero_pad_to(int(duration * sample_rate))
|
122 |
+
|
123 |
+
for k, v in audio_info.items():
|
124 |
+
signal.metadata[k] = v
|
125 |
+
|
126 |
+
item = {
|
127 |
+
"signal": signal,
|
128 |
+
"source_idx": source_idx,
|
129 |
+
"item_idx": item_idx,
|
130 |
+
"source": str(self.sources[source_idx]),
|
131 |
+
"path": str(path),
|
132 |
+
}
|
133 |
+
if self.transform is not None:
|
134 |
+
item["transform_args"] = self.transform.instantiate(state, signal=signal)
|
135 |
+
return item
|
136 |
+
|
137 |
+
|
138 |
+
def default_matcher(x, y):
|
139 |
+
return Path(x).parent == Path(y).parent
|
140 |
+
|
141 |
+
|
142 |
+
def align_lists(lists, matcher: Callable = default_matcher):
|
143 |
+
longest_list = lists[np.argmax([len(l) for l in lists])]
|
144 |
+
for i, x in enumerate(longest_list):
|
145 |
+
for l in lists:
|
146 |
+
if i >= len(l):
|
147 |
+
l.append({"path": "none"})
|
148 |
+
elif not matcher(l[i]["path"], x["path"]):
|
149 |
+
l.insert(i, {"path": "none"})
|
150 |
+
return lists
|
151 |
+
|
152 |
+
|
153 |
+
class AudioDataset:
|
154 |
+
"""Loads audio from multiple loaders (with associated transforms)
|
155 |
+
for a specified number of samples. Excerpts are drawn randomly
|
156 |
+
of the specified duration, above a specified loudness threshold
|
157 |
+
and are resampled on the fly to the desired sample rate
|
158 |
+
(if it is different from the audio source sample rate).
|
159 |
+
|
160 |
+
This takes either a single AudioLoader object,
|
161 |
+
a dictionary of AudioLoader objects, or a dictionary of AudioLoader
|
162 |
+
objects. Each AudioLoader is called by the dataset, and the
|
163 |
+
result is placed in the output dictionary. A transform can also be
|
164 |
+
specified for the entire dataset, rather than for each specific
|
165 |
+
loader. This transform can be applied to the output of all the
|
166 |
+
loaders if desired.
|
167 |
+
|
168 |
+
AudioLoader objects can be specified as aligned, which means the
|
169 |
+
loaders correspond to multitrack audio (e.g. a vocals, bass,
|
170 |
+
drums, and other loader for multitrack music mixtures).
|
171 |
+
|
172 |
+
|
173 |
+
Parameters
|
174 |
+
----------
|
175 |
+
loaders : Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]]
|
176 |
+
AudioLoaders to sample audio from.
|
177 |
+
sample_rate : int
|
178 |
+
Desired sample rate.
|
179 |
+
n_examples : int, optional
|
180 |
+
Number of examples (length of dataset), by default 1000
|
181 |
+
duration : float, optional
|
182 |
+
Duration of audio samples, by default 0.5
|
183 |
+
loudness_cutoff : float, optional
|
184 |
+
Loudness cutoff threshold for audio samples, by default -40
|
185 |
+
num_channels : int, optional
|
186 |
+
Number of channels in output audio, by default 1
|
187 |
+
transform : Callable, optional
|
188 |
+
Transform to instantiate alongside each dataset item, by default None
|
189 |
+
aligned : bool, optional
|
190 |
+
Whether the loaders should be sampled in an aligned manner (e.g. same
|
191 |
+
offset, duration, and matched file name), by default False
|
192 |
+
shuffle_loaders : bool, optional
|
193 |
+
Whether to shuffle the loaders before sampling from them, by default False
|
194 |
+
matcher : Callable
|
195 |
+
How to match files from adjacent audio lists (e.g. for a multitrack audio loader),
|
196 |
+
by default uses the parent directory of each file.
|
197 |
+
without_replacement : bool
|
198 |
+
Whether to choose files with or without replacement, by default True.
|
199 |
+
|
200 |
+
|
201 |
+
Examples
|
202 |
+
--------
|
203 |
+
>>> from audiotools.data.datasets import AudioLoader
|
204 |
+
>>> from audiotools.data.datasets import AudioDataset
|
205 |
+
>>> from audiotools import transforms as tfm
|
206 |
+
>>> import numpy as np
|
207 |
+
>>>
|
208 |
+
>>> loaders = [
|
209 |
+
>>> AudioLoader(
|
210 |
+
>>> sources=[f"tests/audio/spk"],
|
211 |
+
>>> transform=tfm.Equalizer(),
|
212 |
+
>>> ext=["wav"],
|
213 |
+
>>> )
|
214 |
+
>>> for i in range(5)
|
215 |
+
>>> ]
|
216 |
+
>>>
|
217 |
+
>>> dataset = AudioDataset(
|
218 |
+
>>> loaders = loaders,
|
219 |
+
>>> sample_rate = 44100,
|
220 |
+
>>> duration = 1.0,
|
221 |
+
>>> transform = tfm.RescaleAudio(),
|
222 |
+
>>> )
|
223 |
+
>>>
|
224 |
+
>>> item = dataset[np.random.randint(len(dataset))]
|
225 |
+
>>>
|
226 |
+
>>> for i in range(len(loaders)):
|
227 |
+
>>> item[i]["signal"] = loaders[i].transform(
|
228 |
+
>>> item[i]["signal"], **item[i]["transform_args"]
|
229 |
+
>>> )
|
230 |
+
>>> item[i]["signal"].widget(i)
|
231 |
+
>>>
|
232 |
+
>>> mix = sum([item[i]["signal"] for i in range(len(loaders))])
|
233 |
+
>>> mix = dataset.transform(mix, **item["transform_args"])
|
234 |
+
>>> mix.widget("mix")
|
235 |
+
|
236 |
+
Below is an example of how one could load MUSDB multitrack data:
|
237 |
+
|
238 |
+
>>> import audiotools as at
|
239 |
+
>>> from pathlib import Path
|
240 |
+
>>> from audiotools import transforms as tfm
|
241 |
+
>>> import numpy as np
|
242 |
+
>>> import torch
|
243 |
+
>>>
|
244 |
+
>>> def build_dataset(
|
245 |
+
>>> sample_rate: int = 44100,
|
246 |
+
>>> duration: float = 5.0,
|
247 |
+
>>> musdb_path: str = "~/.data/musdb/",
|
248 |
+
>>> ):
|
249 |
+
>>> musdb_path = Path(musdb_path).expanduser()
|
250 |
+
>>> loaders = {
|
251 |
+
>>> src: at.datasets.AudioLoader(
|
252 |
+
>>> sources=[musdb_path],
|
253 |
+
>>> transform=tfm.Compose(
|
254 |
+
>>> tfm.VolumeNorm(("uniform", -20, -10)),
|
255 |
+
>>> tfm.Silence(prob=0.1),
|
256 |
+
>>> ),
|
257 |
+
>>> ext=[f"{src}.wav"],
|
258 |
+
>>> )
|
259 |
+
>>> for src in ["vocals", "bass", "drums", "other"]
|
260 |
+
>>> }
|
261 |
+
>>>
|
262 |
+
>>> dataset = at.datasets.AudioDataset(
|
263 |
+
>>> loaders=loaders,
|
264 |
+
>>> sample_rate=sample_rate,
|
265 |
+
>>> duration=duration,
|
266 |
+
>>> num_channels=1,
|
267 |
+
>>> aligned=True,
|
268 |
+
>>> transform=tfm.RescaleAudio(),
|
269 |
+
>>> shuffle_loaders=True,
|
270 |
+
>>> )
|
271 |
+
>>> return dataset, list(loaders.keys())
|
272 |
+
>>>
|
273 |
+
>>> train_data, sources = build_dataset()
|
274 |
+
>>> dataloader = torch.utils.data.DataLoader(
|
275 |
+
>>> train_data,
|
276 |
+
>>> batch_size=16,
|
277 |
+
>>> num_workers=0,
|
278 |
+
>>> collate_fn=train_data.collate,
|
279 |
+
>>> )
|
280 |
+
>>> batch = next(iter(dataloader))
|
281 |
+
>>>
|
282 |
+
>>> for k in sources:
|
283 |
+
>>> src = batch[k]
|
284 |
+
>>> src["transformed"] = train_data.loaders[k].transform(
|
285 |
+
>>> src["signal"].clone(), **src["transform_args"]
|
286 |
+
>>> )
|
287 |
+
>>>
|
288 |
+
>>> mixture = sum(batch[k]["transformed"] for k in sources)
|
289 |
+
>>> mixture = train_data.transform(mixture, **batch["transform_args"])
|
290 |
+
>>>
|
291 |
+
>>> # Say a model takes the mix and gives back (n_batch, n_src, n_time).
|
292 |
+
>>> # Construct the targets:
|
293 |
+
>>> targets = at.AudioSignal.batch([batch[k]["transformed"] for k in sources], dim=1)
|
294 |
+
|
295 |
+
Similarly, here's example code for loading Slakh data:
|
296 |
+
|
297 |
+
>>> import audiotools as at
|
298 |
+
>>> from pathlib import Path
|
299 |
+
>>> from audiotools import transforms as tfm
|
300 |
+
>>> import numpy as np
|
301 |
+
>>> import torch
|
302 |
+
>>> import glob
|
303 |
+
>>>
|
304 |
+
>>> def build_dataset(
|
305 |
+
>>> sample_rate: int = 16000,
|
306 |
+
>>> duration: float = 10.0,
|
307 |
+
>>> slakh_path: str = "~/.data/slakh/",
|
308 |
+
>>> ):
|
309 |
+
>>> slakh_path = Path(slakh_path).expanduser()
|
310 |
+
>>>
|
311 |
+
>>> # Find the max number of sources in Slakh
|
312 |
+
>>> src_names = [x.name for x in list(slakh_path.glob("**/*.wav")) if "S" in str(x.name)]
|
313 |
+
>>> n_sources = len(list(set(src_names)))
|
314 |
+
>>>
|
315 |
+
>>> loaders = {
|
316 |
+
>>> f"S{i:02d}": at.datasets.AudioLoader(
|
317 |
+
>>> sources=[slakh_path],
|
318 |
+
>>> transform=tfm.Compose(
|
319 |
+
>>> tfm.VolumeNorm(("uniform", -20, -10)),
|
320 |
+
>>> tfm.Silence(prob=0.1),
|
321 |
+
>>> ),
|
322 |
+
>>> ext=[f"S{i:02d}.wav"],
|
323 |
+
>>> )
|
324 |
+
>>> for i in range(n_sources)
|
325 |
+
>>> }
|
326 |
+
>>> dataset = at.datasets.AudioDataset(
|
327 |
+
>>> loaders=loaders,
|
328 |
+
>>> sample_rate=sample_rate,
|
329 |
+
>>> duration=duration,
|
330 |
+
>>> num_channels=1,
|
331 |
+
>>> aligned=True,
|
332 |
+
>>> transform=tfm.RescaleAudio(),
|
333 |
+
>>> shuffle_loaders=False,
|
334 |
+
>>> )
|
335 |
+
>>>
|
336 |
+
>>> return dataset, list(loaders.keys())
|
337 |
+
>>>
|
338 |
+
>>> train_data, sources = build_dataset()
|
339 |
+
>>> dataloader = torch.utils.data.DataLoader(
|
340 |
+
>>> train_data,
|
341 |
+
>>> batch_size=16,
|
342 |
+
>>> num_workers=0,
|
343 |
+
>>> collate_fn=train_data.collate,
|
344 |
+
>>> )
|
345 |
+
>>> batch = next(iter(dataloader))
|
346 |
+
>>>
|
347 |
+
>>> for k in sources:
|
348 |
+
>>> src = batch[k]
|
349 |
+
>>> src["transformed"] = train_data.loaders[k].transform(
|
350 |
+
>>> src["signal"].clone(), **src["transform_args"]
|
351 |
+
>>> )
|
352 |
+
>>>
|
353 |
+
>>> mixture = sum(batch[k]["transformed"] for k in sources)
|
354 |
+
>>> mixture = train_data.transform(mixture, **batch["transform_args"])
|
355 |
+
|
356 |
+
"""
|
357 |
+
|
358 |
+
def __init__(
|
359 |
+
self,
|
360 |
+
loaders: Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]],
|
361 |
+
sample_rate: int,
|
362 |
+
n_examples: int = 1000,
|
363 |
+
duration: float = 0.5,
|
364 |
+
offset: float = None,
|
365 |
+
loudness_cutoff: float = -40,
|
366 |
+
num_channels: int = 1,
|
367 |
+
transform: Callable = None,
|
368 |
+
aligned: bool = False,
|
369 |
+
shuffle_loaders: bool = False,
|
370 |
+
matcher: Callable = default_matcher,
|
371 |
+
without_replacement: bool = True,
|
372 |
+
):
|
373 |
+
# Internally we convert loaders to a dictionary
|
374 |
+
if isinstance(loaders, list):
|
375 |
+
loaders = {i: l for i, l in enumerate(loaders)}
|
376 |
+
elif isinstance(loaders, AudioLoader):
|
377 |
+
loaders = {0: loaders}
|
378 |
+
|
379 |
+
self.loaders = loaders
|
380 |
+
self.loudness_cutoff = loudness_cutoff
|
381 |
+
self.num_channels = num_channels
|
382 |
+
|
383 |
+
self.length = n_examples
|
384 |
+
self.transform = transform
|
385 |
+
self.sample_rate = sample_rate
|
386 |
+
self.duration = duration
|
387 |
+
self.offset = offset
|
388 |
+
self.aligned = aligned
|
389 |
+
self.shuffle_loaders = shuffle_loaders
|
390 |
+
self.without_replacement = without_replacement
|
391 |
+
|
392 |
+
if aligned:
|
393 |
+
loaders_list = list(loaders.values())
|
394 |
+
for i in range(len(loaders_list[0].audio_lists)):
|
395 |
+
input_lists = [l.audio_lists[i] for l in loaders_list]
|
396 |
+
# Alignment happens in-place
|
397 |
+
align_lists(input_lists, matcher)
|
398 |
+
|
399 |
+
def __getitem__(self, idx):
|
400 |
+
state = util.random_state(idx)
|
401 |
+
offset = None if self.offset is None else self.offset
|
402 |
+
item = {}
|
403 |
+
|
404 |
+
keys = list(self.loaders.keys())
|
405 |
+
if self.shuffle_loaders:
|
406 |
+
state.shuffle(keys)
|
407 |
+
|
408 |
+
loader_kwargs = {
|
409 |
+
"state": state,
|
410 |
+
"sample_rate": self.sample_rate,
|
411 |
+
"duration": self.duration,
|
412 |
+
"loudness_cutoff": self.loudness_cutoff,
|
413 |
+
"num_channels": self.num_channels,
|
414 |
+
"global_idx": idx if self.without_replacement else None,
|
415 |
+
}
|
416 |
+
|
417 |
+
# Draw item from first loader
|
418 |
+
loader = self.loaders[keys[0]]
|
419 |
+
item[keys[0]] = loader(**loader_kwargs)
|
420 |
+
|
421 |
+
for key in keys[1:]:
|
422 |
+
loader = self.loaders[key]
|
423 |
+
if self.aligned:
|
424 |
+
# Path mapper takes the current loader + everything
|
425 |
+
# returned by the first loader.
|
426 |
+
offset = item[keys[0]]["signal"].metadata["offset"]
|
427 |
+
loader_kwargs.update(
|
428 |
+
{
|
429 |
+
"offset": offset,
|
430 |
+
"source_idx": item[keys[0]]["source_idx"],
|
431 |
+
"item_idx": item[keys[0]]["item_idx"],
|
432 |
+
}
|
433 |
+
)
|
434 |
+
item[key] = loader(**loader_kwargs)
|
435 |
+
|
436 |
+
# Sort dictionary back into original order
|
437 |
+
keys = list(self.loaders.keys())
|
438 |
+
item = {k: item[k] for k in keys}
|
439 |
+
|
440 |
+
item["idx"] = idx
|
441 |
+
if self.transform is not None:
|
442 |
+
item["transform_args"] = self.transform.instantiate(
|
443 |
+
state=state, signal=item[keys[0]]["signal"]
|
444 |
+
)
|
445 |
+
|
446 |
+
# If there's only one loader, pop it up
|
447 |
+
# to the main dictionary, instead of keeping it
|
448 |
+
# nested.
|
449 |
+
if len(keys) == 1:
|
450 |
+
item.update(item.pop(keys[0]))
|
451 |
+
|
452 |
+
return item
|
453 |
+
|
454 |
+
def __len__(self):
|
455 |
+
return self.length
|
456 |
+
|
457 |
+
@staticmethod
|
458 |
+
def collate(list_of_dicts: Union[list, dict], n_splits: int = None):
|
459 |
+
"""Collates items drawn from this dataset. Uses
|
460 |
+
:py:func:`audiotools.core.util.collate`.
|
461 |
+
|
462 |
+
Parameters
|
463 |
+
----------
|
464 |
+
list_of_dicts : typing.Union[list, dict]
|
465 |
+
Data drawn from each item.
|
466 |
+
n_splits : int
|
467 |
+
Number of splits to make when creating the batches (split into
|
468 |
+
sub-batches). Useful for things like gradient accumulation.
|
469 |
+
|
470 |
+
Returns
|
471 |
+
-------
|
472 |
+
dict
|
473 |
+
Dictionary of batched data.
|
474 |
+
"""
|
475 |
+
return util.collate(list_of_dicts, n_splits=n_splits)
|
476 |
+
|
477 |
+
|
478 |
+
class ConcatDataset(AudioDataset):
|
479 |
+
def __init__(self, datasets: list):
|
480 |
+
self.datasets = datasets
|
481 |
+
|
482 |
+
def __len__(self):
|
483 |
+
return sum([len(d) for d in self.datasets])
|
484 |
+
|
485 |
+
def __getitem__(self, idx):
|
486 |
+
dataset = self.datasets[idx % len(self.datasets)]
|
487 |
+
return dataset[idx // len(self.datasets)]
|
488 |
+
|
489 |
+
|
490 |
+
class ResumableDistributedSampler(DistributedSampler): # pragma: no cover
|
491 |
+
"""Distributed sampler that can be resumed from a given start index."""
|
492 |
+
|
493 |
+
def __init__(self, dataset, start_idx: int = None, **kwargs):
|
494 |
+
super().__init__(dataset, **kwargs)
|
495 |
+
# Start index, allows to resume an experiment at the index it was
|
496 |
+
self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0
|
497 |
+
|
498 |
+
def __iter__(self):
|
499 |
+
for i, idx in enumerate(super().__iter__()):
|
500 |
+
if i >= self.start_idx:
|
501 |
+
yield idx
|
502 |
+
self.start_idx = 0 # set the index back to 0 so for the next epoch
|
503 |
+
|
504 |
+
|
505 |
+
class ResumableSequentialSampler(SequentialSampler): # pragma: no cover
|
506 |
+
"""Sequential sampler that can be resumed from a given start index."""
|
507 |
+
|
508 |
+
def __init__(self, dataset, start_idx: int = None, **kwargs):
|
509 |
+
super().__init__(dataset, **kwargs)
|
510 |
+
# Start index, allows to resume an experiment at the index it was
|
511 |
+
self.start_idx = start_idx if start_idx is not None else 0
|
512 |
+
|
513 |
+
def __iter__(self):
|
514 |
+
for i, idx in enumerate(super().__iter__()):
|
515 |
+
if i >= self.start_idx:
|
516 |
+
yield idx
|
517 |
+
self.start_idx = 0 # set the index back to 0 so for the next epoch
|
audiotools/data/preprocess.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
from ..core import AudioSignal
|
8 |
+
|
9 |
+
|
10 |
+
def create_csv(
|
11 |
+
audio_files: list, output_csv: Path, loudness: bool = False, data_path: str = None
|
12 |
+
):
|
13 |
+
"""Converts a folder of audio files to a CSV file. If ``loudness = True``,
|
14 |
+
the output of this function will create a CSV file that looks something
|
15 |
+
like:
|
16 |
+
|
17 |
+
.. csv-table::
|
18 |
+
:header: path,loudness
|
19 |
+
|
20 |
+
daps/produced/f1_script1_produced.wav,-16.299999237060547
|
21 |
+
daps/produced/f1_script2_produced.wav,-16.600000381469727
|
22 |
+
daps/produced/f1_script3_produced.wav,-17.299999237060547
|
23 |
+
daps/produced/f1_script4_produced.wav,-16.100000381469727
|
24 |
+
daps/produced/f1_script5_produced.wav,-16.700000762939453
|
25 |
+
daps/produced/f3_script1_produced.wav,-16.5
|
26 |
+
|
27 |
+
.. note::
|
28 |
+
The paths above are written relative to the ``data_path`` argument
|
29 |
+
which defaults to the environment variable ``PATH_TO_DATA`` if
|
30 |
+
it isn't passed to this function, and defaults to the empty string
|
31 |
+
if that environment variable is not set.
|
32 |
+
|
33 |
+
You can produce a CSV file from a directory of audio files via:
|
34 |
+
|
35 |
+
>>> import audiotools
|
36 |
+
>>> directory = ...
|
37 |
+
>>> audio_files = audiotools.util.find_audio(directory)
|
38 |
+
>>> output_path = "train.csv"
|
39 |
+
>>> audiotools.data.preprocess.create_csv(
|
40 |
+
>>> audio_files, output_csv, loudness=True
|
41 |
+
>>> )
|
42 |
+
|
43 |
+
Note that you can create empty rows in the CSV file by passing an empty
|
44 |
+
string or None in the ``audio_files`` list. This is useful if you want to
|
45 |
+
sync multiple CSV files in a multitrack setting. The loudness of these
|
46 |
+
empty rows will be set to -inf.
|
47 |
+
|
48 |
+
Parameters
|
49 |
+
----------
|
50 |
+
audio_files : list
|
51 |
+
List of audio files.
|
52 |
+
output_csv : Path
|
53 |
+
Output CSV, with each row containing the relative path of every file
|
54 |
+
to ``data_path``, if specified (defaults to None).
|
55 |
+
loudness : bool
|
56 |
+
Compute loudness of entire file and store alongside path.
|
57 |
+
"""
|
58 |
+
|
59 |
+
info = []
|
60 |
+
pbar = tqdm(audio_files)
|
61 |
+
for af in pbar:
|
62 |
+
af = Path(af)
|
63 |
+
pbar.set_description(f"Processing {af.name}")
|
64 |
+
_info = {}
|
65 |
+
if af.name == "":
|
66 |
+
_info["path"] = ""
|
67 |
+
if loudness:
|
68 |
+
_info["loudness"] = -float("inf")
|
69 |
+
else:
|
70 |
+
_info["path"] = af.relative_to(data_path) if data_path is not None else af
|
71 |
+
if loudness:
|
72 |
+
_info["loudness"] = AudioSignal(af).ffmpeg_loudness().item()
|
73 |
+
|
74 |
+
info.append(_info)
|
75 |
+
|
76 |
+
with open(output_csv, "w") as f:
|
77 |
+
writer = csv.DictWriter(f, fieldnames=list(info[0].keys()))
|
78 |
+
writer.writeheader()
|
79 |
+
|
80 |
+
for item in info:
|
81 |
+
writer.writerow(item)
|
audiotools/data/transforms.py
ADDED
@@ -0,0 +1,1592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from contextlib import contextmanager
|
3 |
+
from inspect import signature
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from flatten_dict import flatten
|
9 |
+
from flatten_dict import unflatten
|
10 |
+
from numpy.random import RandomState
|
11 |
+
|
12 |
+
from .. import ml
|
13 |
+
from ..core import AudioSignal
|
14 |
+
from ..core import util
|
15 |
+
from .datasets import AudioLoader
|
16 |
+
|
17 |
+
tt = torch.tensor
|
18 |
+
"""Shorthand for converting things to torch.tensor."""
|
19 |
+
|
20 |
+
|
21 |
+
class BaseTransform:
|
22 |
+
"""This is the base class for all transforms that are implemented
|
23 |
+
in this library. Transforms have two main operations: ``transform``
|
24 |
+
and ``instantiate``.
|
25 |
+
|
26 |
+
``instantiate`` sets the parameters randomly
|
27 |
+
from distribution tuples for each parameter. For example, for the
|
28 |
+
``BackgroundNoise`` transform, the signal-to-noise ratio (``snr``)
|
29 |
+
is chosen randomly by instantiate. By default, it chosen uniformly
|
30 |
+
between 10.0 and 30.0 (the tuple is set to ``("uniform", 10.0, 30.0)``).
|
31 |
+
|
32 |
+
``transform`` applies the transform using the instantiated parameters.
|
33 |
+
A simple example is as follows:
|
34 |
+
|
35 |
+
>>> seed = 0
|
36 |
+
>>> signal = ...
|
37 |
+
>>> transform = transforms.NoiseFloor(db = ("uniform", -50.0, -30.0))
|
38 |
+
>>> kwargs = transform.instantiate()
|
39 |
+
>>> output = transform(signal.clone(), **kwargs)
|
40 |
+
|
41 |
+
By breaking apart the instantiation of parameters from the actual audio
|
42 |
+
processing of the transform, we can make things more reproducible, while
|
43 |
+
also applying the transform on batches of data efficiently on GPU,
|
44 |
+
rather than on individual audio samples.
|
45 |
+
|
46 |
+
.. note::
|
47 |
+
We call ``signal.clone()`` for the input to the ``transform`` function
|
48 |
+
because signals are modified in-place! If you don't clone the signal,
|
49 |
+
you will lose the original data.
|
50 |
+
|
51 |
+
Parameters
|
52 |
+
----------
|
53 |
+
keys : list, optional
|
54 |
+
Keys that the transform looks for when
|
55 |
+
calling ``self.transform``, by default []. In general this is
|
56 |
+
set automatically, and you won't need to manipulate this argument.
|
57 |
+
name : str, optional
|
58 |
+
Name of this transform, used to identify it in the dictionary
|
59 |
+
produced by ``self.instantiate``, by default None
|
60 |
+
prob : float, optional
|
61 |
+
Probability of applying this transform, by default 1.0
|
62 |
+
|
63 |
+
Examples
|
64 |
+
--------
|
65 |
+
|
66 |
+
>>> seed = 0
|
67 |
+
>>>
|
68 |
+
>>> audio_path = "tests/audio/spk/f10_script4_produced.wav"
|
69 |
+
>>> signal = AudioSignal(audio_path, offset=10, duration=2)
|
70 |
+
>>> transform = tfm.Compose(
|
71 |
+
>>> [
|
72 |
+
>>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]),
|
73 |
+
>>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]),
|
74 |
+
>>> ],
|
75 |
+
>>> )
|
76 |
+
>>>
|
77 |
+
>>> kwargs = transform.instantiate(seed, signal)
|
78 |
+
>>> output = transform(signal, **kwargs)
|
79 |
+
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(self, keys: list = [], name: str = None, prob: float = 1.0):
|
83 |
+
# Get keys from the _transform signature.
|
84 |
+
tfm_keys = list(signature(self._transform).parameters.keys())
|
85 |
+
|
86 |
+
# Filter out signal and kwargs keys.
|
87 |
+
ignore_keys = ["signal", "kwargs"]
|
88 |
+
tfm_keys = [k for k in tfm_keys if k not in ignore_keys]
|
89 |
+
|
90 |
+
# Combine keys specified by the child class, the keys found in
|
91 |
+
# _transform signature, and the mask key.
|
92 |
+
self.keys = keys + tfm_keys + ["mask"]
|
93 |
+
|
94 |
+
self.prob = prob
|
95 |
+
|
96 |
+
if name is None:
|
97 |
+
name = self.__class__.__name__
|
98 |
+
self.name = name
|
99 |
+
|
100 |
+
def _prepare(self, batch: dict):
|
101 |
+
sub_batch = batch[self.name]
|
102 |
+
|
103 |
+
for k in self.keys:
|
104 |
+
assert k in sub_batch.keys(), f"{k} not in batch"
|
105 |
+
|
106 |
+
return sub_batch
|
107 |
+
|
108 |
+
def _transform(self, signal):
|
109 |
+
return signal
|
110 |
+
|
111 |
+
def _instantiate(self, state: RandomState, signal: AudioSignal = None):
|
112 |
+
return {}
|
113 |
+
|
114 |
+
@staticmethod
|
115 |
+
def apply_mask(batch: dict, mask: torch.Tensor):
|
116 |
+
"""Applies a mask to the batch.
|
117 |
+
|
118 |
+
Parameters
|
119 |
+
----------
|
120 |
+
batch : dict
|
121 |
+
Batch whose values will be masked in the ``transform`` pass.
|
122 |
+
mask : torch.Tensor
|
123 |
+
Mask to apply to batch.
|
124 |
+
|
125 |
+
Returns
|
126 |
+
-------
|
127 |
+
dict
|
128 |
+
A dictionary that contains values only where ``mask = True``.
|
129 |
+
"""
|
130 |
+
masked_batch = {k: v[mask] for k, v in flatten(batch).items()}
|
131 |
+
return unflatten(masked_batch)
|
132 |
+
|
133 |
+
def transform(self, signal: AudioSignal, **kwargs):
|
134 |
+
"""Apply the transform to the audio signal,
|
135 |
+
with given keyword arguments.
|
136 |
+
|
137 |
+
Parameters
|
138 |
+
----------
|
139 |
+
signal : AudioSignal
|
140 |
+
Signal that will be modified by the transforms in-place.
|
141 |
+
kwargs: dict
|
142 |
+
Keyword arguments to the specific transforms ``self._transform``
|
143 |
+
function.
|
144 |
+
|
145 |
+
Returns
|
146 |
+
-------
|
147 |
+
AudioSignal
|
148 |
+
Transformed AudioSignal.
|
149 |
+
|
150 |
+
Examples
|
151 |
+
--------
|
152 |
+
|
153 |
+
>>> for seed in range(10):
|
154 |
+
>>> kwargs = transform.instantiate(seed, signal)
|
155 |
+
>>> output = transform(signal.clone(), **kwargs)
|
156 |
+
|
157 |
+
"""
|
158 |
+
tfm_kwargs = self._prepare(kwargs)
|
159 |
+
mask = tfm_kwargs["mask"]
|
160 |
+
|
161 |
+
if torch.any(mask):
|
162 |
+
tfm_kwargs = self.apply_mask(tfm_kwargs, mask)
|
163 |
+
tfm_kwargs = {k: v for k, v in tfm_kwargs.items() if k != "mask"}
|
164 |
+
signal[mask] = self._transform(signal[mask], **tfm_kwargs)
|
165 |
+
|
166 |
+
return signal
|
167 |
+
|
168 |
+
def __call__(self, *args, **kwargs):
|
169 |
+
return self.transform(*args, **kwargs)
|
170 |
+
|
171 |
+
def instantiate(
|
172 |
+
self,
|
173 |
+
state: RandomState = None,
|
174 |
+
signal: AudioSignal = None,
|
175 |
+
):
|
176 |
+
"""Instantiates parameters for the transform.
|
177 |
+
|
178 |
+
Parameters
|
179 |
+
----------
|
180 |
+
state : RandomState, optional
|
181 |
+
_description_, by default None
|
182 |
+
signal : AudioSignal, optional
|
183 |
+
_description_, by default None
|
184 |
+
|
185 |
+
Returns
|
186 |
+
-------
|
187 |
+
dict
|
188 |
+
Dictionary containing instantiated arguments for every keyword
|
189 |
+
argument to ``self._transform``.
|
190 |
+
|
191 |
+
Examples
|
192 |
+
--------
|
193 |
+
|
194 |
+
>>> for seed in range(10):
|
195 |
+
>>> kwargs = transform.instantiate(seed, signal)
|
196 |
+
>>> output = transform(signal.clone(), **kwargs)
|
197 |
+
|
198 |
+
"""
|
199 |
+
state = util.random_state(state)
|
200 |
+
|
201 |
+
# Not all instantiates need the signal. Check if signal
|
202 |
+
# is needed before passing it in, so that the end-user
|
203 |
+
# doesn't need to have variables they're not using flowing
|
204 |
+
# into their function.
|
205 |
+
needs_signal = "signal" in set(signature(self._instantiate).parameters.keys())
|
206 |
+
kwargs = {}
|
207 |
+
if needs_signal:
|
208 |
+
kwargs = {"signal": signal}
|
209 |
+
|
210 |
+
# Instantiate the parameters for the transform.
|
211 |
+
params = self._instantiate(state, **kwargs)
|
212 |
+
for k in list(params.keys()):
|
213 |
+
v = params[k]
|
214 |
+
if isinstance(v, (AudioSignal, torch.Tensor, dict)):
|
215 |
+
params[k] = v
|
216 |
+
else:
|
217 |
+
params[k] = tt(v)
|
218 |
+
mask = state.rand() <= self.prob
|
219 |
+
params[f"mask"] = tt(mask)
|
220 |
+
|
221 |
+
# Put the params into a nested dictionary that will be
|
222 |
+
# used later when calling the transform. This is to avoid
|
223 |
+
# collisions in the dictionary.
|
224 |
+
params = {self.name: params}
|
225 |
+
|
226 |
+
return params
|
227 |
+
|
228 |
+
def batch_instantiate(
|
229 |
+
self,
|
230 |
+
states: list = None,
|
231 |
+
signal: AudioSignal = None,
|
232 |
+
):
|
233 |
+
"""Instantiates arguments for every item in a batch,
|
234 |
+
given a list of states. Each state in the list
|
235 |
+
corresponds to one item in the batch.
|
236 |
+
|
237 |
+
Parameters
|
238 |
+
----------
|
239 |
+
states : list, optional
|
240 |
+
List of states, by default None
|
241 |
+
signal : AudioSignal, optional
|
242 |
+
AudioSignal to pass to the ``self.instantiate`` section
|
243 |
+
if it is needed for this transform, by default None
|
244 |
+
|
245 |
+
Returns
|
246 |
+
-------
|
247 |
+
dict
|
248 |
+
Collated dictionary of arguments.
|
249 |
+
|
250 |
+
Examples
|
251 |
+
--------
|
252 |
+
|
253 |
+
>>> batch_size = 4
|
254 |
+
>>> signal = AudioSignal(audio_path, offset=10, duration=2)
|
255 |
+
>>> signal_batch = AudioSignal.batch([signal.clone() for _ in range(batch_size)])
|
256 |
+
>>>
|
257 |
+
>>> states = [seed + idx for idx in list(range(batch_size))]
|
258 |
+
>>> kwargs = transform.batch_instantiate(states, signal_batch)
|
259 |
+
>>> batch_output = transform(signal_batch, **kwargs)
|
260 |
+
"""
|
261 |
+
kwargs = []
|
262 |
+
for state in states:
|
263 |
+
kwargs.append(self.instantiate(state, signal))
|
264 |
+
kwargs = util.collate(kwargs)
|
265 |
+
return kwargs
|
266 |
+
|
267 |
+
|
268 |
+
class Identity(BaseTransform):
|
269 |
+
"""This transform just returns the original signal."""
|
270 |
+
|
271 |
+
pass
|
272 |
+
|
273 |
+
|
274 |
+
class SpectralTransform(BaseTransform):
|
275 |
+
"""Spectral transforms require STFT data to exist, since manipulations
|
276 |
+
of the STFT require the spectrogram. This just calls ``stft`` before
|
277 |
+
the transform is called, and calls ``istft`` after the transform is
|
278 |
+
called so that the audio data is written to after the spectral
|
279 |
+
manipulation.
|
280 |
+
"""
|
281 |
+
|
282 |
+
def transform(self, signal, **kwargs):
|
283 |
+
signal.stft()
|
284 |
+
super().transform(signal, **kwargs)
|
285 |
+
signal.istft()
|
286 |
+
return signal
|
287 |
+
|
288 |
+
|
289 |
+
class Compose(BaseTransform):
|
290 |
+
"""Compose applies transforms in sequence, one after the other. The
|
291 |
+
transforms are passed in as positional arguments or as a list like so:
|
292 |
+
|
293 |
+
>>> transform = tfm.Compose(
|
294 |
+
>>> [
|
295 |
+
>>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]),
|
296 |
+
>>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]),
|
297 |
+
>>> ],
|
298 |
+
>>> )
|
299 |
+
|
300 |
+
This will convolve the signal with a room impulse response, and then
|
301 |
+
add background noise to the signal. Instantiate instantiates
|
302 |
+
all the parameters for every transform in the transform list so the
|
303 |
+
interface for using the Compose transform is the same as everything
|
304 |
+
else:
|
305 |
+
|
306 |
+
>>> kwargs = transform.instantiate()
|
307 |
+
>>> output = transform(signal.clone(), **kwargs)
|
308 |
+
|
309 |
+
Under the hood, the transform maps each transform to a unique name
|
310 |
+
under the hood of the form ``{position}.{name}``, where ``position``
|
311 |
+
is the index of the transform in the list. ``Compose`` can nest
|
312 |
+
within other ``Compose`` transforms, like so:
|
313 |
+
|
314 |
+
>>> preprocess = transforms.Compose(
|
315 |
+
>>> tfm.GlobalVolumeNorm(),
|
316 |
+
>>> tfm.CrossTalk(),
|
317 |
+
>>> name="preprocess",
|
318 |
+
>>> )
|
319 |
+
>>> augment = transforms.Compose(
|
320 |
+
>>> tfm.RoomImpulseResponse(),
|
321 |
+
>>> tfm.BackgroundNoise(),
|
322 |
+
>>> name="augment",
|
323 |
+
>>> )
|
324 |
+
>>> postprocess = transforms.Compose(
|
325 |
+
>>> tfm.VolumeChange(),
|
326 |
+
>>> tfm.RescaleAudio(),
|
327 |
+
>>> tfm.ShiftPhase(),
|
328 |
+
>>> name="postprocess",
|
329 |
+
>>> )
|
330 |
+
>>> transform = transforms.Compose(preprocess, augment, postprocess),
|
331 |
+
|
332 |
+
This defines 3 composed transforms, and then composes them in sequence
|
333 |
+
with one another.
|
334 |
+
|
335 |
+
Parameters
|
336 |
+
----------
|
337 |
+
*transforms : list
|
338 |
+
List of transforms to apply
|
339 |
+
name : str, optional
|
340 |
+
Name of this transform, used to identify it in the dictionary
|
341 |
+
produced by ``self.instantiate``, by default None
|
342 |
+
prob : float, optional
|
343 |
+
Probability of applying this transform, by default 1.0
|
344 |
+
"""
|
345 |
+
|
346 |
+
def __init__(self, *transforms: list, name: str = None, prob: float = 1.0):
|
347 |
+
if isinstance(transforms[0], list):
|
348 |
+
transforms = transforms[0]
|
349 |
+
|
350 |
+
for i, tfm in enumerate(transforms):
|
351 |
+
tfm.name = f"{i}.{tfm.name}"
|
352 |
+
|
353 |
+
keys = [tfm.name for tfm in transforms]
|
354 |
+
super().__init__(keys=keys, name=name, prob=prob)
|
355 |
+
|
356 |
+
self.transforms = transforms
|
357 |
+
self.transforms_to_apply = keys
|
358 |
+
|
359 |
+
@contextmanager
|
360 |
+
def filter(self, *names: list):
|
361 |
+
"""This can be used to skip transforms entirely when applying
|
362 |
+
the sequence of transforms to a signal. For example, take
|
363 |
+
the following transforms with the names ``preprocess, augment, postprocess``.
|
364 |
+
|
365 |
+
>>> preprocess = transforms.Compose(
|
366 |
+
>>> tfm.GlobalVolumeNorm(),
|
367 |
+
>>> tfm.CrossTalk(),
|
368 |
+
>>> name="preprocess",
|
369 |
+
>>> )
|
370 |
+
>>> augment = transforms.Compose(
|
371 |
+
>>> tfm.RoomImpulseResponse(),
|
372 |
+
>>> tfm.BackgroundNoise(),
|
373 |
+
>>> name="augment",
|
374 |
+
>>> )
|
375 |
+
>>> postprocess = transforms.Compose(
|
376 |
+
>>> tfm.VolumeChange(),
|
377 |
+
>>> tfm.RescaleAudio(),
|
378 |
+
>>> tfm.ShiftPhase(),
|
379 |
+
>>> name="postprocess",
|
380 |
+
>>> )
|
381 |
+
>>> transform = transforms.Compose(preprocess, augment, postprocess)
|
382 |
+
|
383 |
+
If we wanted to apply all 3 to a signal, we do:
|
384 |
+
|
385 |
+
>>> kwargs = transform.instantiate()
|
386 |
+
>>> output = transform(signal.clone(), **kwargs)
|
387 |
+
|
388 |
+
But if we only wanted to apply the ``preprocess`` and ``postprocess``
|
389 |
+
transforms to the signal, we do:
|
390 |
+
|
391 |
+
>>> with transform_fn.filter("preprocess", "postprocess"):
|
392 |
+
>>> output = transform(signal.clone(), **kwargs)
|
393 |
+
|
394 |
+
Parameters
|
395 |
+
----------
|
396 |
+
*names : list
|
397 |
+
List of transforms, identified by name, to apply to signal.
|
398 |
+
"""
|
399 |
+
old_transforms = self.transforms_to_apply
|
400 |
+
self.transforms_to_apply = names
|
401 |
+
yield
|
402 |
+
self.transforms_to_apply = old_transforms
|
403 |
+
|
404 |
+
def _transform(self, signal, **kwargs):
|
405 |
+
for transform in self.transforms:
|
406 |
+
if any([x in transform.name for x in self.transforms_to_apply]):
|
407 |
+
signal = transform(signal, **kwargs)
|
408 |
+
return signal
|
409 |
+
|
410 |
+
def _instantiate(self, state: RandomState, signal: AudioSignal = None):
|
411 |
+
parameters = {}
|
412 |
+
for transform in self.transforms:
|
413 |
+
parameters.update(transform.instantiate(state, signal=signal))
|
414 |
+
return parameters
|
415 |
+
|
416 |
+
def __getitem__(self, idx):
|
417 |
+
return self.transforms[idx]
|
418 |
+
|
419 |
+
def __len__(self):
|
420 |
+
return len(self.transforms)
|
421 |
+
|
422 |
+
def __iter__(self):
|
423 |
+
for transform in self.transforms:
|
424 |
+
yield transform
|
425 |
+
|
426 |
+
|
427 |
+
class Choose(Compose):
|
428 |
+
"""Choose logic is the same as :py:func:`audiotools.data.transforms.Compose`,
|
429 |
+
but instead of applying all the transforms in sequence, it applies just a single transform,
|
430 |
+
which is chosen for each item in the batch.
|
431 |
+
|
432 |
+
Parameters
|
433 |
+
----------
|
434 |
+
*transforms : list
|
435 |
+
List of transforms to apply
|
436 |
+
weights : list
|
437 |
+
Probability of choosing any specific transform.
|
438 |
+
name : str, optional
|
439 |
+
Name of this transform, used to identify it in the dictionary
|
440 |
+
produced by ``self.instantiate``, by default None
|
441 |
+
prob : float, optional
|
442 |
+
Probability of applying this transform, by default 1.0
|
443 |
+
|
444 |
+
Examples
|
445 |
+
--------
|
446 |
+
|
447 |
+
>>> transforms.Choose(tfm.LowPass(), tfm.HighPass())
|
448 |
+
"""
|
449 |
+
|
450 |
+
def __init__(
|
451 |
+
self,
|
452 |
+
*transforms: list,
|
453 |
+
weights: list = None,
|
454 |
+
name: str = None,
|
455 |
+
prob: float = 1.0,
|
456 |
+
):
|
457 |
+
super().__init__(*transforms, name=name, prob=prob)
|
458 |
+
|
459 |
+
if weights is None:
|
460 |
+
_len = len(self.transforms)
|
461 |
+
weights = [1 / _len for _ in range(_len)]
|
462 |
+
self.weights = np.array(weights)
|
463 |
+
|
464 |
+
def _instantiate(self, state: RandomState, signal: AudioSignal = None):
|
465 |
+
kwargs = super()._instantiate(state, signal)
|
466 |
+
tfm_idx = list(range(len(self.transforms)))
|
467 |
+
tfm_idx = state.choice(tfm_idx, p=self.weights)
|
468 |
+
one_hot = []
|
469 |
+
for i, t in enumerate(self.transforms):
|
470 |
+
mask = kwargs[t.name]["mask"]
|
471 |
+
if mask.item():
|
472 |
+
kwargs[t.name]["mask"] = tt(i == tfm_idx)
|
473 |
+
one_hot.append(kwargs[t.name]["mask"])
|
474 |
+
kwargs["one_hot"] = one_hot
|
475 |
+
return kwargs
|
476 |
+
|
477 |
+
|
478 |
+
class Repeat(Compose):
|
479 |
+
"""Repeatedly applies a given transform ``n_repeat`` times."
|
480 |
+
|
481 |
+
Parameters
|
482 |
+
----------
|
483 |
+
transform : BaseTransform
|
484 |
+
Transform to repeat.
|
485 |
+
n_repeat : int, optional
|
486 |
+
Number of times to repeat transform, by default 1
|
487 |
+
"""
|
488 |
+
|
489 |
+
def __init__(
|
490 |
+
self,
|
491 |
+
transform,
|
492 |
+
n_repeat: int = 1,
|
493 |
+
name: str = None,
|
494 |
+
prob: float = 1.0,
|
495 |
+
):
|
496 |
+
transforms = [copy.copy(transform) for _ in range(n_repeat)]
|
497 |
+
super().__init__(transforms, name=name, prob=prob)
|
498 |
+
|
499 |
+
self.n_repeat = n_repeat
|
500 |
+
|
501 |
+
|
502 |
+
class RepeatUpTo(Choose):
|
503 |
+
"""Repeatedly applies a given transform up to ``max_repeat`` times."
|
504 |
+
|
505 |
+
Parameters
|
506 |
+
----------
|
507 |
+
transform : BaseTransform
|
508 |
+
Transform to repeat.
|
509 |
+
max_repeat : int, optional
|
510 |
+
Max number of times to repeat transform, by default 1
|
511 |
+
weights : list
|
512 |
+
Probability of choosing any specific number up to ``max_repeat``.
|
513 |
+
"""
|
514 |
+
|
515 |
+
def __init__(
|
516 |
+
self,
|
517 |
+
transform,
|
518 |
+
max_repeat: int = 5,
|
519 |
+
weights: list = None,
|
520 |
+
name: str = None,
|
521 |
+
prob: float = 1.0,
|
522 |
+
):
|
523 |
+
transforms = []
|
524 |
+
for n in range(1, max_repeat):
|
525 |
+
transforms.append(Repeat(transform, n_repeat=n))
|
526 |
+
super().__init__(transforms, name=name, prob=prob, weights=weights)
|
527 |
+
|
528 |
+
self.max_repeat = max_repeat
|
529 |
+
|
530 |
+
|
531 |
+
class ClippingDistortion(BaseTransform):
|
532 |
+
"""Adds clipping distortion to signal. Corresponds
|
533 |
+
to :py:func:`audiotools.core.effects.EffectMixin.clip_distortion`.
|
534 |
+
|
535 |
+
Parameters
|
536 |
+
----------
|
537 |
+
perc : tuple, optional
|
538 |
+
Clipping percentile. Values are between 0.0 to 1.0.
|
539 |
+
Typical values are 0.1 or below, by default ("uniform", 0.0, 0.1)
|
540 |
+
name : str, optional
|
541 |
+
Name of this transform, used to identify it in the dictionary
|
542 |
+
produced by ``self.instantiate``, by default None
|
543 |
+
prob : float, optional
|
544 |
+
Probability of applying this transform, by default 1.0
|
545 |
+
"""
|
546 |
+
|
547 |
+
def __init__(
|
548 |
+
self,
|
549 |
+
perc: tuple = ("uniform", 0.0, 0.1),
|
550 |
+
name: str = None,
|
551 |
+
prob: float = 1.0,
|
552 |
+
):
|
553 |
+
super().__init__(name=name, prob=prob)
|
554 |
+
|
555 |
+
self.perc = perc
|
556 |
+
|
557 |
+
def _instantiate(self, state: RandomState):
|
558 |
+
return {"perc": util.sample_from_dist(self.perc, state)}
|
559 |
+
|
560 |
+
def _transform(self, signal, perc):
|
561 |
+
return signal.clip_distortion(perc)
|
562 |
+
|
563 |
+
|
564 |
+
class Equalizer(BaseTransform):
|
565 |
+
"""Applies an equalization curve to the audio signal. Corresponds
|
566 |
+
to :py:func:`audiotools.core.effects.EffectMixin.equalizer`.
|
567 |
+
|
568 |
+
Parameters
|
569 |
+
----------
|
570 |
+
eq_amount : tuple, optional
|
571 |
+
The maximum dB cut to apply to the audio in any band,
|
572 |
+
by default ("const", 1.0 dB)
|
573 |
+
n_bands : int, optional
|
574 |
+
Number of bands in EQ, by default 6
|
575 |
+
name : str, optional
|
576 |
+
Name of this transform, used to identify it in the dictionary
|
577 |
+
produced by ``self.instantiate``, by default None
|
578 |
+
prob : float, optional
|
579 |
+
Probability of applying this transform, by default 1.0
|
580 |
+
"""
|
581 |
+
|
582 |
+
def __init__(
|
583 |
+
self,
|
584 |
+
eq_amount: tuple = ("const", 1.0),
|
585 |
+
n_bands: int = 6,
|
586 |
+
name: str = None,
|
587 |
+
prob: float = 1.0,
|
588 |
+
):
|
589 |
+
super().__init__(name=name, prob=prob)
|
590 |
+
|
591 |
+
self.eq_amount = eq_amount
|
592 |
+
self.n_bands = n_bands
|
593 |
+
|
594 |
+
def _instantiate(self, state: RandomState):
|
595 |
+
eq_amount = util.sample_from_dist(self.eq_amount, state)
|
596 |
+
eq = -eq_amount * state.rand(self.n_bands)
|
597 |
+
return {"eq": eq}
|
598 |
+
|
599 |
+
def _transform(self, signal, eq):
|
600 |
+
return signal.equalizer(eq)
|
601 |
+
|
602 |
+
|
603 |
+
class Quantization(BaseTransform):
|
604 |
+
"""Applies quantization to the input waveform. Corresponds
|
605 |
+
to :py:func:`audiotools.core.effects.EffectMixin.quantization`.
|
606 |
+
|
607 |
+
Parameters
|
608 |
+
----------
|
609 |
+
channels : tuple, optional
|
610 |
+
Number of evenly spaced quantization channels to quantize
|
611 |
+
to, by default ("choice", [8, 32, 128, 256, 1024])
|
612 |
+
name : str, optional
|
613 |
+
Name of this transform, used to identify it in the dictionary
|
614 |
+
produced by ``self.instantiate``, by default None
|
615 |
+
prob : float, optional
|
616 |
+
Probability of applying this transform, by default 1.0
|
617 |
+
"""
|
618 |
+
|
619 |
+
def __init__(
|
620 |
+
self,
|
621 |
+
channels: tuple = ("choice", [8, 32, 128, 256, 1024]),
|
622 |
+
name: str = None,
|
623 |
+
prob: float = 1.0,
|
624 |
+
):
|
625 |
+
super().__init__(name=name, prob=prob)
|
626 |
+
|
627 |
+
self.channels = channels
|
628 |
+
|
629 |
+
def _instantiate(self, state: RandomState):
|
630 |
+
return {"channels": util.sample_from_dist(self.channels, state)}
|
631 |
+
|
632 |
+
def _transform(self, signal, channels):
|
633 |
+
return signal.quantization(channels)
|
634 |
+
|
635 |
+
|
636 |
+
class MuLawQuantization(BaseTransform):
|
637 |
+
"""Applies mu-law quantization to the input waveform. Corresponds
|
638 |
+
to :py:func:`audiotools.core.effects.EffectMixin.mulaw_quantization`.
|
639 |
+
|
640 |
+
Parameters
|
641 |
+
----------
|
642 |
+
channels : tuple, optional
|
643 |
+
Number of mu-law spaced quantization channels to quantize
|
644 |
+
to, by default ("choice", [8, 32, 128, 256, 1024])
|
645 |
+
name : str, optional
|
646 |
+
Name of this transform, used to identify it in the dictionary
|
647 |
+
produced by ``self.instantiate``, by default None
|
648 |
+
prob : float, optional
|
649 |
+
Probability of applying this transform, by default 1.0
|
650 |
+
"""
|
651 |
+
|
652 |
+
def __init__(
|
653 |
+
self,
|
654 |
+
channels: tuple = ("choice", [8, 32, 128, 256, 1024]),
|
655 |
+
name: str = None,
|
656 |
+
prob: float = 1.0,
|
657 |
+
):
|
658 |
+
super().__init__(name=name, prob=prob)
|
659 |
+
|
660 |
+
self.channels = channels
|
661 |
+
|
662 |
+
def _instantiate(self, state: RandomState):
|
663 |
+
return {"channels": util.sample_from_dist(self.channels, state)}
|
664 |
+
|
665 |
+
def _transform(self, signal, channels):
|
666 |
+
return signal.mulaw_quantization(channels)
|
667 |
+
|
668 |
+
|
669 |
+
class NoiseFloor(BaseTransform):
|
670 |
+
"""Adds a noise floor of Gaussian noise to the signal at a specified
|
671 |
+
dB.
|
672 |
+
|
673 |
+
Parameters
|
674 |
+
----------
|
675 |
+
db : tuple, optional
|
676 |
+
Level of noise to add to signal, by default ("const", -50.0)
|
677 |
+
name : str, optional
|
678 |
+
Name of this transform, used to identify it in the dictionary
|
679 |
+
produced by ``self.instantiate``, by default None
|
680 |
+
prob : float, optional
|
681 |
+
Probability of applying this transform, by default 1.0
|
682 |
+
"""
|
683 |
+
|
684 |
+
def __init__(
|
685 |
+
self,
|
686 |
+
db: tuple = ("const", -50.0),
|
687 |
+
name: str = None,
|
688 |
+
prob: float = 1.0,
|
689 |
+
):
|
690 |
+
super().__init__(name=name, prob=prob)
|
691 |
+
|
692 |
+
self.db = db
|
693 |
+
|
694 |
+
def _instantiate(self, state: RandomState, signal: AudioSignal):
|
695 |
+
db = util.sample_from_dist(self.db, state)
|
696 |
+
audio_data = state.randn(signal.num_channels, signal.signal_length)
|
697 |
+
nz_signal = AudioSignal(audio_data, signal.sample_rate)
|
698 |
+
nz_signal.normalize(db)
|
699 |
+
return {"nz_signal": nz_signal}
|
700 |
+
|
701 |
+
def _transform(self, signal, nz_signal):
|
702 |
+
# Clone bg_signal so that transform can be repeatedly applied
|
703 |
+
# to different signals with the same effect.
|
704 |
+
return signal + nz_signal
|
705 |
+
|
706 |
+
|
707 |
+
class BackgroundNoise(BaseTransform):
|
708 |
+
"""Adds background noise from audio specified by a set of CSV files.
|
709 |
+
A valid CSV file looks like, and is typically generated by
|
710 |
+
:py:func:`audiotools.data.preprocess.create_csv`:
|
711 |
+
|
712 |
+
.. csv-table::
|
713 |
+
:header: path
|
714 |
+
|
715 |
+
room_tone/m6_script2_clean.wav
|
716 |
+
room_tone/m6_script2_cleanraw.wav
|
717 |
+
room_tone/m6_script2_ipad_balcony1.wav
|
718 |
+
room_tone/m6_script2_ipad_bedroom1.wav
|
719 |
+
room_tone/m6_script2_ipad_confroom1.wav
|
720 |
+
room_tone/m6_script2_ipad_confroom2.wav
|
721 |
+
room_tone/m6_script2_ipad_livingroom1.wav
|
722 |
+
room_tone/m6_script2_ipad_office1.wav
|
723 |
+
|
724 |
+
.. note::
|
725 |
+
All paths are relative to an environment variable called ``PATH_TO_DATA``,
|
726 |
+
so that CSV files are portable across machines where data may be
|
727 |
+
located in different places.
|
728 |
+
|
729 |
+
This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix`
|
730 |
+
and :py:func:`audiotools.core.effects.EffectMixin.equalizer` under the
|
731 |
+
hood.
|
732 |
+
|
733 |
+
Parameters
|
734 |
+
----------
|
735 |
+
snr : tuple, optional
|
736 |
+
Signal-to-noise ratio, by default ("uniform", 10.0, 30.0)
|
737 |
+
sources : List[str], optional
|
738 |
+
Sources containing folders, or CSVs with paths to audio files,
|
739 |
+
by default None
|
740 |
+
weights : List[float], optional
|
741 |
+
Weights to sample audio files from each source, by default None
|
742 |
+
eq_amount : tuple, optional
|
743 |
+
Amount of equalization to apply, by default ("const", 1.0)
|
744 |
+
n_bands : int, optional
|
745 |
+
Number of bands in equalizer, by default 3
|
746 |
+
name : str, optional
|
747 |
+
Name of this transform, used to identify it in the dictionary
|
748 |
+
produced by ``self.instantiate``, by default None
|
749 |
+
prob : float, optional
|
750 |
+
Probability of applying this transform, by default 1.0
|
751 |
+
loudness_cutoff : float, optional
|
752 |
+
Loudness cutoff when loading from audio files, by default None
|
753 |
+
"""
|
754 |
+
|
755 |
+
def __init__(
|
756 |
+
self,
|
757 |
+
snr: tuple = ("uniform", 10.0, 30.0),
|
758 |
+
sources: List[str] = None,
|
759 |
+
weights: List[float] = None,
|
760 |
+
eq_amount: tuple = ("const", 1.0),
|
761 |
+
n_bands: int = 3,
|
762 |
+
name: str = None,
|
763 |
+
prob: float = 1.0,
|
764 |
+
loudness_cutoff: float = None,
|
765 |
+
):
|
766 |
+
super().__init__(name=name, prob=prob)
|
767 |
+
|
768 |
+
self.snr = snr
|
769 |
+
self.eq_amount = eq_amount
|
770 |
+
self.n_bands = n_bands
|
771 |
+
self.loader = AudioLoader(sources, weights)
|
772 |
+
self.loudness_cutoff = loudness_cutoff
|
773 |
+
|
774 |
+
def _instantiate(self, state: RandomState, signal: AudioSignal):
|
775 |
+
eq_amount = util.sample_from_dist(self.eq_amount, state)
|
776 |
+
eq = -eq_amount * state.rand(self.n_bands)
|
777 |
+
snr = util.sample_from_dist(self.snr, state)
|
778 |
+
|
779 |
+
bg_signal = self.loader(
|
780 |
+
state,
|
781 |
+
signal.sample_rate,
|
782 |
+
duration=signal.signal_duration,
|
783 |
+
loudness_cutoff=self.loudness_cutoff,
|
784 |
+
num_channels=signal.num_channels,
|
785 |
+
)["signal"]
|
786 |
+
|
787 |
+
return {"eq": eq, "bg_signal": bg_signal, "snr": snr}
|
788 |
+
|
789 |
+
def _transform(self, signal, bg_signal, snr, eq):
|
790 |
+
# Clone bg_signal so that transform can be repeatedly applied
|
791 |
+
# to different signals with the same effect.
|
792 |
+
return signal.mix(bg_signal.clone(), snr, eq)
|
793 |
+
|
794 |
+
|
795 |
+
class CrossTalk(BaseTransform):
|
796 |
+
"""Adds crosstalk between speakers, whose audio is drawn from a CSV file
|
797 |
+
that was produced via :py:func:`audiotools.data.preprocess.create_csv`.
|
798 |
+
|
799 |
+
This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix`
|
800 |
+
under the hood.
|
801 |
+
|
802 |
+
Parameters
|
803 |
+
----------
|
804 |
+
snr : tuple, optional
|
805 |
+
How loud cross-talk speaker is relative to original signal in dB,
|
806 |
+
by default ("uniform", 0.0, 10.0)
|
807 |
+
sources : List[str], optional
|
808 |
+
Sources containing folders, or CSVs with paths to audio files,
|
809 |
+
by default None
|
810 |
+
weights : List[float], optional
|
811 |
+
Weights to sample audio files from each source, by default None
|
812 |
+
name : str, optional
|
813 |
+
Name of this transform, used to identify it in the dictionary
|
814 |
+
produced by ``self.instantiate``, by default None
|
815 |
+
prob : float, optional
|
816 |
+
Probability of applying this transform, by default 1.0
|
817 |
+
loudness_cutoff : float, optional
|
818 |
+
Loudness cutoff when loading from audio files, by default -40
|
819 |
+
"""
|
820 |
+
|
821 |
+
def __init__(
|
822 |
+
self,
|
823 |
+
snr: tuple = ("uniform", 0.0, 10.0),
|
824 |
+
sources: List[str] = None,
|
825 |
+
weights: List[float] = None,
|
826 |
+
name: str = None,
|
827 |
+
prob: float = 1.0,
|
828 |
+
loudness_cutoff: float = -40,
|
829 |
+
):
|
830 |
+
super().__init__(name=name, prob=prob)
|
831 |
+
|
832 |
+
self.snr = snr
|
833 |
+
self.loader = AudioLoader(sources, weights)
|
834 |
+
self.loudness_cutoff = loudness_cutoff
|
835 |
+
|
836 |
+
def _instantiate(self, state: RandomState, signal: AudioSignal):
|
837 |
+
snr = util.sample_from_dist(self.snr, state)
|
838 |
+
crosstalk_signal = self.loader(
|
839 |
+
state,
|
840 |
+
signal.sample_rate,
|
841 |
+
duration=signal.signal_duration,
|
842 |
+
loudness_cutoff=self.loudness_cutoff,
|
843 |
+
num_channels=signal.num_channels,
|
844 |
+
)["signal"]
|
845 |
+
|
846 |
+
return {"crosstalk_signal": crosstalk_signal, "snr": snr}
|
847 |
+
|
848 |
+
def _transform(self, signal, crosstalk_signal, snr):
|
849 |
+
# Clone bg_signal so that transform can be repeatedly applied
|
850 |
+
# to different signals with the same effect.
|
851 |
+
loudness = signal.loudness()
|
852 |
+
mix = signal.mix(crosstalk_signal.clone(), snr)
|
853 |
+
mix.normalize(loudness)
|
854 |
+
return mix
|
855 |
+
|
856 |
+
|
857 |
+
class RoomImpulseResponse(BaseTransform):
|
858 |
+
"""Convolves signal with a room impulse response, at a specified
|
859 |
+
direct-to-reverberant ratio, with equalization applied. Room impulse
|
860 |
+
response data is drawn from a CSV file that was produced via
|
861 |
+
:py:func:`audiotools.data.preprocess.create_csv`.
|
862 |
+
|
863 |
+
This transform calls :py:func:`audiotools.core.effects.EffectMixin.apply_ir`
|
864 |
+
under the hood.
|
865 |
+
|
866 |
+
Parameters
|
867 |
+
----------
|
868 |
+
drr : tuple, optional
|
869 |
+
_description_, by default ("uniform", 0.0, 30.0)
|
870 |
+
sources : List[str], optional
|
871 |
+
Sources containing folders, or CSVs with paths to audio files,
|
872 |
+
by default None
|
873 |
+
weights : List[float], optional
|
874 |
+
Weights to sample audio files from each source, by default None
|
875 |
+
eq_amount : tuple, optional
|
876 |
+
Amount of equalization to apply, by default ("const", 1.0)
|
877 |
+
n_bands : int, optional
|
878 |
+
Number of bands in equalizer, by default 6
|
879 |
+
name : str, optional
|
880 |
+
Name of this transform, used to identify it in the dictionary
|
881 |
+
produced by ``self.instantiate``, by default None
|
882 |
+
prob : float, optional
|
883 |
+
Probability of applying this transform, by default 1.0
|
884 |
+
use_original_phase : bool, optional
|
885 |
+
Whether or not to use the original phase, by default False
|
886 |
+
offset : float, optional
|
887 |
+
Offset from each impulse response file to use, by default 0.0
|
888 |
+
duration : float, optional
|
889 |
+
Duration of each impulse response, by default 1.0
|
890 |
+
"""
|
891 |
+
|
892 |
+
def __init__(
|
893 |
+
self,
|
894 |
+
drr: tuple = ("uniform", 0.0, 30.0),
|
895 |
+
sources: List[str] = None,
|
896 |
+
weights: List[float] = None,
|
897 |
+
eq_amount: tuple = ("const", 1.0),
|
898 |
+
n_bands: int = 6,
|
899 |
+
name: str = None,
|
900 |
+
prob: float = 1.0,
|
901 |
+
use_original_phase: bool = False,
|
902 |
+
offset: float = 0.0,
|
903 |
+
duration: float = 1.0,
|
904 |
+
):
|
905 |
+
super().__init__(name=name, prob=prob)
|
906 |
+
|
907 |
+
self.drr = drr
|
908 |
+
self.eq_amount = eq_amount
|
909 |
+
self.n_bands = n_bands
|
910 |
+
self.use_original_phase = use_original_phase
|
911 |
+
|
912 |
+
self.loader = AudioLoader(sources, weights)
|
913 |
+
self.offset = offset
|
914 |
+
self.duration = duration
|
915 |
+
|
916 |
+
def _instantiate(self, state: RandomState, signal: AudioSignal = None):
|
917 |
+
eq_amount = util.sample_from_dist(self.eq_amount, state)
|
918 |
+
eq = -eq_amount * state.rand(self.n_bands)
|
919 |
+
drr = util.sample_from_dist(self.drr, state)
|
920 |
+
|
921 |
+
ir_signal = self.loader(
|
922 |
+
state,
|
923 |
+
signal.sample_rate,
|
924 |
+
offset=self.offset,
|
925 |
+
duration=self.duration,
|
926 |
+
loudness_cutoff=None,
|
927 |
+
num_channels=signal.num_channels,
|
928 |
+
)["signal"]
|
929 |
+
ir_signal.zero_pad_to(signal.sample_rate)
|
930 |
+
|
931 |
+
return {"eq": eq, "ir_signal": ir_signal, "drr": drr}
|
932 |
+
|
933 |
+
def _transform(self, signal, ir_signal, drr, eq):
|
934 |
+
# Clone ir_signal so that transform can be repeatedly applied
|
935 |
+
# to different signals with the same effect.
|
936 |
+
return signal.apply_ir(
|
937 |
+
ir_signal.clone(), drr, eq, use_original_phase=self.use_original_phase
|
938 |
+
)
|
939 |
+
|
940 |
+
|
941 |
+
class VolumeChange(BaseTransform):
|
942 |
+
"""Changes the volume of the input signal.
|
943 |
+
|
944 |
+
Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`.
|
945 |
+
|
946 |
+
Parameters
|
947 |
+
----------
|
948 |
+
db : tuple, optional
|
949 |
+
Change in volume in decibels, by default ("uniform", -12.0, 0.0)
|
950 |
+
name : str, optional
|
951 |
+
Name of this transform, used to identify it in the dictionary
|
952 |
+
produced by ``self.instantiate``, by default None
|
953 |
+
prob : float, optional
|
954 |
+
Probability of applying this transform, by default 1.0
|
955 |
+
"""
|
956 |
+
|
957 |
+
def __init__(
|
958 |
+
self,
|
959 |
+
db: tuple = ("uniform", -12.0, 0.0),
|
960 |
+
name: str = None,
|
961 |
+
prob: float = 1.0,
|
962 |
+
):
|
963 |
+
super().__init__(name=name, prob=prob)
|
964 |
+
self.db = db
|
965 |
+
|
966 |
+
def _instantiate(self, state: RandomState):
|
967 |
+
return {"db": util.sample_from_dist(self.db, state)}
|
968 |
+
|
969 |
+
def _transform(self, signal, db):
|
970 |
+
return signal.volume_change(db)
|
971 |
+
|
972 |
+
|
973 |
+
class VolumeNorm(BaseTransform):
|
974 |
+
"""Normalizes the volume of the excerpt to a specified decibel.
|
975 |
+
|
976 |
+
Uses :py:func:`audiotools.core.effects.EffectMixin.normalize`.
|
977 |
+
|
978 |
+
Parameters
|
979 |
+
----------
|
980 |
+
db : tuple, optional
|
981 |
+
dB to normalize signal to, by default ("const", -24)
|
982 |
+
name : str, optional
|
983 |
+
Name of this transform, used to identify it in the dictionary
|
984 |
+
produced by ``self.instantiate``, by default None
|
985 |
+
prob : float, optional
|
986 |
+
Probability of applying this transform, by default 1.0
|
987 |
+
"""
|
988 |
+
|
989 |
+
def __init__(
|
990 |
+
self,
|
991 |
+
db: tuple = ("const", -24),
|
992 |
+
name: str = None,
|
993 |
+
prob: float = 1.0,
|
994 |
+
):
|
995 |
+
super().__init__(name=name, prob=prob)
|
996 |
+
|
997 |
+
self.db = db
|
998 |
+
|
999 |
+
def _instantiate(self, state: RandomState):
|
1000 |
+
return {"db": util.sample_from_dist(self.db, state)}
|
1001 |
+
|
1002 |
+
def _transform(self, signal, db):
|
1003 |
+
return signal.normalize(db)
|
1004 |
+
|
1005 |
+
|
1006 |
+
class GlobalVolumeNorm(BaseTransform):
|
1007 |
+
"""Similar to :py:func:`audiotools.data.transforms.VolumeNorm`, this
|
1008 |
+
transform also normalizes the volume of a signal, but it uses
|
1009 |
+
the volume of the entire audio file the loaded excerpt comes from,
|
1010 |
+
rather than the volume of just the excerpt. The volume of the
|
1011 |
+
entire audio file is expected in ``signal.metadata["loudness"]``.
|
1012 |
+
If loading audio from a CSV generated by :py:func:`audiotools.data.preprocess.create_csv`
|
1013 |
+
with ``loudness = True``, like the following:
|
1014 |
+
|
1015 |
+
.. csv-table::
|
1016 |
+
:header: path,loudness
|
1017 |
+
|
1018 |
+
daps/produced/f1_script1_produced.wav,-16.299999237060547
|
1019 |
+
daps/produced/f1_script2_produced.wav,-16.600000381469727
|
1020 |
+
daps/produced/f1_script3_produced.wav,-17.299999237060547
|
1021 |
+
daps/produced/f1_script4_produced.wav,-16.100000381469727
|
1022 |
+
daps/produced/f1_script5_produced.wav,-16.700000762939453
|
1023 |
+
daps/produced/f3_script1_produced.wav,-16.5
|
1024 |
+
|
1025 |
+
The ``AudioLoader`` will automatically load the loudness column into
|
1026 |
+
the metadata of the signal.
|
1027 |
+
|
1028 |
+
Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`.
|
1029 |
+
|
1030 |
+
Parameters
|
1031 |
+
----------
|
1032 |
+
db : tuple, optional
|
1033 |
+
dB to normalize signal to, by default ("const", -24)
|
1034 |
+
name : str, optional
|
1035 |
+
Name of this transform, used to identify it in the dictionary
|
1036 |
+
produced by ``self.instantiate``, by default None
|
1037 |
+
prob : float, optional
|
1038 |
+
Probability of applying this transform, by default 1.0
|
1039 |
+
"""
|
1040 |
+
|
1041 |
+
def __init__(
|
1042 |
+
self,
|
1043 |
+
db: tuple = ("const", -24),
|
1044 |
+
name: str = None,
|
1045 |
+
prob: float = 1.0,
|
1046 |
+
):
|
1047 |
+
super().__init__(name=name, prob=prob)
|
1048 |
+
|
1049 |
+
self.db = db
|
1050 |
+
|
1051 |
+
def _instantiate(self, state: RandomState, signal: AudioSignal):
|
1052 |
+
if "loudness" not in signal.metadata:
|
1053 |
+
db_change = 0.0
|
1054 |
+
elif float(signal.metadata["loudness"]) == float("-inf"):
|
1055 |
+
db_change = 0.0
|
1056 |
+
else:
|
1057 |
+
db = util.sample_from_dist(self.db, state)
|
1058 |
+
db_change = db - float(signal.metadata["loudness"])
|
1059 |
+
|
1060 |
+
return {"db": db_change}
|
1061 |
+
|
1062 |
+
def _transform(self, signal, db):
|
1063 |
+
return signal.volume_change(db)
|
1064 |
+
|
1065 |
+
|
1066 |
+
class Silence(BaseTransform):
|
1067 |
+
"""Zeros out the signal with some probability.
|
1068 |
+
|
1069 |
+
Parameters
|
1070 |
+
----------
|
1071 |
+
name : str, optional
|
1072 |
+
Name of this transform, used to identify it in the dictionary
|
1073 |
+
produced by ``self.instantiate``, by default None
|
1074 |
+
prob : float, optional
|
1075 |
+
Probability of applying this transform, by default 0.1
|
1076 |
+
"""
|
1077 |
+
|
1078 |
+
def __init__(self, name: str = None, prob: float = 0.1):
|
1079 |
+
super().__init__(name=name, prob=prob)
|
1080 |
+
|
1081 |
+
def _transform(self, signal):
|
1082 |
+
_loudness = signal._loudness
|
1083 |
+
signal = AudioSignal(
|
1084 |
+
torch.zeros_like(signal.audio_data),
|
1085 |
+
sample_rate=signal.sample_rate,
|
1086 |
+
stft_params=signal.stft_params,
|
1087 |
+
)
|
1088 |
+
# So that the amound of noise added is as if it wasn't silenced.
|
1089 |
+
# TODO: improve this hack
|
1090 |
+
signal._loudness = _loudness
|
1091 |
+
|
1092 |
+
return signal
|
1093 |
+
|
1094 |
+
|
1095 |
+
class LowPass(BaseTransform):
|
1096 |
+
"""Applies a LowPass filter.
|
1097 |
+
|
1098 |
+
Uses :py:func:`audiotools.core.dsp.DSPMixin.low_pass`.
|
1099 |
+
|
1100 |
+
Parameters
|
1101 |
+
----------
|
1102 |
+
cutoff : tuple, optional
|
1103 |
+
Cutoff frequency distribution,
|
1104 |
+
by default ``("choice", [4000, 8000, 16000])``
|
1105 |
+
zeros : int, optional
|
1106 |
+
Number of zero-crossings in filter, argument to
|
1107 |
+
``julius.LowPassFilters``, by default 51
|
1108 |
+
name : str, optional
|
1109 |
+
Name of this transform, used to identify it in the dictionary
|
1110 |
+
produced by ``self.instantiate``, by default None
|
1111 |
+
prob : float, optional
|
1112 |
+
Probability of applying this transform, by default 1.0
|
1113 |
+
"""
|
1114 |
+
|
1115 |
+
def __init__(
|
1116 |
+
self,
|
1117 |
+
cutoff: tuple = ("choice", [4000, 8000, 16000]),
|
1118 |
+
zeros: int = 51,
|
1119 |
+
name: str = None,
|
1120 |
+
prob: float = 1,
|
1121 |
+
):
|
1122 |
+
super().__init__(name=name, prob=prob)
|
1123 |
+
|
1124 |
+
self.cutoff = cutoff
|
1125 |
+
self.zeros = zeros
|
1126 |
+
|
1127 |
+
def _instantiate(self, state: RandomState):
|
1128 |
+
return {"cutoff": util.sample_from_dist(self.cutoff, state)}
|
1129 |
+
|
1130 |
+
def _transform(self, signal, cutoff):
|
1131 |
+
return signal.low_pass(cutoff, zeros=self.zeros)
|
1132 |
+
|
1133 |
+
|
1134 |
+
class HighPass(BaseTransform):
|
1135 |
+
"""Applies a HighPass filter.
|
1136 |
+
|
1137 |
+
Uses :py:func:`audiotools.core.dsp.DSPMixin.high_pass`.
|
1138 |
+
|
1139 |
+
Parameters
|
1140 |
+
----------
|
1141 |
+
cutoff : tuple, optional
|
1142 |
+
Cutoff frequency distribution,
|
1143 |
+
by default ``("choice", [50, 100, 250, 500, 1000])``
|
1144 |
+
zeros : int, optional
|
1145 |
+
Number of zero-crossings in filter, argument to
|
1146 |
+
``julius.LowPassFilters``, by default 51
|
1147 |
+
name : str, optional
|
1148 |
+
Name of this transform, used to identify it in the dictionary
|
1149 |
+
produced by ``self.instantiate``, by default None
|
1150 |
+
prob : float, optional
|
1151 |
+
Probability of applying this transform, by default 1.0
|
1152 |
+
"""
|
1153 |
+
|
1154 |
+
def __init__(
|
1155 |
+
self,
|
1156 |
+
cutoff: tuple = ("choice", [50, 100, 250, 500, 1000]),
|
1157 |
+
zeros: int = 51,
|
1158 |
+
name: str = None,
|
1159 |
+
prob: float = 1,
|
1160 |
+
):
|
1161 |
+
super().__init__(name=name, prob=prob)
|
1162 |
+
|
1163 |
+
self.cutoff = cutoff
|
1164 |
+
self.zeros = zeros
|
1165 |
+
|
1166 |
+
def _instantiate(self, state: RandomState):
|
1167 |
+
return {"cutoff": util.sample_from_dist(self.cutoff, state)}
|
1168 |
+
|
1169 |
+
def _transform(self, signal, cutoff):
|
1170 |
+
return signal.high_pass(cutoff, zeros=self.zeros)
|
1171 |
+
|
1172 |
+
|
1173 |
+
class RescaleAudio(BaseTransform):
|
1174 |
+
"""Rescales the audio so it is in between ``-val`` and ``val``
|
1175 |
+
only if the original audio exceeds those bounds. Useful if
|
1176 |
+
transforms have caused the audio to clip.
|
1177 |
+
|
1178 |
+
Uses :py:func:`audiotools.core.effects.EffectMixin.ensure_max_of_audio`.
|
1179 |
+
|
1180 |
+
Parameters
|
1181 |
+
----------
|
1182 |
+
val : float, optional
|
1183 |
+
Max absolute value of signal, by default 1.0
|
1184 |
+
name : str, optional
|
1185 |
+
Name of this transform, used to identify it in the dictionary
|
1186 |
+
produced by ``self.instantiate``, by default None
|
1187 |
+
prob : float, optional
|
1188 |
+
Probability of applying this transform, by default 1.0
|
1189 |
+
"""
|
1190 |
+
|
1191 |
+
def __init__(self, val: float = 1.0, name: str = None, prob: float = 1):
|
1192 |
+
super().__init__(name=name, prob=prob)
|
1193 |
+
|
1194 |
+
self.val = val
|
1195 |
+
|
1196 |
+
def _transform(self, signal):
|
1197 |
+
return signal.ensure_max_of_audio(self.val)
|
1198 |
+
|
1199 |
+
|
1200 |
+
class ShiftPhase(SpectralTransform):
|
1201 |
+
"""Shifts the phase of the audio.
|
1202 |
+
|
1203 |
+
Uses :py:func:`audiotools.core.dsp.DSPMixin.shift)phase`.
|
1204 |
+
|
1205 |
+
Parameters
|
1206 |
+
----------
|
1207 |
+
shift : tuple, optional
|
1208 |
+
How much to shift phase by, by default ("uniform", -np.pi, np.pi)
|
1209 |
+
name : str, optional
|
1210 |
+
Name of this transform, used to identify it in the dictionary
|
1211 |
+
produced by ``self.instantiate``, by default None
|
1212 |
+
prob : float, optional
|
1213 |
+
Probability of applying this transform, by default 1.0
|
1214 |
+
"""
|
1215 |
+
|
1216 |
+
def __init__(
|
1217 |
+
self,
|
1218 |
+
shift: tuple = ("uniform", -np.pi, np.pi),
|
1219 |
+
name: str = None,
|
1220 |
+
prob: float = 1,
|
1221 |
+
):
|
1222 |
+
super().__init__(name=name, prob=prob)
|
1223 |
+
self.shift = shift
|
1224 |
+
|
1225 |
+
def _instantiate(self, state: RandomState):
|
1226 |
+
return {"shift": util.sample_from_dist(self.shift, state)}
|
1227 |
+
|
1228 |
+
def _transform(self, signal, shift):
|
1229 |
+
return signal.shift_phase(shift)
|
1230 |
+
|
1231 |
+
|
1232 |
+
class InvertPhase(ShiftPhase):
|
1233 |
+
"""Inverts the phase of the audio.
|
1234 |
+
|
1235 |
+
Uses :py:func:`audiotools.core.dsp.DSPMixin.shift_phase`.
|
1236 |
+
|
1237 |
+
Parameters
|
1238 |
+
----------
|
1239 |
+
name : str, optional
|
1240 |
+
Name of this transform, used to identify it in the dictionary
|
1241 |
+
produced by ``self.instantiate``, by default None
|
1242 |
+
prob : float, optional
|
1243 |
+
Probability of applying this transform, by default 1.0
|
1244 |
+
"""
|
1245 |
+
|
1246 |
+
def __init__(self, name: str = None, prob: float = 1):
|
1247 |
+
super().__init__(shift=("const", np.pi), name=name, prob=prob)
|
1248 |
+
|
1249 |
+
|
1250 |
+
class CorruptPhase(SpectralTransform):
|
1251 |
+
"""Corrupts the phase of the audio.
|
1252 |
+
|
1253 |
+
Uses :py:func:`audiotools.core.dsp.DSPMixin.corrupt_phase`.
|
1254 |
+
|
1255 |
+
Parameters
|
1256 |
+
----------
|
1257 |
+
scale : tuple, optional
|
1258 |
+
How much to corrupt phase by, by default ("uniform", 0, np.pi)
|
1259 |
+
name : str, optional
|
1260 |
+
Name of this transform, used to identify it in the dictionary
|
1261 |
+
produced by ``self.instantiate``, by default None
|
1262 |
+
prob : float, optional
|
1263 |
+
Probability of applying this transform, by default 1.0
|
1264 |
+
"""
|
1265 |
+
|
1266 |
+
def __init__(
|
1267 |
+
self, scale: tuple = ("uniform", 0, np.pi), name: str = None, prob: float = 1
|
1268 |
+
):
|
1269 |
+
super().__init__(name=name, prob=prob)
|
1270 |
+
self.scale = scale
|
1271 |
+
|
1272 |
+
def _instantiate(self, state: RandomState, signal: AudioSignal = None):
|
1273 |
+
scale = util.sample_from_dist(self.scale, state)
|
1274 |
+
corruption = state.normal(scale=scale, size=signal.phase.shape[1:])
|
1275 |
+
return {"corruption": corruption.astype("float32")}
|
1276 |
+
|
1277 |
+
def _transform(self, signal, corruption):
|
1278 |
+
return signal.shift_phase(shift=corruption)
|
1279 |
+
|
1280 |
+
|
1281 |
+
class FrequencyMask(SpectralTransform):
|
1282 |
+
"""Masks a band of frequencies at a center frequency
|
1283 |
+
from the audio.
|
1284 |
+
|
1285 |
+
Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_frequencies`.
|
1286 |
+
|
1287 |
+
Parameters
|
1288 |
+
----------
|
1289 |
+
f_center : tuple, optional
|
1290 |
+
Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0)
|
1291 |
+
f_width : tuple, optional
|
1292 |
+
Width of zero'd out band, by default ("const", 0.1)
|
1293 |
+
name : str, optional
|
1294 |
+
Name of this transform, used to identify it in the dictionary
|
1295 |
+
produced by ``self.instantiate``, by default None
|
1296 |
+
prob : float, optional
|
1297 |
+
Probability of applying this transform, by default 1.0
|
1298 |
+
"""
|
1299 |
+
|
1300 |
+
def __init__(
|
1301 |
+
self,
|
1302 |
+
f_center: tuple = ("uniform", 0.0, 1.0),
|
1303 |
+
f_width: tuple = ("const", 0.1),
|
1304 |
+
name: str = None,
|
1305 |
+
prob: float = 1,
|
1306 |
+
):
|
1307 |
+
super().__init__(name=name, prob=prob)
|
1308 |
+
self.f_center = f_center
|
1309 |
+
self.f_width = f_width
|
1310 |
+
|
1311 |
+
def _instantiate(self, state: RandomState, signal: AudioSignal):
|
1312 |
+
f_center = util.sample_from_dist(self.f_center, state)
|
1313 |
+
f_width = util.sample_from_dist(self.f_width, state)
|
1314 |
+
|
1315 |
+
fmin = max(f_center - (f_width / 2), 0.0)
|
1316 |
+
fmax = min(f_center + (f_width / 2), 1.0)
|
1317 |
+
|
1318 |
+
fmin_hz = (signal.sample_rate / 2) * fmin
|
1319 |
+
fmax_hz = (signal.sample_rate / 2) * fmax
|
1320 |
+
|
1321 |
+
return {"fmin_hz": fmin_hz, "fmax_hz": fmax_hz}
|
1322 |
+
|
1323 |
+
def _transform(self, signal, fmin_hz: float, fmax_hz: float):
|
1324 |
+
return signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz)
|
1325 |
+
|
1326 |
+
|
1327 |
+
class TimeMask(SpectralTransform):
|
1328 |
+
"""Masks out contiguous time-steps from signal.
|
1329 |
+
|
1330 |
+
Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_timesteps`.
|
1331 |
+
|
1332 |
+
Parameters
|
1333 |
+
----------
|
1334 |
+
t_center : tuple, optional
|
1335 |
+
Center time in terms of 0.0 and 1.0 (duration of signal),
|
1336 |
+
by default ("uniform", 0.0, 1.0)
|
1337 |
+
t_width : tuple, optional
|
1338 |
+
Width of dropped out portion, by default ("const", 0.025)
|
1339 |
+
name : str, optional
|
1340 |
+
Name of this transform, used to identify it in the dictionary
|
1341 |
+
produced by ``self.instantiate``, by default None
|
1342 |
+
prob : float, optional
|
1343 |
+
Probability of applying this transform, by default 1.0
|
1344 |
+
"""
|
1345 |
+
|
1346 |
+
def __init__(
|
1347 |
+
self,
|
1348 |
+
t_center: tuple = ("uniform", 0.0, 1.0),
|
1349 |
+
t_width: tuple = ("const", 0.025),
|
1350 |
+
name: str = None,
|
1351 |
+
prob: float = 1,
|
1352 |
+
):
|
1353 |
+
super().__init__(name=name, prob=prob)
|
1354 |
+
self.t_center = t_center
|
1355 |
+
self.t_width = t_width
|
1356 |
+
|
1357 |
+
def _instantiate(self, state: RandomState, signal: AudioSignal):
|
1358 |
+
t_center = util.sample_from_dist(self.t_center, state)
|
1359 |
+
t_width = util.sample_from_dist(self.t_width, state)
|
1360 |
+
|
1361 |
+
tmin = max(t_center - (t_width / 2), 0.0)
|
1362 |
+
tmax = min(t_center + (t_width / 2), 1.0)
|
1363 |
+
|
1364 |
+
tmin_s = signal.signal_duration * tmin
|
1365 |
+
tmax_s = signal.signal_duration * tmax
|
1366 |
+
return {"tmin_s": tmin_s, "tmax_s": tmax_s}
|
1367 |
+
|
1368 |
+
def _transform(self, signal, tmin_s: float, tmax_s: float):
|
1369 |
+
return signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s)
|
1370 |
+
|
1371 |
+
|
1372 |
+
class MaskLowMagnitudes(SpectralTransform):
|
1373 |
+
"""Masks low magnitude regions out of signal.
|
1374 |
+
|
1375 |
+
Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_low_magnitudes`.
|
1376 |
+
|
1377 |
+
Parameters
|
1378 |
+
----------
|
1379 |
+
db_cutoff : tuple, optional
|
1380 |
+
Decibel value for which things below it will be masked away,
|
1381 |
+
by default ("uniform", -10, 10)
|
1382 |
+
name : str, optional
|
1383 |
+
Name of this transform, used to identify it in the dictionary
|
1384 |
+
produced by ``self.instantiate``, by default None
|
1385 |
+
prob : float, optional
|
1386 |
+
Probability of applying this transform, by default 1.0
|
1387 |
+
"""
|
1388 |
+
|
1389 |
+
def __init__(
|
1390 |
+
self,
|
1391 |
+
db_cutoff: tuple = ("uniform", -10, 10),
|
1392 |
+
name: str = None,
|
1393 |
+
prob: float = 1,
|
1394 |
+
):
|
1395 |
+
super().__init__(name=name, prob=prob)
|
1396 |
+
self.db_cutoff = db_cutoff
|
1397 |
+
|
1398 |
+
def _instantiate(self, state: RandomState, signal: AudioSignal = None):
|
1399 |
+
return {"db_cutoff": util.sample_from_dist(self.db_cutoff, state)}
|
1400 |
+
|
1401 |
+
def _transform(self, signal, db_cutoff: float):
|
1402 |
+
return signal.mask_low_magnitudes(db_cutoff)
|
1403 |
+
|
1404 |
+
|
1405 |
+
class Smoothing(BaseTransform):
|
1406 |
+
"""Convolves the signal with a smoothing window.
|
1407 |
+
|
1408 |
+
Uses :py:func:`audiotools.core.effects.EffectMixin.convolve`.
|
1409 |
+
|
1410 |
+
Parameters
|
1411 |
+
----------
|
1412 |
+
window_type : tuple, optional
|
1413 |
+
Type of window to use, by default ("const", "average")
|
1414 |
+
window_length : tuple, optional
|
1415 |
+
Length of smoothing window, by
|
1416 |
+
default ("choice", [8, 16, 32, 64, 128, 256, 512])
|
1417 |
+
name : str, optional
|
1418 |
+
Name of this transform, used to identify it in the dictionary
|
1419 |
+
produced by ``self.instantiate``, by default None
|
1420 |
+
prob : float, optional
|
1421 |
+
Probability of applying this transform, by default 1.0
|
1422 |
+
"""
|
1423 |
+
|
1424 |
+
def __init__(
|
1425 |
+
self,
|
1426 |
+
window_type: tuple = ("const", "average"),
|
1427 |
+
window_length: tuple = ("choice", [8, 16, 32, 64, 128, 256, 512]),
|
1428 |
+
name: str = None,
|
1429 |
+
prob: float = 1,
|
1430 |
+
):
|
1431 |
+
super().__init__(name=name, prob=prob)
|
1432 |
+
self.window_type = window_type
|
1433 |
+
self.window_length = window_length
|
1434 |
+
|
1435 |
+
def _instantiate(self, state: RandomState, signal: AudioSignal = None):
|
1436 |
+
window_type = util.sample_from_dist(self.window_type, state)
|
1437 |
+
window_length = util.sample_from_dist(self.window_length, state)
|
1438 |
+
window = signal.get_window(
|
1439 |
+
window_type=window_type, window_length=window_length, device="cpu"
|
1440 |
+
)
|
1441 |
+
return {"window": AudioSignal(window, signal.sample_rate)}
|
1442 |
+
|
1443 |
+
def _transform(self, signal, window):
|
1444 |
+
sscale = signal.audio_data.abs().max(dim=-1, keepdim=True).values
|
1445 |
+
sscale[sscale == 0.0] = 1.0
|
1446 |
+
|
1447 |
+
out = signal.convolve(window)
|
1448 |
+
|
1449 |
+
oscale = out.audio_data.abs().max(dim=-1, keepdim=True).values
|
1450 |
+
oscale[oscale == 0.0] = 1.0
|
1451 |
+
|
1452 |
+
out = out * (sscale / oscale)
|
1453 |
+
return out
|
1454 |
+
|
1455 |
+
|
1456 |
+
class TimeNoise(TimeMask):
|
1457 |
+
"""Similar to :py:func:`audiotools.data.transforms.TimeMask`, but
|
1458 |
+
replaces with noise instead of zeros.
|
1459 |
+
|
1460 |
+
Parameters
|
1461 |
+
----------
|
1462 |
+
t_center : tuple, optional
|
1463 |
+
Center time in terms of 0.0 and 1.0 (duration of signal),
|
1464 |
+
by default ("uniform", 0.0, 1.0)
|
1465 |
+
t_width : tuple, optional
|
1466 |
+
Width of dropped out portion, by default ("const", 0.025)
|
1467 |
+
name : str, optional
|
1468 |
+
Name of this transform, used to identify it in the dictionary
|
1469 |
+
produced by ``self.instantiate``, by default None
|
1470 |
+
prob : float, optional
|
1471 |
+
Probability of applying this transform, by default 1.0
|
1472 |
+
"""
|
1473 |
+
|
1474 |
+
def __init__(
|
1475 |
+
self,
|
1476 |
+
t_center: tuple = ("uniform", 0.0, 1.0),
|
1477 |
+
t_width: tuple = ("const", 0.025),
|
1478 |
+
name: str = None,
|
1479 |
+
prob: float = 1,
|
1480 |
+
):
|
1481 |
+
super().__init__(t_center=t_center, t_width=t_width, name=name, prob=prob)
|
1482 |
+
|
1483 |
+
def _transform(self, signal, tmin_s: float, tmax_s: float):
|
1484 |
+
signal = signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s, val=0.0)
|
1485 |
+
mag, phase = signal.magnitude, signal.phase
|
1486 |
+
|
1487 |
+
mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase)
|
1488 |
+
mask = (mag == 0.0) * (phase == 0.0)
|
1489 |
+
|
1490 |
+
mag[mask] = mag_r[mask]
|
1491 |
+
phase[mask] = phase_r[mask]
|
1492 |
+
|
1493 |
+
signal.magnitude = mag
|
1494 |
+
signal.phase = phase
|
1495 |
+
return signal
|
1496 |
+
|
1497 |
+
|
1498 |
+
class FrequencyNoise(FrequencyMask):
|
1499 |
+
"""Similar to :py:func:`audiotools.data.transforms.FrequencyMask`, but
|
1500 |
+
replaces with noise instead of zeros.
|
1501 |
+
|
1502 |
+
Parameters
|
1503 |
+
----------
|
1504 |
+
f_center : tuple, optional
|
1505 |
+
Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0)
|
1506 |
+
f_width : tuple, optional
|
1507 |
+
Width of zero'd out band, by default ("const", 0.1)
|
1508 |
+
name : str, optional
|
1509 |
+
Name of this transform, used to identify it in the dictionary
|
1510 |
+
produced by ``self.instantiate``, by default None
|
1511 |
+
prob : float, optional
|
1512 |
+
Probability of applying this transform, by default 1.0
|
1513 |
+
"""
|
1514 |
+
|
1515 |
+
def __init__(
|
1516 |
+
self,
|
1517 |
+
f_center: tuple = ("uniform", 0.0, 1.0),
|
1518 |
+
f_width: tuple = ("const", 0.1),
|
1519 |
+
name: str = None,
|
1520 |
+
prob: float = 1,
|
1521 |
+
):
|
1522 |
+
super().__init__(f_center=f_center, f_width=f_width, name=name, prob=prob)
|
1523 |
+
|
1524 |
+
def _transform(self, signal, fmin_hz: float, fmax_hz: float):
|
1525 |
+
signal = signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz)
|
1526 |
+
mag, phase = signal.magnitude, signal.phase
|
1527 |
+
|
1528 |
+
mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase)
|
1529 |
+
mask = (mag == 0.0) * (phase == 0.0)
|
1530 |
+
|
1531 |
+
mag[mask] = mag_r[mask]
|
1532 |
+
phase[mask] = phase_r[mask]
|
1533 |
+
|
1534 |
+
signal.magnitude = mag
|
1535 |
+
signal.phase = phase
|
1536 |
+
return signal
|
1537 |
+
|
1538 |
+
|
1539 |
+
class SpectralDenoising(Equalizer):
|
1540 |
+
"""Applies denoising algorithm detailed in
|
1541 |
+
:py:func:`audiotools.ml.layers.spectral_gate.SpectralGate`,
|
1542 |
+
using a randomly generated noise signal for denoising.
|
1543 |
+
|
1544 |
+
Parameters
|
1545 |
+
----------
|
1546 |
+
eq_amount : tuple, optional
|
1547 |
+
Amount of eq to apply to noise signal, by default ("const", 1.0)
|
1548 |
+
denoise_amount : tuple, optional
|
1549 |
+
Amount to denoise by, by default ("uniform", 0.8, 1.0)
|
1550 |
+
nz_volume : float, optional
|
1551 |
+
Volume of noise to denoise with, by default -40
|
1552 |
+
n_bands : int, optional
|
1553 |
+
Number of bands in equalizer, by default 6
|
1554 |
+
n_freq : int, optional
|
1555 |
+
Number of frequency bins to smooth by, by default 3
|
1556 |
+
n_time : int, optional
|
1557 |
+
Number of time bins to smooth by, by default 5
|
1558 |
+
name : str, optional
|
1559 |
+
Name of this transform, used to identify it in the dictionary
|
1560 |
+
produced by ``self.instantiate``, by default None
|
1561 |
+
prob : float, optional
|
1562 |
+
Probability of applying this transform, by default 1.0
|
1563 |
+
"""
|
1564 |
+
|
1565 |
+
def __init__(
|
1566 |
+
self,
|
1567 |
+
eq_amount: tuple = ("const", 1.0),
|
1568 |
+
denoise_amount: tuple = ("uniform", 0.8, 1.0),
|
1569 |
+
nz_volume: float = -40,
|
1570 |
+
n_bands: int = 6,
|
1571 |
+
n_freq: int = 3,
|
1572 |
+
n_time: int = 5,
|
1573 |
+
name: str = None,
|
1574 |
+
prob: float = 1,
|
1575 |
+
):
|
1576 |
+
super().__init__(eq_amount=eq_amount, n_bands=n_bands, name=name, prob=prob)
|
1577 |
+
|
1578 |
+
self.nz_volume = nz_volume
|
1579 |
+
self.denoise_amount = denoise_amount
|
1580 |
+
self.spectral_gate = ml.layers.SpectralGate(n_freq, n_time)
|
1581 |
+
|
1582 |
+
def _transform(self, signal, nz, eq, denoise_amount):
|
1583 |
+
nz = nz.normalize(self.nz_volume).equalizer(eq)
|
1584 |
+
self.spectral_gate = self.spectral_gate.to(signal.device)
|
1585 |
+
signal = self.spectral_gate(signal, nz, denoise_amount)
|
1586 |
+
return signal
|
1587 |
+
|
1588 |
+
def _instantiate(self, state: RandomState):
|
1589 |
+
kwargs = super()._instantiate(state)
|
1590 |
+
kwargs["denoise_amount"] = util.sample_from_dist(self.denoise_amount, state)
|
1591 |
+
kwargs["nz"] = AudioSignal(state.randn(22050), 44100)
|
1592 |
+
return kwargs
|
audiotools/metrics/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Functions for comparing AudioSignal objects to one another.
|
3 |
+
""" # fmt: skip
|
4 |
+
from . import distance
|
5 |
+
from . import quality
|
6 |
+
from . import spectral
|
audiotools/metrics/distance.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from .. import AudioSignal
|
5 |
+
|
6 |
+
|
7 |
+
class L1Loss(nn.L1Loss):
|
8 |
+
"""L1 Loss between AudioSignals. Defaults
|
9 |
+
to comparing ``audio_data``, but any
|
10 |
+
attribute of an AudioSignal can be used.
|
11 |
+
|
12 |
+
Parameters
|
13 |
+
----------
|
14 |
+
attribute : str, optional
|
15 |
+
Attribute of signal to compare, defaults to ``audio_data``.
|
16 |
+
weight : float, optional
|
17 |
+
Weight of this loss, defaults to 1.0.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
|
21 |
+
self.attribute = attribute
|
22 |
+
self.weight = weight
|
23 |
+
super().__init__(**kwargs)
|
24 |
+
|
25 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
26 |
+
"""
|
27 |
+
Parameters
|
28 |
+
----------
|
29 |
+
x : AudioSignal
|
30 |
+
Estimate AudioSignal
|
31 |
+
y : AudioSignal
|
32 |
+
Reference AudioSignal
|
33 |
+
|
34 |
+
Returns
|
35 |
+
-------
|
36 |
+
torch.Tensor
|
37 |
+
L1 loss between AudioSignal attributes.
|
38 |
+
"""
|
39 |
+
if isinstance(x, AudioSignal):
|
40 |
+
x = getattr(x, self.attribute)
|
41 |
+
y = getattr(y, self.attribute)
|
42 |
+
return super().forward(x, y)
|
43 |
+
|
44 |
+
|
45 |
+
class SISDRLoss(nn.Module):
|
46 |
+
"""
|
47 |
+
Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
|
48 |
+
of estimated and reference audio signals or aligned features.
|
49 |
+
|
50 |
+
Parameters
|
51 |
+
----------
|
52 |
+
scaling : int, optional
|
53 |
+
Whether to use scale-invariant (True) or
|
54 |
+
signal-to-noise ratio (False), by default True
|
55 |
+
reduction : str, optional
|
56 |
+
How to reduce across the batch (either 'mean',
|
57 |
+
'sum', or none).], by default ' mean'
|
58 |
+
zero_mean : int, optional
|
59 |
+
Zero mean the references and estimates before
|
60 |
+
computing the loss, by default True
|
61 |
+
clip_min : int, optional
|
62 |
+
The minimum possible loss value. Helps network
|
63 |
+
to not focus on making already good examples better, by default None
|
64 |
+
weight : float, optional
|
65 |
+
Weight of this loss, defaults to 1.0.
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
scaling: int = True,
|
71 |
+
reduction: str = "mean",
|
72 |
+
zero_mean: int = True,
|
73 |
+
clip_min: int = None,
|
74 |
+
weight: float = 1.0,
|
75 |
+
):
|
76 |
+
self.scaling = scaling
|
77 |
+
self.reduction = reduction
|
78 |
+
self.zero_mean = zero_mean
|
79 |
+
self.clip_min = clip_min
|
80 |
+
self.weight = weight
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
84 |
+
eps = 1e-8
|
85 |
+
# nb, nc, nt
|
86 |
+
if isinstance(x, AudioSignal):
|
87 |
+
references = x.audio_data
|
88 |
+
estimates = y.audio_data
|
89 |
+
else:
|
90 |
+
references = x
|
91 |
+
estimates = y
|
92 |
+
|
93 |
+
nb = references.shape[0]
|
94 |
+
references = references.reshape(nb, 1, -1).permute(0, 2, 1)
|
95 |
+
estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
|
96 |
+
|
97 |
+
# samples now on axis 1
|
98 |
+
if self.zero_mean:
|
99 |
+
mean_reference = references.mean(dim=1, keepdim=True)
|
100 |
+
mean_estimate = estimates.mean(dim=1, keepdim=True)
|
101 |
+
else:
|
102 |
+
mean_reference = 0
|
103 |
+
mean_estimate = 0
|
104 |
+
|
105 |
+
_references = references - mean_reference
|
106 |
+
_estimates = estimates - mean_estimate
|
107 |
+
|
108 |
+
references_projection = (_references**2).sum(dim=-2) + eps
|
109 |
+
references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
|
110 |
+
|
111 |
+
scale = (
|
112 |
+
(references_on_estimates / references_projection).unsqueeze(1)
|
113 |
+
if self.scaling
|
114 |
+
else 1
|
115 |
+
)
|
116 |
+
|
117 |
+
e_true = scale * _references
|
118 |
+
e_res = _estimates - e_true
|
119 |
+
|
120 |
+
signal = (e_true**2).sum(dim=1)
|
121 |
+
noise = (e_res**2).sum(dim=1)
|
122 |
+
sdr = -10 * torch.log10(signal / noise + eps)
|
123 |
+
|
124 |
+
if self.clip_min is not None:
|
125 |
+
sdr = torch.clamp(sdr, min=self.clip_min)
|
126 |
+
|
127 |
+
if self.reduction == "mean":
|
128 |
+
sdr = sdr.mean()
|
129 |
+
elif self.reduction == "sum":
|
130 |
+
sdr = sdr.sum()
|
131 |
+
return sdr
|
audiotools/metrics/quality.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .. import AudioSignal
|
7 |
+
|
8 |
+
|
9 |
+
def stoi(
|
10 |
+
estimates: AudioSignal,
|
11 |
+
references: AudioSignal,
|
12 |
+
extended: int = False,
|
13 |
+
):
|
14 |
+
"""Short term objective intelligibility
|
15 |
+
Computes the STOI (See [1][2]) of a denoised signal compared to a clean
|
16 |
+
signal, The output is expected to have a monotonic relation with the
|
17 |
+
subjective speech-intelligibility, where a higher score denotes better
|
18 |
+
speech intelligibility. Uses pystoi under the hood.
|
19 |
+
|
20 |
+
Parameters
|
21 |
+
----------
|
22 |
+
estimates : AudioSignal
|
23 |
+
Denoised speech
|
24 |
+
references : AudioSignal
|
25 |
+
Clean original speech
|
26 |
+
extended : int, optional
|
27 |
+
Boolean, whether to use the extended STOI described in [3], by default False
|
28 |
+
|
29 |
+
Returns
|
30 |
+
-------
|
31 |
+
Tensor[float]
|
32 |
+
Short time objective intelligibility measure between clean and
|
33 |
+
denoised speech
|
34 |
+
|
35 |
+
References
|
36 |
+
----------
|
37 |
+
1. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time
|
38 |
+
Objective Intelligibility Measure for Time-Frequency Weighted Noisy
|
39 |
+
Speech', ICASSP 2010, Texas, Dallas.
|
40 |
+
2. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for
|
41 |
+
Intelligibility Prediction of Time-Frequency Weighted Noisy Speech',
|
42 |
+
IEEE Transactions on Audio, Speech, and Language Processing, 2011.
|
43 |
+
3. Jesper Jensen and Cees H. Taal, 'An Algorithm for Predicting the
|
44 |
+
Intelligibility of Speech Masked by Modulated Noise Maskers',
|
45 |
+
IEEE Transactions on Audio, Speech and Language Processing, 2016.
|
46 |
+
"""
|
47 |
+
import pystoi
|
48 |
+
|
49 |
+
estimates = estimates.clone().to_mono()
|
50 |
+
references = references.clone().to_mono()
|
51 |
+
|
52 |
+
stois = []
|
53 |
+
for i in range(estimates.batch_size):
|
54 |
+
_stoi = pystoi.stoi(
|
55 |
+
references.audio_data[i, 0].detach().cpu().numpy(),
|
56 |
+
estimates.audio_data[i, 0].detach().cpu().numpy(),
|
57 |
+
references.sample_rate,
|
58 |
+
extended=extended,
|
59 |
+
)
|
60 |
+
stois.append(_stoi)
|
61 |
+
return torch.from_numpy(np.array(stois))
|
62 |
+
|
63 |
+
|
64 |
+
def pesq(
|
65 |
+
estimates: AudioSignal,
|
66 |
+
references: AudioSignal,
|
67 |
+
mode: str = "wb",
|
68 |
+
target_sr: float = 16000,
|
69 |
+
):
|
70 |
+
"""_summary_
|
71 |
+
|
72 |
+
Parameters
|
73 |
+
----------
|
74 |
+
estimates : AudioSignal
|
75 |
+
Degraded AudioSignal
|
76 |
+
references : AudioSignal
|
77 |
+
Reference AudioSignal
|
78 |
+
mode : str, optional
|
79 |
+
'wb' (wide-band) or 'nb' (narrow-band), by default "wb"
|
80 |
+
target_sr : float, optional
|
81 |
+
Target sample rate, by default 16000
|
82 |
+
|
83 |
+
Returns
|
84 |
+
-------
|
85 |
+
Tensor[float]
|
86 |
+
PESQ score: P.862.2 Prediction (MOS-LQO)
|
87 |
+
"""
|
88 |
+
from pesq import pesq as pesq_fn
|
89 |
+
|
90 |
+
estimates = estimates.clone().to_mono().resample(target_sr)
|
91 |
+
references = references.clone().to_mono().resample(target_sr)
|
92 |
+
|
93 |
+
pesqs = []
|
94 |
+
for i in range(estimates.batch_size):
|
95 |
+
_pesq = pesq_fn(
|
96 |
+
estimates.sample_rate,
|
97 |
+
references.audio_data[i, 0].detach().cpu().numpy(),
|
98 |
+
estimates.audio_data[i, 0].detach().cpu().numpy(),
|
99 |
+
mode,
|
100 |
+
)
|
101 |
+
pesqs.append(_pesq)
|
102 |
+
return torch.from_numpy(np.array(pesqs))
|
103 |
+
|
104 |
+
|
105 |
+
def visqol(
|
106 |
+
estimates: AudioSignal,
|
107 |
+
references: AudioSignal,
|
108 |
+
mode: str = "audio",
|
109 |
+
): # pragma: no cover
|
110 |
+
"""ViSQOL score.
|
111 |
+
|
112 |
+
Parameters
|
113 |
+
----------
|
114 |
+
estimates : AudioSignal
|
115 |
+
Degraded AudioSignal
|
116 |
+
references : AudioSignal
|
117 |
+
Reference AudioSignal
|
118 |
+
mode : str, optional
|
119 |
+
'audio' or 'speech', by default 'audio'
|
120 |
+
|
121 |
+
Returns
|
122 |
+
-------
|
123 |
+
Tensor[float]
|
124 |
+
ViSQOL score (MOS-LQO)
|
125 |
+
"""
|
126 |
+
from visqol import visqol_lib_py
|
127 |
+
from visqol.pb2 import visqol_config_pb2
|
128 |
+
from visqol.pb2 import similarity_result_pb2
|
129 |
+
|
130 |
+
config = visqol_config_pb2.VisqolConfig()
|
131 |
+
if mode == "audio":
|
132 |
+
target_sr = 48000
|
133 |
+
config.options.use_speech_scoring = False
|
134 |
+
svr_model_path = "libsvm_nu_svr_model.txt"
|
135 |
+
elif mode == "speech":
|
136 |
+
target_sr = 16000
|
137 |
+
config.options.use_speech_scoring = True
|
138 |
+
svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite"
|
139 |
+
else:
|
140 |
+
raise ValueError(f"Unrecognized mode: {mode}")
|
141 |
+
config.audio.sample_rate = target_sr
|
142 |
+
config.options.svr_model_path = os.path.join(
|
143 |
+
os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path
|
144 |
+
)
|
145 |
+
|
146 |
+
api = visqol_lib_py.VisqolApi()
|
147 |
+
api.Create(config)
|
148 |
+
|
149 |
+
estimates = estimates.clone().to_mono().resample(target_sr)
|
150 |
+
references = references.clone().to_mono().resample(target_sr)
|
151 |
+
|
152 |
+
visqols = []
|
153 |
+
for i in range(estimates.batch_size):
|
154 |
+
_visqol = api.Measure(
|
155 |
+
references.audio_data[i, 0].detach().cpu().numpy().astype(float),
|
156 |
+
estimates.audio_data[i, 0].detach().cpu().numpy().astype(float),
|
157 |
+
)
|
158 |
+
visqols.append(_visqol.moslqo)
|
159 |
+
return torch.from_numpy(np.array(visqols))
|
audiotools/metrics/spectral.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from .. import AudioSignal
|
8 |
+
from .. import STFTParams
|
9 |
+
|
10 |
+
|
11 |
+
class MultiScaleSTFTLoss(nn.Module):
|
12 |
+
"""Computes the multi-scale STFT loss from [1].
|
13 |
+
|
14 |
+
Parameters
|
15 |
+
----------
|
16 |
+
window_lengths : List[int], optional
|
17 |
+
Length of each window of each STFT, by default [2048, 512]
|
18 |
+
loss_fn : typing.Callable, optional
|
19 |
+
How to compare each loss, by default nn.L1Loss()
|
20 |
+
clamp_eps : float, optional
|
21 |
+
Clamp on the log magnitude, below, by default 1e-5
|
22 |
+
mag_weight : float, optional
|
23 |
+
Weight of raw magnitude portion of loss, by default 1.0
|
24 |
+
log_weight : float, optional
|
25 |
+
Weight of log magnitude portion of loss, by default 1.0
|
26 |
+
pow : float, optional
|
27 |
+
Power to raise magnitude to before taking log, by default 2.0
|
28 |
+
weight : float, optional
|
29 |
+
Weight of this loss, by default 1.0
|
30 |
+
match_stride : bool, optional
|
31 |
+
Whether to match the stride of convolutional layers, by default False
|
32 |
+
|
33 |
+
References
|
34 |
+
----------
|
35 |
+
|
36 |
+
1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
|
37 |
+
"DDSP: Differentiable Digital Signal Processing."
|
38 |
+
International Conference on Learning Representations. 2019.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
window_lengths: List[int] = [2048, 512],
|
44 |
+
loss_fn: typing.Callable = nn.L1Loss(),
|
45 |
+
clamp_eps: float = 1e-5,
|
46 |
+
mag_weight: float = 1.0,
|
47 |
+
log_weight: float = 1.0,
|
48 |
+
pow: float = 2.0,
|
49 |
+
weight: float = 1.0,
|
50 |
+
match_stride: bool = False,
|
51 |
+
window_type: str = None,
|
52 |
+
):
|
53 |
+
super().__init__()
|
54 |
+
self.stft_params = [
|
55 |
+
STFTParams(
|
56 |
+
window_length=w,
|
57 |
+
hop_length=w // 4,
|
58 |
+
match_stride=match_stride,
|
59 |
+
window_type=window_type,
|
60 |
+
)
|
61 |
+
for w in window_lengths
|
62 |
+
]
|
63 |
+
self.loss_fn = loss_fn
|
64 |
+
self.log_weight = log_weight
|
65 |
+
self.mag_weight = mag_weight
|
66 |
+
self.clamp_eps = clamp_eps
|
67 |
+
self.weight = weight
|
68 |
+
self.pow = pow
|
69 |
+
|
70 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
71 |
+
"""Computes multi-scale STFT between an estimate and a reference
|
72 |
+
signal.
|
73 |
+
|
74 |
+
Parameters
|
75 |
+
----------
|
76 |
+
x : AudioSignal
|
77 |
+
Estimate signal
|
78 |
+
y : AudioSignal
|
79 |
+
Reference signal
|
80 |
+
|
81 |
+
Returns
|
82 |
+
-------
|
83 |
+
torch.Tensor
|
84 |
+
Multi-scale STFT loss.
|
85 |
+
"""
|
86 |
+
loss = 0.0
|
87 |
+
for s in self.stft_params:
|
88 |
+
x.stft(s.window_length, s.hop_length, s.window_type)
|
89 |
+
y.stft(s.window_length, s.hop_length, s.window_type)
|
90 |
+
loss += self.log_weight * self.loss_fn(
|
91 |
+
x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
|
92 |
+
y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
|
93 |
+
)
|
94 |
+
loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
|
95 |
+
return loss
|
96 |
+
|
97 |
+
|
98 |
+
class MelSpectrogramLoss(nn.Module):
|
99 |
+
"""Compute distance between mel spectrograms. Can be used
|
100 |
+
in a multi-scale way.
|
101 |
+
|
102 |
+
Parameters
|
103 |
+
----------
|
104 |
+
n_mels : List[int]
|
105 |
+
Number of mels per STFT, by default [150, 80],
|
106 |
+
window_lengths : List[int], optional
|
107 |
+
Length of each window of each STFT, by default [2048, 512]
|
108 |
+
loss_fn : typing.Callable, optional
|
109 |
+
How to compare each loss, by default nn.L1Loss()
|
110 |
+
clamp_eps : float, optional
|
111 |
+
Clamp on the log magnitude, below, by default 1e-5
|
112 |
+
mag_weight : float, optional
|
113 |
+
Weight of raw magnitude portion of loss, by default 1.0
|
114 |
+
log_weight : float, optional
|
115 |
+
Weight of log magnitude portion of loss, by default 1.0
|
116 |
+
pow : float, optional
|
117 |
+
Power to raise magnitude to before taking log, by default 2.0
|
118 |
+
weight : float, optional
|
119 |
+
Weight of this loss, by default 1.0
|
120 |
+
match_stride : bool, optional
|
121 |
+
Whether to match the stride of convolutional layers, by default False
|
122 |
+
"""
|
123 |
+
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
n_mels: List[int] = [150, 80],
|
127 |
+
window_lengths: List[int] = [2048, 512],
|
128 |
+
loss_fn: typing.Callable = nn.L1Loss(),
|
129 |
+
clamp_eps: float = 1e-5,
|
130 |
+
mag_weight: float = 1.0,
|
131 |
+
log_weight: float = 1.0,
|
132 |
+
pow: float = 2.0,
|
133 |
+
weight: float = 1.0,
|
134 |
+
match_stride: bool = False,
|
135 |
+
mel_fmin: List[float] = [0.0, 0.0],
|
136 |
+
mel_fmax: List[float] = [None, None],
|
137 |
+
window_type: str = None,
|
138 |
+
):
|
139 |
+
super().__init__()
|
140 |
+
self.stft_params = [
|
141 |
+
STFTParams(
|
142 |
+
window_length=w,
|
143 |
+
hop_length=w // 4,
|
144 |
+
match_stride=match_stride,
|
145 |
+
window_type=window_type,
|
146 |
+
)
|
147 |
+
for w in window_lengths
|
148 |
+
]
|
149 |
+
self.n_mels = n_mels
|
150 |
+
self.loss_fn = loss_fn
|
151 |
+
self.clamp_eps = clamp_eps
|
152 |
+
self.log_weight = log_weight
|
153 |
+
self.mag_weight = mag_weight
|
154 |
+
self.weight = weight
|
155 |
+
self.mel_fmin = mel_fmin
|
156 |
+
self.mel_fmax = mel_fmax
|
157 |
+
self.pow = pow
|
158 |
+
|
159 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
160 |
+
"""Computes mel loss between an estimate and a reference
|
161 |
+
signal.
|
162 |
+
|
163 |
+
Parameters
|
164 |
+
----------
|
165 |
+
x : AudioSignal
|
166 |
+
Estimate signal
|
167 |
+
y : AudioSignal
|
168 |
+
Reference signal
|
169 |
+
|
170 |
+
Returns
|
171 |
+
-------
|
172 |
+
torch.Tensor
|
173 |
+
Mel loss.
|
174 |
+
"""
|
175 |
+
loss = 0.0
|
176 |
+
for n_mels, fmin, fmax, s in zip(
|
177 |
+
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
|
178 |
+
):
|
179 |
+
kwargs = {
|
180 |
+
"window_length": s.window_length,
|
181 |
+
"hop_length": s.hop_length,
|
182 |
+
"window_type": s.window_type,
|
183 |
+
}
|
184 |
+
x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
|
185 |
+
y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
|
186 |
+
|
187 |
+
loss += self.log_weight * self.loss_fn(
|
188 |
+
x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
|
189 |
+
y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
|
190 |
+
)
|
191 |
+
loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
|
192 |
+
return loss
|
193 |
+
|
194 |
+
|
195 |
+
class PhaseLoss(nn.Module):
|
196 |
+
"""Difference between phase spectrograms.
|
197 |
+
|
198 |
+
Parameters
|
199 |
+
----------
|
200 |
+
window_length : int, optional
|
201 |
+
Length of STFT window, by default 2048
|
202 |
+
hop_length : int, optional
|
203 |
+
Hop length of STFT window, by default 512
|
204 |
+
weight : float, optional
|
205 |
+
Weight of loss, by default 1.0
|
206 |
+
"""
|
207 |
+
|
208 |
+
def __init__(
|
209 |
+
self, window_length: int = 2048, hop_length: int = 512, weight: float = 1.0
|
210 |
+
):
|
211 |
+
super().__init__()
|
212 |
+
|
213 |
+
self.weight = weight
|
214 |
+
self.stft_params = STFTParams(window_length, hop_length)
|
215 |
+
|
216 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
217 |
+
"""Computes phase loss between an estimate and a reference
|
218 |
+
signal.
|
219 |
+
|
220 |
+
Parameters
|
221 |
+
----------
|
222 |
+
x : AudioSignal
|
223 |
+
Estimate signal
|
224 |
+
y : AudioSignal
|
225 |
+
Reference signal
|
226 |
+
|
227 |
+
Returns
|
228 |
+
-------
|
229 |
+
torch.Tensor
|
230 |
+
Phase loss.
|
231 |
+
"""
|
232 |
+
s = self.stft_params
|
233 |
+
x.stft(s.window_length, s.hop_length, s.window_type)
|
234 |
+
y.stft(s.window_length, s.hop_length, s.window_type)
|
235 |
+
|
236 |
+
# Take circular difference
|
237 |
+
diff = x.phase - y.phase
|
238 |
+
diff[diff < -np.pi] += 2 * np.pi
|
239 |
+
diff[diff > np.pi] -= -2 * np.pi
|
240 |
+
|
241 |
+
# Scale true magnitude to weights in [0, 1]
|
242 |
+
x_min, x_max = x.magnitude.min(), x.magnitude.max()
|
243 |
+
weights = (x.magnitude - x_min) / (x_max - x_min)
|
244 |
+
|
245 |
+
# Take weighted mean of all phase errors
|
246 |
+
loss = ((weights * diff) ** 2).mean()
|
247 |
+
return loss
|
audiotools/ml/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import decorators
|
2 |
+
from . import layers
|
3 |
+
from .accelerator import Accelerator
|
4 |
+
from .experiment import Experiment
|
5 |
+
from .layers import BaseModel
|
audiotools/ml/accelerator.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import typing
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
from torch.nn.parallel import DataParallel
|
7 |
+
from torch.nn.parallel import DistributedDataParallel
|
8 |
+
|
9 |
+
from ..data.datasets import ResumableDistributedSampler as DistributedSampler
|
10 |
+
from ..data.datasets import ResumableSequentialSampler as SequentialSampler
|
11 |
+
|
12 |
+
|
13 |
+
class Accelerator: # pragma: no cover
|
14 |
+
"""This class is used to prepare models and dataloaders for
|
15 |
+
usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to
|
16 |
+
prepare the respective objects. In the case of models, they are moved to
|
17 |
+
the appropriate GPU and SyncBatchNorm is applied to them. In the case of
|
18 |
+
dataloaders, a sampler is created and the dataloader is initialized with
|
19 |
+
that sampler.
|
20 |
+
|
21 |
+
If the world size is 1, prepare_model and prepare_dataloader are
|
22 |
+
no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the
|
23 |
+
script was launched without ``torchrun``, and ``DataParallel``
|
24 |
+
will be used instead of ``DistributedDataParallel`` (not recommended), if
|
25 |
+
the world size (number of GPUs) is greater than 1.
|
26 |
+
|
27 |
+
Parameters
|
28 |
+
----------
|
29 |
+
amp : bool, optional
|
30 |
+
Whether or not to enable automatic mixed precision, by default False
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, amp: bool = False):
|
34 |
+
local_rank = os.getenv("LOCAL_RANK", None)
|
35 |
+
self.world_size = torch.cuda.device_count()
|
36 |
+
|
37 |
+
self.use_ddp = self.world_size > 1 and local_rank is not None
|
38 |
+
self.use_dp = self.world_size > 1 and local_rank is None
|
39 |
+
self.device = "cpu" if self.world_size == 0 else "cuda"
|
40 |
+
|
41 |
+
if self.use_ddp:
|
42 |
+
local_rank = int(local_rank)
|
43 |
+
dist.init_process_group(
|
44 |
+
"nccl",
|
45 |
+
init_method="env://",
|
46 |
+
world_size=self.world_size,
|
47 |
+
rank=local_rank,
|
48 |
+
)
|
49 |
+
|
50 |
+
self.local_rank = 0 if local_rank is None else local_rank
|
51 |
+
self.amp = amp
|
52 |
+
|
53 |
+
class DummyScaler:
|
54 |
+
def __init__(self):
|
55 |
+
pass
|
56 |
+
|
57 |
+
def step(self, optimizer):
|
58 |
+
optimizer.step()
|
59 |
+
|
60 |
+
def scale(self, loss):
|
61 |
+
return loss
|
62 |
+
|
63 |
+
def unscale_(self, optimizer):
|
64 |
+
return optimizer
|
65 |
+
|
66 |
+
def update(self):
|
67 |
+
pass
|
68 |
+
|
69 |
+
self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler()
|
70 |
+
self.device_ctx = (
|
71 |
+
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
|
72 |
+
)
|
73 |
+
|
74 |
+
def __enter__(self):
|
75 |
+
if self.device_ctx is not None:
|
76 |
+
self.device_ctx.__enter__()
|
77 |
+
return self
|
78 |
+
|
79 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
80 |
+
if self.device_ctx is not None:
|
81 |
+
self.device_ctx.__exit__(exc_type, exc_value, traceback)
|
82 |
+
|
83 |
+
def prepare_model(self, model: torch.nn.Module, **kwargs):
|
84 |
+
"""Prepares model for DDP or DP. The model is moved to
|
85 |
+
the device of the correct rank.
|
86 |
+
|
87 |
+
Parameters
|
88 |
+
----------
|
89 |
+
model : torch.nn.Module
|
90 |
+
Model that is converted for DDP or DP.
|
91 |
+
|
92 |
+
Returns
|
93 |
+
-------
|
94 |
+
torch.nn.Module
|
95 |
+
Wrapped model, or original model if DDP and DP are turned off.
|
96 |
+
"""
|
97 |
+
model = model.to(self.device)
|
98 |
+
if self.use_ddp:
|
99 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
100 |
+
model = DistributedDataParallel(
|
101 |
+
model, device_ids=[self.local_rank], **kwargs
|
102 |
+
)
|
103 |
+
elif self.use_dp:
|
104 |
+
model = DataParallel(model, **kwargs)
|
105 |
+
return model
|
106 |
+
|
107 |
+
# Automatic mixed-precision utilities
|
108 |
+
def autocast(self, *args, **kwargs):
|
109 |
+
"""Context manager for autocasting. Arguments
|
110 |
+
go to ``torch.cuda.amp.autocast``.
|
111 |
+
"""
|
112 |
+
return torch.cuda.amp.autocast(self.amp, *args, **kwargs)
|
113 |
+
|
114 |
+
def backward(self, loss: torch.Tensor):
|
115 |
+
"""Backwards pass, after scaling the loss if ``amp`` is
|
116 |
+
enabled.
|
117 |
+
|
118 |
+
Parameters
|
119 |
+
----------
|
120 |
+
loss : torch.Tensor
|
121 |
+
Loss value.
|
122 |
+
"""
|
123 |
+
self.scaler.scale(loss).backward()
|
124 |
+
|
125 |
+
def step(self, optimizer: torch.optim.Optimizer):
|
126 |
+
"""Steps the optimizer, using a ``scaler`` if ``amp`` is
|
127 |
+
enabled.
|
128 |
+
|
129 |
+
Parameters
|
130 |
+
----------
|
131 |
+
optimizer : torch.optim.Optimizer
|
132 |
+
Optimizer to step forward.
|
133 |
+
"""
|
134 |
+
self.scaler.step(optimizer)
|
135 |
+
|
136 |
+
def update(self):
|
137 |
+
"""Updates the scale factor."""
|
138 |
+
self.scaler.update()
|
139 |
+
|
140 |
+
def prepare_dataloader(
|
141 |
+
self, dataset: typing.Iterable, start_idx: int = None, **kwargs
|
142 |
+
):
|
143 |
+
"""Wraps a dataset with a DataLoader, using the correct sampler if DDP is
|
144 |
+
enabled.
|
145 |
+
|
146 |
+
Parameters
|
147 |
+
----------
|
148 |
+
dataset : typing.Iterable
|
149 |
+
Dataset to build Dataloader around.
|
150 |
+
start_idx : int, optional
|
151 |
+
Start index of sampler, useful if resuming from some epoch,
|
152 |
+
by default None
|
153 |
+
|
154 |
+
Returns
|
155 |
+
-------
|
156 |
+
_type_
|
157 |
+
_description_
|
158 |
+
"""
|
159 |
+
|
160 |
+
if self.use_ddp:
|
161 |
+
sampler = DistributedSampler(
|
162 |
+
dataset,
|
163 |
+
start_idx,
|
164 |
+
num_replicas=self.world_size,
|
165 |
+
rank=self.local_rank,
|
166 |
+
)
|
167 |
+
if "num_workers" in kwargs:
|
168 |
+
kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1)
|
169 |
+
kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1)
|
170 |
+
else:
|
171 |
+
sampler = SequentialSampler(dataset, start_idx)
|
172 |
+
|
173 |
+
dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs)
|
174 |
+
return dataloader
|
175 |
+
|
176 |
+
@staticmethod
|
177 |
+
def unwrap(model):
|
178 |
+
"""Unwraps the model if it was wrapped in DDP or DP, otherwise
|
179 |
+
just returns the model. Use this to unwrap the model returned by
|
180 |
+
:py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`.
|
181 |
+
"""
|
182 |
+
if hasattr(model, "module"):
|
183 |
+
return model.module
|
184 |
+
return model
|
audiotools/ml/decorators.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from collections import defaultdict
|
5 |
+
from functools import wraps
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
from rich import box
|
10 |
+
from rich.console import Console
|
11 |
+
from rich.console import Group
|
12 |
+
from rich.live import Live
|
13 |
+
from rich.markdown import Markdown
|
14 |
+
from rich.padding import Padding
|
15 |
+
from rich.panel import Panel
|
16 |
+
from rich.progress import BarColumn
|
17 |
+
from rich.progress import Progress
|
18 |
+
from rich.progress import SpinnerColumn
|
19 |
+
from rich.progress import TimeElapsedColumn
|
20 |
+
from rich.progress import TimeRemainingColumn
|
21 |
+
from rich.rule import Rule
|
22 |
+
from rich.table import Table
|
23 |
+
from torch.utils.tensorboard import SummaryWriter
|
24 |
+
|
25 |
+
|
26 |
+
# This is here so that the history can be pickled.
|
27 |
+
def default_list():
|
28 |
+
return []
|
29 |
+
|
30 |
+
|
31 |
+
class Mean:
|
32 |
+
"""Keeps track of the running mean, along with the latest
|
33 |
+
value.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self):
|
37 |
+
self.reset()
|
38 |
+
|
39 |
+
def __call__(self):
|
40 |
+
mean = self.total / max(self.count, 1)
|
41 |
+
return mean
|
42 |
+
|
43 |
+
def reset(self):
|
44 |
+
self.count = 0
|
45 |
+
self.total = 0
|
46 |
+
|
47 |
+
def update(self, val):
|
48 |
+
if math.isfinite(val):
|
49 |
+
self.count += 1
|
50 |
+
self.total += val
|
51 |
+
|
52 |
+
|
53 |
+
def when(condition):
|
54 |
+
"""Runs a function only when the condition is met. The condition is
|
55 |
+
a function that is run.
|
56 |
+
|
57 |
+
Parameters
|
58 |
+
----------
|
59 |
+
condition : Callable
|
60 |
+
Function to run to check whether or not to run the decorated
|
61 |
+
function.
|
62 |
+
|
63 |
+
Example
|
64 |
+
-------
|
65 |
+
Checkpoint only runs every 100 iterations, and only if the
|
66 |
+
local rank is 0.
|
67 |
+
|
68 |
+
>>> i = 0
|
69 |
+
>>> rank = 0
|
70 |
+
>>>
|
71 |
+
>>> @when(lambda: i % 100 == 0 and rank == 0)
|
72 |
+
>>> def checkpoint():
|
73 |
+
>>> print("Saving to /runs/exp1")
|
74 |
+
>>>
|
75 |
+
>>> for i in range(1000):
|
76 |
+
>>> checkpoint()
|
77 |
+
|
78 |
+
"""
|
79 |
+
|
80 |
+
def decorator(fn):
|
81 |
+
@wraps(fn)
|
82 |
+
def decorated(*args, **kwargs):
|
83 |
+
if condition():
|
84 |
+
return fn(*args, **kwargs)
|
85 |
+
|
86 |
+
return decorated
|
87 |
+
|
88 |
+
return decorator
|
89 |
+
|
90 |
+
|
91 |
+
def timer(prefix: str = "time"):
|
92 |
+
"""Adds execution time to the output dictionary of the decorated
|
93 |
+
function. The function decorated by this must output a dictionary.
|
94 |
+
The key added will follow the form "[prefix]/[name_of_function]"
|
95 |
+
|
96 |
+
Parameters
|
97 |
+
----------
|
98 |
+
prefix : str, optional
|
99 |
+
The key added will follow the form "[prefix]/[name_of_function]",
|
100 |
+
by default "time".
|
101 |
+
"""
|
102 |
+
|
103 |
+
def decorator(fn):
|
104 |
+
@wraps(fn)
|
105 |
+
def decorated(*args, **kwargs):
|
106 |
+
s = time.perf_counter()
|
107 |
+
output = fn(*args, **kwargs)
|
108 |
+
assert isinstance(output, dict)
|
109 |
+
e = time.perf_counter()
|
110 |
+
output[f"{prefix}/{fn.__name__}"] = e - s
|
111 |
+
return output
|
112 |
+
|
113 |
+
return decorated
|
114 |
+
|
115 |
+
return decorator
|
116 |
+
|
117 |
+
|
118 |
+
class Tracker:
|
119 |
+
"""
|
120 |
+
A tracker class that helps to monitor the progress of training and logging the metrics.
|
121 |
+
|
122 |
+
Attributes
|
123 |
+
----------
|
124 |
+
metrics : dict
|
125 |
+
A dictionary containing the metrics for each label.
|
126 |
+
history : dict
|
127 |
+
A dictionary containing the history of metrics for each label.
|
128 |
+
writer : SummaryWriter
|
129 |
+
A SummaryWriter object for logging the metrics.
|
130 |
+
rank : int
|
131 |
+
The rank of the current process.
|
132 |
+
step : int
|
133 |
+
The current step of the training.
|
134 |
+
tasks : dict
|
135 |
+
A dictionary containing the progress bars and tables for each label.
|
136 |
+
pbar : Progress
|
137 |
+
A progress bar object for displaying the progress.
|
138 |
+
consoles : list
|
139 |
+
A list of console objects for logging.
|
140 |
+
live : Live
|
141 |
+
A Live object for updating the display live.
|
142 |
+
|
143 |
+
Methods
|
144 |
+
-------
|
145 |
+
print(msg: str)
|
146 |
+
Prints the given message to all consoles.
|
147 |
+
update(label: str, fn_name: str)
|
148 |
+
Updates the progress bar and table for the given label.
|
149 |
+
done(label: str, title: str)
|
150 |
+
Resets the progress bar and table for the given label and prints the final result.
|
151 |
+
track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ)
|
152 |
+
A decorator for tracking the progress and metrics of a function.
|
153 |
+
log(label: str, value_type: str = "value", history: bool = True)
|
154 |
+
A decorator for logging the metrics of a function.
|
155 |
+
is_best(label: str, key: str) -> bool
|
156 |
+
Checks if the latest value of the given key in the label is the best so far.
|
157 |
+
state_dict() -> dict
|
158 |
+
Returns a dictionary containing the state of the tracker.
|
159 |
+
load_state_dict(state_dict: dict) -> Tracker
|
160 |
+
Loads the state of the tracker from the given state dictionary.
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(
|
164 |
+
self,
|
165 |
+
writer: SummaryWriter = None,
|
166 |
+
log_file: str = None,
|
167 |
+
rank: int = 0,
|
168 |
+
console_width: int = 100,
|
169 |
+
step: int = 0,
|
170 |
+
):
|
171 |
+
"""
|
172 |
+
Initializes the Tracker object.
|
173 |
+
|
174 |
+
Parameters
|
175 |
+
----------
|
176 |
+
writer : SummaryWriter, optional
|
177 |
+
A SummaryWriter object for logging the metrics, by default None.
|
178 |
+
log_file : str, optional
|
179 |
+
The path to the log file, by default None.
|
180 |
+
rank : int, optional
|
181 |
+
The rank of the current process, by default 0.
|
182 |
+
console_width : int, optional
|
183 |
+
The width of the console, by default 100.
|
184 |
+
step : int, optional
|
185 |
+
The current step of the training, by default 0.
|
186 |
+
"""
|
187 |
+
self.metrics = {}
|
188 |
+
self.history = {}
|
189 |
+
self.writer = writer
|
190 |
+
self.rank = rank
|
191 |
+
self.step = step
|
192 |
+
|
193 |
+
# Create progress bars etc.
|
194 |
+
self.tasks = {}
|
195 |
+
self.pbar = Progress(
|
196 |
+
SpinnerColumn(),
|
197 |
+
"[progress.description]{task.description}",
|
198 |
+
"{task.completed}/{task.total}",
|
199 |
+
BarColumn(),
|
200 |
+
TimeElapsedColumn(),
|
201 |
+
"/",
|
202 |
+
TimeRemainingColumn(),
|
203 |
+
)
|
204 |
+
self.consoles = [Console(width=console_width)]
|
205 |
+
self.live = Live(console=self.consoles[0], refresh_per_second=10)
|
206 |
+
if log_file is not None:
|
207 |
+
self.consoles.append(Console(width=console_width, file=open(log_file, "a")))
|
208 |
+
|
209 |
+
def print(self, msg):
|
210 |
+
"""
|
211 |
+
Prints the given message to all consoles.
|
212 |
+
|
213 |
+
Parameters
|
214 |
+
----------
|
215 |
+
msg : str
|
216 |
+
The message to be printed.
|
217 |
+
"""
|
218 |
+
if self.rank == 0:
|
219 |
+
for c in self.consoles:
|
220 |
+
c.log(msg)
|
221 |
+
|
222 |
+
def update(self, label, fn_name):
|
223 |
+
"""
|
224 |
+
Updates the progress bar and table for the given label.
|
225 |
+
|
226 |
+
Parameters
|
227 |
+
----------
|
228 |
+
label : str
|
229 |
+
The label of the progress bar and table to be updated.
|
230 |
+
fn_name : str
|
231 |
+
The name of the function associated with the label.
|
232 |
+
"""
|
233 |
+
if self.rank == 0:
|
234 |
+
self.pbar.advance(self.tasks[label]["pbar"])
|
235 |
+
|
236 |
+
# Create table
|
237 |
+
table = Table(title=label, expand=True, box=box.MINIMAL)
|
238 |
+
table.add_column("key", style="cyan")
|
239 |
+
table.add_column("value", style="bright_blue")
|
240 |
+
table.add_column("mean", style="bright_green")
|
241 |
+
|
242 |
+
keys = self.metrics[label]["value"].keys()
|
243 |
+
for k in keys:
|
244 |
+
value = self.metrics[label]["value"][k]
|
245 |
+
mean = self.metrics[label]["mean"][k]()
|
246 |
+
table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}")
|
247 |
+
|
248 |
+
self.tasks[label]["table"] = table
|
249 |
+
tables = [t["table"] for t in self.tasks.values()]
|
250 |
+
group = Group(*tables, self.pbar)
|
251 |
+
self.live.update(
|
252 |
+
Group(
|
253 |
+
Padding("", (0, 0)),
|
254 |
+
Rule(f"[italic]{fn_name}()", style="white"),
|
255 |
+
Padding("", (0, 0)),
|
256 |
+
Panel.fit(
|
257 |
+
group, padding=(0, 5), title="[b]Progress", border_style="blue"
|
258 |
+
),
|
259 |
+
)
|
260 |
+
)
|
261 |
+
|
262 |
+
def done(self, label: str, title: str):
|
263 |
+
"""
|
264 |
+
Resets the progress bar and table for the given label and prints the final result.
|
265 |
+
|
266 |
+
Parameters
|
267 |
+
----------
|
268 |
+
label : str
|
269 |
+
The label of the progress bar and table to be reset.
|
270 |
+
title : str
|
271 |
+
The title to be displayed when printing the final result.
|
272 |
+
"""
|
273 |
+
for label in self.metrics:
|
274 |
+
for v in self.metrics[label]["mean"].values():
|
275 |
+
v.reset()
|
276 |
+
|
277 |
+
if self.rank == 0:
|
278 |
+
self.pbar.reset(self.tasks[label]["pbar"])
|
279 |
+
tables = [t["table"] for t in self.tasks.values()]
|
280 |
+
group = Group(Markdown(f"# {title}"), *tables, self.pbar)
|
281 |
+
self.print(group)
|
282 |
+
|
283 |
+
def track(
|
284 |
+
self,
|
285 |
+
label: str,
|
286 |
+
length: int,
|
287 |
+
completed: int = 0,
|
288 |
+
op: dist.ReduceOp = dist.ReduceOp.AVG,
|
289 |
+
ddp_active: bool = "LOCAL_RANK" in os.environ,
|
290 |
+
):
|
291 |
+
"""
|
292 |
+
A decorator for tracking the progress and metrics of a function.
|
293 |
+
|
294 |
+
Parameters
|
295 |
+
----------
|
296 |
+
label : str
|
297 |
+
The label to be associated with the progress and metrics.
|
298 |
+
length : int
|
299 |
+
The total number of iterations to be completed.
|
300 |
+
completed : int, optional
|
301 |
+
The number of iterations already completed, by default 0.
|
302 |
+
op : dist.ReduceOp, optional
|
303 |
+
The reduce operation to be used, by default dist.ReduceOp.AVG.
|
304 |
+
ddp_active : bool, optional
|
305 |
+
Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ.
|
306 |
+
"""
|
307 |
+
self.tasks[label] = {
|
308 |
+
"pbar": self.pbar.add_task(
|
309 |
+
f"[white]Iteration ({label})", total=length, completed=completed
|
310 |
+
),
|
311 |
+
"table": Table(),
|
312 |
+
}
|
313 |
+
self.metrics[label] = {
|
314 |
+
"value": defaultdict(),
|
315 |
+
"mean": defaultdict(lambda: Mean()),
|
316 |
+
}
|
317 |
+
|
318 |
+
def decorator(fn):
|
319 |
+
@wraps(fn)
|
320 |
+
def decorated(*args, **kwargs):
|
321 |
+
output = fn(*args, **kwargs)
|
322 |
+
if not isinstance(output, dict):
|
323 |
+
self.update(label, fn.__name__)
|
324 |
+
return output
|
325 |
+
# Collect across all DDP processes
|
326 |
+
scalar_keys = []
|
327 |
+
for k, v in output.items():
|
328 |
+
if isinstance(v, (int, float)):
|
329 |
+
v = torch.tensor([v])
|
330 |
+
if not torch.is_tensor(v):
|
331 |
+
continue
|
332 |
+
if ddp_active and v.is_cuda: # pragma: no cover
|
333 |
+
dist.all_reduce(v, op=op)
|
334 |
+
output[k] = v.detach()
|
335 |
+
if torch.numel(v) == 1:
|
336 |
+
scalar_keys.append(k)
|
337 |
+
output[k] = v.item()
|
338 |
+
|
339 |
+
# Save the outputs to tracker
|
340 |
+
for k, v in output.items():
|
341 |
+
if k not in scalar_keys:
|
342 |
+
continue
|
343 |
+
self.metrics[label]["value"][k] = v
|
344 |
+
# Update the running mean
|
345 |
+
self.metrics[label]["mean"][k].update(v)
|
346 |
+
|
347 |
+
self.update(label, fn.__name__)
|
348 |
+
return output
|
349 |
+
|
350 |
+
return decorated
|
351 |
+
|
352 |
+
return decorator
|
353 |
+
|
354 |
+
def log(self, label: str, value_type: str = "value", history: bool = True):
|
355 |
+
"""
|
356 |
+
A decorator for logging the metrics of a function.
|
357 |
+
|
358 |
+
Parameters
|
359 |
+
----------
|
360 |
+
label : str
|
361 |
+
The label to be associated with the logging.
|
362 |
+
value_type : str, optional
|
363 |
+
The type of value to be logged, by default "value".
|
364 |
+
history : bool, optional
|
365 |
+
Whether to save the history of the metrics, by default True.
|
366 |
+
"""
|
367 |
+
assert value_type in ["mean", "value"]
|
368 |
+
if history:
|
369 |
+
if label not in self.history:
|
370 |
+
self.history[label] = defaultdict(default_list)
|
371 |
+
|
372 |
+
def decorator(fn):
|
373 |
+
@wraps(fn)
|
374 |
+
def decorated(*args, **kwargs):
|
375 |
+
output = fn(*args, **kwargs)
|
376 |
+
if self.rank == 0:
|
377 |
+
nonlocal value_type, label
|
378 |
+
metrics = self.metrics[label][value_type]
|
379 |
+
for k, v in metrics.items():
|
380 |
+
v = v() if isinstance(v, Mean) else v
|
381 |
+
if self.writer is not None:
|
382 |
+
self.writer.add_scalar(f"{k}/{label}", v, self.step)
|
383 |
+
if label in self.history:
|
384 |
+
self.history[label][k].append(v)
|
385 |
+
|
386 |
+
if label in self.history:
|
387 |
+
self.history[label]["step"].append(self.step)
|
388 |
+
|
389 |
+
return output
|
390 |
+
|
391 |
+
return decorated
|
392 |
+
|
393 |
+
return decorator
|
394 |
+
|
395 |
+
def is_best(self, label, key):
|
396 |
+
"""
|
397 |
+
Checks if the latest value of the given key in the label is the best so far.
|
398 |
+
|
399 |
+
Parameters
|
400 |
+
----------
|
401 |
+
label : str
|
402 |
+
The label of the metrics to be checked.
|
403 |
+
key : str
|
404 |
+
The key of the metric to be checked.
|
405 |
+
|
406 |
+
Returns
|
407 |
+
-------
|
408 |
+
bool
|
409 |
+
True if the latest value is the best so far, otherwise False.
|
410 |
+
"""
|
411 |
+
return self.history[label][key][-1] == min(self.history[label][key])
|
412 |
+
|
413 |
+
def state_dict(self):
|
414 |
+
"""
|
415 |
+
Returns a dictionary containing the state of the tracker.
|
416 |
+
|
417 |
+
Returns
|
418 |
+
-------
|
419 |
+
dict
|
420 |
+
A dictionary containing the history and step of the tracker.
|
421 |
+
"""
|
422 |
+
return {"history": self.history, "step": self.step}
|
423 |
+
|
424 |
+
def load_state_dict(self, state_dict):
|
425 |
+
"""
|
426 |
+
Loads the state of the tracker from the given state dictionary.
|
427 |
+
|
428 |
+
Parameters
|
429 |
+
----------
|
430 |
+
state_dict : dict
|
431 |
+
A dictionary containing the history and step of the tracker.
|
432 |
+
|
433 |
+
Returns
|
434 |
+
-------
|
435 |
+
Tracker
|
436 |
+
The tracker object with the loaded state.
|
437 |
+
"""
|
438 |
+
self.history = state_dict["history"]
|
439 |
+
self.step = state_dict["step"]
|
440 |
+
return self
|
audiotools/ml/experiment.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Useful class for Experiment tracking, and ensuring code is
|
3 |
+
saved alongside files.
|
4 |
+
""" # fmt: skip
|
5 |
+
import datetime
|
6 |
+
import os
|
7 |
+
import shlex
|
8 |
+
import shutil
|
9 |
+
import subprocess
|
10 |
+
import typing
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import randomname
|
14 |
+
|
15 |
+
|
16 |
+
class Experiment:
|
17 |
+
"""This class contains utilities for managing experiments.
|
18 |
+
It is a context manager, that when you enter it, changes
|
19 |
+
your directory to a specified experiment folder (which
|
20 |
+
optionally can have an automatically generated experiment
|
21 |
+
name, or a specified one), and changes the CUDA device used
|
22 |
+
to the specified device (or devices).
|
23 |
+
|
24 |
+
Parameters
|
25 |
+
----------
|
26 |
+
exp_directory : str
|
27 |
+
Folder where all experiments are saved, by default "runs/".
|
28 |
+
exp_name : str, optional
|
29 |
+
Name of the experiment, by default uses the current time, date, and
|
30 |
+
hostname to save.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
exp_directory: str = "runs/",
|
36 |
+
exp_name: str = None,
|
37 |
+
):
|
38 |
+
if exp_name is None:
|
39 |
+
exp_name = self.generate_exp_name()
|
40 |
+
exp_dir = Path(exp_directory) / exp_name
|
41 |
+
exp_dir.mkdir(parents=True, exist_ok=True)
|
42 |
+
|
43 |
+
self.exp_dir = exp_dir
|
44 |
+
self.exp_name = exp_name
|
45 |
+
self.git_tracked_files = (
|
46 |
+
subprocess.check_output(
|
47 |
+
shlex.split("git ls-tree --full-tree --name-only -r HEAD")
|
48 |
+
)
|
49 |
+
.decode("utf-8")
|
50 |
+
.splitlines()
|
51 |
+
)
|
52 |
+
self.parent_directory = Path(".").absolute()
|
53 |
+
|
54 |
+
def __enter__(self):
|
55 |
+
self.prev_dir = os.getcwd()
|
56 |
+
os.chdir(self.exp_dir)
|
57 |
+
return self
|
58 |
+
|
59 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
60 |
+
os.chdir(self.prev_dir)
|
61 |
+
|
62 |
+
@staticmethod
|
63 |
+
def generate_exp_name():
|
64 |
+
"""Generates a random experiment name based on the date
|
65 |
+
and a randomly generated adjective-noun tuple.
|
66 |
+
|
67 |
+
Returns
|
68 |
+
-------
|
69 |
+
str
|
70 |
+
Randomly generated experiment name.
|
71 |
+
"""
|
72 |
+
date = datetime.datetime.now().strftime("%y%m%d")
|
73 |
+
name = f"{date}-{randomname.get_name()}"
|
74 |
+
return name
|
75 |
+
|
76 |
+
def snapshot(self, filter_fn: typing.Callable = lambda f: True):
|
77 |
+
"""Captures a full snapshot of all the files tracked by git at the time
|
78 |
+
the experiment is run. It also captures the diff against the committed
|
79 |
+
code as a separate file.
|
80 |
+
|
81 |
+
Parameters
|
82 |
+
----------
|
83 |
+
filter_fn : typing.Callable, optional
|
84 |
+
Function that can be used to exclude some files
|
85 |
+
from the snapshot, by default accepts all files
|
86 |
+
"""
|
87 |
+
for f in self.git_tracked_files:
|
88 |
+
if filter_fn(f):
|
89 |
+
Path(f).parent.mkdir(parents=True, exist_ok=True)
|
90 |
+
shutil.copyfile(self.parent_directory / f, f)
|
audiotools/ml/layers/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .base import BaseModel
|
2 |
+
from .spectral_gate import SpectralGate
|
audiotools/ml/layers/base.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import shutil
|
3 |
+
import tempfile
|
4 |
+
import typing
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
|
11 |
+
class BaseModel(nn.Module):
|
12 |
+
"""This is a class that adds useful save/load functionality to a
|
13 |
+
``torch.nn.Module`` object. ``BaseModel`` objects can be saved
|
14 |
+
as ``torch.package`` easily, making them super easy to port between
|
15 |
+
machines without requiring a ton of dependencies. Files can also be
|
16 |
+
saved as just weights, in the standard way.
|
17 |
+
|
18 |
+
>>> class Model(ml.BaseModel):
|
19 |
+
>>> def __init__(self, arg1: float = 1.0):
|
20 |
+
>>> super().__init__()
|
21 |
+
>>> self.arg1 = arg1
|
22 |
+
>>> self.linear = nn.Linear(1, 1)
|
23 |
+
>>>
|
24 |
+
>>> def forward(self, x):
|
25 |
+
>>> return self.linear(x)
|
26 |
+
>>>
|
27 |
+
>>> model1 = Model()
|
28 |
+
>>>
|
29 |
+
>>> with tempfile.NamedTemporaryFile(suffix=".pth") as f:
|
30 |
+
>>> model1.save(
|
31 |
+
>>> f.name,
|
32 |
+
>>> )
|
33 |
+
>>> model2 = Model.load(f.name)
|
34 |
+
>>> out2 = seed_and_run(model2, x)
|
35 |
+
>>> assert torch.allclose(out1, out2)
|
36 |
+
>>>
|
37 |
+
>>> model1.save(f.name, package=True)
|
38 |
+
>>> model2 = Model.load(f.name)
|
39 |
+
>>> model2.save(f.name, package=False)
|
40 |
+
>>> model3 = Model.load(f.name)
|
41 |
+
>>> out3 = seed_and_run(model3, x)
|
42 |
+
>>>
|
43 |
+
>>> with tempfile.TemporaryDirectory() as d:
|
44 |
+
>>> model1.save_to_folder(d, {"data": 1.0})
|
45 |
+
>>> Model.load_from_folder(d)
|
46 |
+
|
47 |
+
"""
|
48 |
+
|
49 |
+
EXTERN = [
|
50 |
+
"audiotools.**",
|
51 |
+
"tqdm",
|
52 |
+
"__main__",
|
53 |
+
"numpy.**",
|
54 |
+
"julius.**",
|
55 |
+
"torchaudio.**",
|
56 |
+
"scipy.**",
|
57 |
+
"einops",
|
58 |
+
]
|
59 |
+
"""Names of libraries that are external to the torch.package saving mechanism.
|
60 |
+
Source code from these libraries will not be packaged into the model. This can
|
61 |
+
be edited by the user of this class by editing ``model.EXTERN``."""
|
62 |
+
INTERN = []
|
63 |
+
"""Names of libraries that are internal to the torch.package saving mechanism.
|
64 |
+
Source code from these libraries will be saved alongside the model."""
|
65 |
+
|
66 |
+
def save(
|
67 |
+
self,
|
68 |
+
path: str,
|
69 |
+
metadata: dict = None,
|
70 |
+
package: bool = True,
|
71 |
+
intern: list = [],
|
72 |
+
extern: list = [],
|
73 |
+
mock: list = [],
|
74 |
+
):
|
75 |
+
"""Saves the model, either as a torch package, or just as
|
76 |
+
weights, alongside some specified metadata.
|
77 |
+
|
78 |
+
Parameters
|
79 |
+
----------
|
80 |
+
path : str
|
81 |
+
Path to save model to.
|
82 |
+
metadata : dict, optional
|
83 |
+
Any metadata to save alongside the model,
|
84 |
+
by default None
|
85 |
+
package : bool, optional
|
86 |
+
Whether to use ``torch.package`` to save the model in
|
87 |
+
a format that is portable, by default True
|
88 |
+
intern : list, optional
|
89 |
+
List of additional libraries that are internal
|
90 |
+
to the model, used with torch.package, by default []
|
91 |
+
extern : list, optional
|
92 |
+
List of additional libraries that are external to
|
93 |
+
the model, used with torch.package, by default []
|
94 |
+
mock : list, optional
|
95 |
+
List of libraries to mock, used with torch.package,
|
96 |
+
by default []
|
97 |
+
|
98 |
+
Returns
|
99 |
+
-------
|
100 |
+
str
|
101 |
+
Path to saved model.
|
102 |
+
"""
|
103 |
+
sig = inspect.signature(self.__class__)
|
104 |
+
args = {}
|
105 |
+
|
106 |
+
for key, val in sig.parameters.items():
|
107 |
+
arg_val = val.default
|
108 |
+
if arg_val is not inspect.Parameter.empty:
|
109 |
+
args[key] = arg_val
|
110 |
+
|
111 |
+
# Look up attibutes in self, and if any of them are in args,
|
112 |
+
# overwrite them in args.
|
113 |
+
for attribute in dir(self):
|
114 |
+
if attribute in args:
|
115 |
+
args[attribute] = getattr(self, attribute)
|
116 |
+
|
117 |
+
metadata = {} if metadata is None else metadata
|
118 |
+
metadata["kwargs"] = args
|
119 |
+
if not hasattr(self, "metadata"):
|
120 |
+
self.metadata = {}
|
121 |
+
self.metadata.update(metadata)
|
122 |
+
|
123 |
+
if not package:
|
124 |
+
state_dict = {"state_dict": self.state_dict(), "metadata": metadata}
|
125 |
+
torch.save(state_dict, path)
|
126 |
+
else:
|
127 |
+
self._save_package(path, intern=intern, extern=extern, mock=mock)
|
128 |
+
|
129 |
+
return path
|
130 |
+
|
131 |
+
@property
|
132 |
+
def device(self):
|
133 |
+
"""Gets the device the model is on by looking at the device of
|
134 |
+
the first parameter. May not be valid if model is split across
|
135 |
+
multiple devices.
|
136 |
+
"""
|
137 |
+
return list(self.parameters())[0].device
|
138 |
+
|
139 |
+
@classmethod
|
140 |
+
def load(
|
141 |
+
cls,
|
142 |
+
location: str,
|
143 |
+
*args,
|
144 |
+
package_name: str = None,
|
145 |
+
strict: bool = False,
|
146 |
+
**kwargs,
|
147 |
+
):
|
148 |
+
"""Load model from a path. Tries first to load as a package, and if
|
149 |
+
that fails, tries to load as weights. The arguments to the class are
|
150 |
+
specified inside the model weights file.
|
151 |
+
|
152 |
+
Parameters
|
153 |
+
----------
|
154 |
+
location : str
|
155 |
+
Path to file.
|
156 |
+
package_name : str, optional
|
157 |
+
Name of package, by default ``cls.__name__``.
|
158 |
+
strict : bool, optional
|
159 |
+
Ignore unmatched keys, by default False
|
160 |
+
kwargs : dict
|
161 |
+
Additional keyword arguments to the model instantiation, if
|
162 |
+
not loading from package.
|
163 |
+
|
164 |
+
Returns
|
165 |
+
-------
|
166 |
+
BaseModel
|
167 |
+
A model that inherits from BaseModel.
|
168 |
+
"""
|
169 |
+
try:
|
170 |
+
model = cls._load_package(location, package_name=package_name)
|
171 |
+
except:
|
172 |
+
model_dict = torch.load(location, "cpu")
|
173 |
+
metadata = model_dict["metadata"]
|
174 |
+
metadata["kwargs"].update(kwargs)
|
175 |
+
|
176 |
+
sig = inspect.signature(cls)
|
177 |
+
class_keys = list(sig.parameters.keys())
|
178 |
+
for k in list(metadata["kwargs"].keys()):
|
179 |
+
if k not in class_keys:
|
180 |
+
metadata["kwargs"].pop(k)
|
181 |
+
|
182 |
+
model = cls(*args, **metadata["kwargs"])
|
183 |
+
model.load_state_dict(model_dict["state_dict"], strict=strict)
|
184 |
+
model.metadata = metadata
|
185 |
+
|
186 |
+
return model
|
187 |
+
|
188 |
+
def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs):
|
189 |
+
package_name = type(self).__name__
|
190 |
+
resource_name = f"{type(self).__name__}.pth"
|
191 |
+
|
192 |
+
# Below is for loading and re-saving a package.
|
193 |
+
if hasattr(self, "importer"):
|
194 |
+
kwargs["importer"] = (self.importer, torch.package.sys_importer)
|
195 |
+
del self.importer
|
196 |
+
|
197 |
+
# Why do we use a tempfile, you ask?
|
198 |
+
# It's so we can load a packaged model and then re-save
|
199 |
+
# it to the same location. torch.package throws an
|
200 |
+
# error if it's loading and writing to the same
|
201 |
+
# file (this is undocumented).
|
202 |
+
with tempfile.NamedTemporaryFile(suffix=".pth") as f:
|
203 |
+
with torch.package.PackageExporter(f.name, **kwargs) as exp:
|
204 |
+
exp.intern(self.INTERN + intern)
|
205 |
+
exp.mock(mock)
|
206 |
+
exp.extern(self.EXTERN + extern)
|
207 |
+
exp.save_pickle(package_name, resource_name, self)
|
208 |
+
|
209 |
+
if hasattr(self, "metadata"):
|
210 |
+
exp.save_pickle(
|
211 |
+
package_name, f"{package_name}.metadata", self.metadata
|
212 |
+
)
|
213 |
+
|
214 |
+
shutil.copyfile(f.name, path)
|
215 |
+
|
216 |
+
# Must reset the importer back to `self` if it existed
|
217 |
+
# so that you can save the model again!
|
218 |
+
if "importer" in kwargs:
|
219 |
+
self.importer = kwargs["importer"][0]
|
220 |
+
return path
|
221 |
+
|
222 |
+
@classmethod
|
223 |
+
def _load_package(cls, path, package_name=None):
|
224 |
+
package_name = cls.__name__ if package_name is None else package_name
|
225 |
+
resource_name = f"{package_name}.pth"
|
226 |
+
|
227 |
+
imp = torch.package.PackageImporter(path)
|
228 |
+
model = imp.load_pickle(package_name, resource_name, "cpu")
|
229 |
+
try:
|
230 |
+
model.metadata = imp.load_pickle(package_name, f"{package_name}.metadata")
|
231 |
+
except: # pragma: no cover
|
232 |
+
pass
|
233 |
+
model.importer = imp
|
234 |
+
|
235 |
+
return model
|
236 |
+
|
237 |
+
def save_to_folder(
|
238 |
+
self,
|
239 |
+
folder: typing.Union[str, Path],
|
240 |
+
extra_data: dict = None,
|
241 |
+
package: bool = True,
|
242 |
+
):
|
243 |
+
"""Dumps a model into a folder, as both a package
|
244 |
+
and as weights, as well as anything specified in
|
245 |
+
``extra_data``. ``extra_data`` is a dictionary of other
|
246 |
+
pickleable files, with the keys being the paths
|
247 |
+
to save them in. The model is saved under a subfolder
|
248 |
+
specified by the name of the class (e.g. ``folder/generator/[package, weights].pth``
|
249 |
+
if the model name was ``Generator``).
|
250 |
+
|
251 |
+
>>> with tempfile.TemporaryDirectory() as d:
|
252 |
+
>>> extra_data = {
|
253 |
+
>>> "optimizer.pth": optimizer.state_dict()
|
254 |
+
>>> }
|
255 |
+
>>> model.save_to_folder(d, extra_data)
|
256 |
+
>>> Model.load_from_folder(d)
|
257 |
+
|
258 |
+
Parameters
|
259 |
+
----------
|
260 |
+
folder : typing.Union[str, Path]
|
261 |
+
_description_
|
262 |
+
extra_data : dict, optional
|
263 |
+
_description_, by default None
|
264 |
+
|
265 |
+
Returns
|
266 |
+
-------
|
267 |
+
str
|
268 |
+
Path to folder
|
269 |
+
"""
|
270 |
+
extra_data = {} if extra_data is None else extra_data
|
271 |
+
model_name = type(self).__name__.lower()
|
272 |
+
target_base = Path(f"{folder}/{model_name}/")
|
273 |
+
target_base.mkdir(exist_ok=True, parents=True)
|
274 |
+
|
275 |
+
if package:
|
276 |
+
package_path = target_base / f"package.pth"
|
277 |
+
self.save(package_path)
|
278 |
+
|
279 |
+
weights_path = target_base / f"weights.pth"
|
280 |
+
self.save(weights_path, package=False)
|
281 |
+
|
282 |
+
for path, obj in extra_data.items():
|
283 |
+
torch.save(obj, target_base / path)
|
284 |
+
|
285 |
+
return target_base
|
286 |
+
|
287 |
+
@classmethod
|
288 |
+
def load_from_folder(
|
289 |
+
cls,
|
290 |
+
folder: typing.Union[str, Path],
|
291 |
+
package: bool = True,
|
292 |
+
strict: bool = False,
|
293 |
+
**kwargs,
|
294 |
+
):
|
295 |
+
"""Loads the model from a folder generated by
|
296 |
+
:py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
|
297 |
+
Like that function, this one looks for a subfolder that has
|
298 |
+
the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the
|
299 |
+
model name was ``Generator``).
|
300 |
+
|
301 |
+
Parameters
|
302 |
+
----------
|
303 |
+
folder : typing.Union[str, Path]
|
304 |
+
_description_
|
305 |
+
package : bool, optional
|
306 |
+
Whether to use ``torch.package`` to load the model,
|
307 |
+
loading the model from ``package.pth``.
|
308 |
+
strict : bool, optional
|
309 |
+
Ignore unmatched keys, by default False
|
310 |
+
|
311 |
+
Returns
|
312 |
+
-------
|
313 |
+
tuple
|
314 |
+
tuple of model and extra data as saved by
|
315 |
+
:py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
|
316 |
+
"""
|
317 |
+
folder = Path(folder) / cls.__name__.lower()
|
318 |
+
model_pth = "package.pth" if package else "weights.pth"
|
319 |
+
model_pth = folder / model_pth
|
320 |
+
|
321 |
+
model = cls.load(model_pth, strict=strict)
|
322 |
+
extra_data = {}
|
323 |
+
excluded = ["package.pth", "weights.pth"]
|
324 |
+
files = [x for x in folder.glob("*") if x.is_file() and x.name not in excluded]
|
325 |
+
for f in files:
|
326 |
+
extra_data[f.name] = torch.load(f, **kwargs)
|
327 |
+
|
328 |
+
return model, extra_data
|
audiotools/ml/layers/spectral_gate.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
from ...core import AudioSignal
|
6 |
+
from ...core import STFTParams
|
7 |
+
from ...core import util
|
8 |
+
|
9 |
+
|
10 |
+
class SpectralGate(nn.Module):
|
11 |
+
"""Spectral gating algorithm for noise reduction,
|
12 |
+
as in Audacity/Ocenaudio. The steps are as follows:
|
13 |
+
|
14 |
+
1. An FFT is calculated over the noise audio clip
|
15 |
+
2. Statistics are calculated over FFT of the the noise
|
16 |
+
(in frequency)
|
17 |
+
3. A threshold is calculated based upon the statistics
|
18 |
+
of the noise (and the desired sensitivity of the algorithm)
|
19 |
+
4. An FFT is calculated over the signal
|
20 |
+
5. A mask is determined by comparing the signal FFT to the
|
21 |
+
threshold
|
22 |
+
6. The mask is smoothed with a filter over frequency and time
|
23 |
+
7. The mask is appled to the FFT of the signal, and is inverted
|
24 |
+
|
25 |
+
Implementation inspired by Tim Sainburg's noisereduce:
|
26 |
+
|
27 |
+
https://timsainburg.com/noise-reduction-python.html
|
28 |
+
|
29 |
+
Parameters
|
30 |
+
----------
|
31 |
+
n_freq : int, optional
|
32 |
+
Number of frequency bins to smooth by, by default 3
|
33 |
+
n_time : int, optional
|
34 |
+
Number of time bins to smooth by, by default 5
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, n_freq: int = 3, n_time: int = 5):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
smoothing_filter = torch.outer(
|
41 |
+
torch.cat(
|
42 |
+
[
|
43 |
+
torch.linspace(0, 1, n_freq + 2)[:-1],
|
44 |
+
torch.linspace(1, 0, n_freq + 2),
|
45 |
+
]
|
46 |
+
)[..., 1:-1],
|
47 |
+
torch.cat(
|
48 |
+
[
|
49 |
+
torch.linspace(0, 1, n_time + 2)[:-1],
|
50 |
+
torch.linspace(1, 0, n_time + 2),
|
51 |
+
]
|
52 |
+
)[..., 1:-1],
|
53 |
+
)
|
54 |
+
smoothing_filter = smoothing_filter / smoothing_filter.sum()
|
55 |
+
smoothing_filter = smoothing_filter.unsqueeze(0).unsqueeze(0)
|
56 |
+
self.register_buffer("smoothing_filter", smoothing_filter)
|
57 |
+
|
58 |
+
def forward(
|
59 |
+
self,
|
60 |
+
audio_signal: AudioSignal,
|
61 |
+
nz_signal: AudioSignal,
|
62 |
+
denoise_amount: float = 1.0,
|
63 |
+
n_std: float = 3.0,
|
64 |
+
win_length: int = 2048,
|
65 |
+
hop_length: int = 512,
|
66 |
+
):
|
67 |
+
"""Perform noise reduction.
|
68 |
+
|
69 |
+
Parameters
|
70 |
+
----------
|
71 |
+
audio_signal : AudioSignal
|
72 |
+
Audio signal that noise will be removed from.
|
73 |
+
nz_signal : AudioSignal, optional
|
74 |
+
Noise signal to compute noise statistics from.
|
75 |
+
denoise_amount : float, optional
|
76 |
+
Amount to denoise by, by default 1.0
|
77 |
+
n_std : float, optional
|
78 |
+
Number of standard deviations above which to consider
|
79 |
+
noise, by default 3.0
|
80 |
+
win_length : int, optional
|
81 |
+
Length of window for STFT, by default 2048
|
82 |
+
hop_length : int, optional
|
83 |
+
Hop length for STFT, by default 512
|
84 |
+
|
85 |
+
Returns
|
86 |
+
-------
|
87 |
+
AudioSignal
|
88 |
+
Denoised audio signal.
|
89 |
+
"""
|
90 |
+
stft_params = STFTParams(win_length, hop_length, "sqrt_hann")
|
91 |
+
|
92 |
+
audio_signal = audio_signal.clone()
|
93 |
+
audio_signal.stft_data = None
|
94 |
+
audio_signal.stft_params = stft_params
|
95 |
+
|
96 |
+
nz_signal = nz_signal.clone()
|
97 |
+
nz_signal.stft_params = stft_params
|
98 |
+
|
99 |
+
nz_stft_db = 20 * nz_signal.magnitude.clamp(1e-4).log10()
|
100 |
+
nz_freq_mean = nz_stft_db.mean(keepdim=True, dim=-1)
|
101 |
+
nz_freq_std = nz_stft_db.std(keepdim=True, dim=-1)
|
102 |
+
|
103 |
+
nz_thresh = nz_freq_mean + nz_freq_std * n_std
|
104 |
+
|
105 |
+
stft_db = 20 * audio_signal.magnitude.clamp(1e-4).log10()
|
106 |
+
nb, nac, nf, nt = stft_db.shape
|
107 |
+
db_thresh = nz_thresh.expand(nb, nac, -1, nt)
|
108 |
+
|
109 |
+
stft_mask = (stft_db < db_thresh).float()
|
110 |
+
shape = stft_mask.shape
|
111 |
+
|
112 |
+
stft_mask = stft_mask.reshape(nb * nac, 1, nf, nt)
|
113 |
+
pad_tuple = (
|
114 |
+
self.smoothing_filter.shape[-2] // 2,
|
115 |
+
self.smoothing_filter.shape[-1] // 2,
|
116 |
+
)
|
117 |
+
stft_mask = F.conv2d(stft_mask, self.smoothing_filter, padding=pad_tuple)
|
118 |
+
stft_mask = stft_mask.reshape(*shape)
|
119 |
+
stft_mask *= util.ensure_tensor(denoise_amount, ndim=stft_mask.ndim).to(
|
120 |
+
audio_signal.device
|
121 |
+
)
|
122 |
+
stft_mask = 1 - stft_mask
|
123 |
+
|
124 |
+
audio_signal.stft_data *= stft_mask
|
125 |
+
audio_signal.istft()
|
126 |
+
|
127 |
+
return audio_signal
|
audiotools/post.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tempfile
|
2 |
+
import typing
|
3 |
+
import zipfile
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import markdown2 as md
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import torch
|
9 |
+
from IPython.display import HTML
|
10 |
+
|
11 |
+
|
12 |
+
def audio_table(
|
13 |
+
audio_dict: dict,
|
14 |
+
first_column: str = None,
|
15 |
+
format_fn: typing.Callable = None,
|
16 |
+
**kwargs,
|
17 |
+
): # pragma: no cover
|
18 |
+
"""Embeds an audio table into HTML, or as the output cell
|
19 |
+
in a notebook.
|
20 |
+
|
21 |
+
Parameters
|
22 |
+
----------
|
23 |
+
audio_dict : dict
|
24 |
+
Dictionary of data to embed.
|
25 |
+
first_column : str, optional
|
26 |
+
The label for the first column of the table, by default None
|
27 |
+
format_fn : typing.Callable, optional
|
28 |
+
How to format the data, by default None
|
29 |
+
|
30 |
+
Returns
|
31 |
+
-------
|
32 |
+
str
|
33 |
+
Table as a string
|
34 |
+
|
35 |
+
Examples
|
36 |
+
--------
|
37 |
+
|
38 |
+
>>> audio_dict = {}
|
39 |
+
>>> for i in range(signal_batch.batch_size):
|
40 |
+
>>> audio_dict[i] = {
|
41 |
+
>>> "input": signal_batch[i],
|
42 |
+
>>> "output": output_batch[i]
|
43 |
+
>>> }
|
44 |
+
>>> audiotools.post.audio_zip(audio_dict)
|
45 |
+
|
46 |
+
"""
|
47 |
+
from audiotools import AudioSignal
|
48 |
+
|
49 |
+
output = []
|
50 |
+
columns = None
|
51 |
+
|
52 |
+
def _default_format_fn(label, x, **kwargs):
|
53 |
+
if torch.is_tensor(x):
|
54 |
+
x = x.tolist()
|
55 |
+
|
56 |
+
if x is None:
|
57 |
+
return "."
|
58 |
+
elif isinstance(x, AudioSignal):
|
59 |
+
return x.embed(display=False, return_html=True, **kwargs)
|
60 |
+
else:
|
61 |
+
return str(x)
|
62 |
+
|
63 |
+
if format_fn is None:
|
64 |
+
format_fn = _default_format_fn
|
65 |
+
|
66 |
+
if first_column is None:
|
67 |
+
first_column = "."
|
68 |
+
|
69 |
+
for k, v in audio_dict.items():
|
70 |
+
if not isinstance(v, dict):
|
71 |
+
v = {"Audio": v}
|
72 |
+
|
73 |
+
v_keys = list(v.keys())
|
74 |
+
if columns is None:
|
75 |
+
columns = [first_column] + v_keys
|
76 |
+
output.append(" | ".join(columns))
|
77 |
+
|
78 |
+
layout = "|---" + len(v_keys) * "|:-:"
|
79 |
+
output.append(layout)
|
80 |
+
|
81 |
+
formatted_audio = []
|
82 |
+
for col in columns[1:]:
|
83 |
+
formatted_audio.append(format_fn(col, v[col], **kwargs))
|
84 |
+
|
85 |
+
row = f"| {k} | "
|
86 |
+
row += " | ".join(formatted_audio)
|
87 |
+
output.append(row)
|
88 |
+
|
89 |
+
output = "\n" + "\n".join(output)
|
90 |
+
return output
|
91 |
+
|
92 |
+
|
93 |
+
def in_notebook(): # pragma: no cover
|
94 |
+
"""Determines if code is running in a notebook.
|
95 |
+
|
96 |
+
Returns
|
97 |
+
-------
|
98 |
+
bool
|
99 |
+
Whether or not this is running in a notebook.
|
100 |
+
"""
|
101 |
+
try:
|
102 |
+
from IPython import get_ipython
|
103 |
+
|
104 |
+
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
|
105 |
+
return False
|
106 |
+
except ImportError:
|
107 |
+
return False
|
108 |
+
except AttributeError:
|
109 |
+
return False
|
110 |
+
return True
|
111 |
+
|
112 |
+
|
113 |
+
def disp(obj, **kwargs): # pragma: no cover
|
114 |
+
"""Displays an object, depending on if its in a notebook
|
115 |
+
or not.
|
116 |
+
|
117 |
+
Parameters
|
118 |
+
----------
|
119 |
+
obj : typing.Any
|
120 |
+
Any object to display.
|
121 |
+
|
122 |
+
"""
|
123 |
+
from audiotools import AudioSignal
|
124 |
+
|
125 |
+
IN_NOTEBOOK = in_notebook()
|
126 |
+
|
127 |
+
if isinstance(obj, AudioSignal):
|
128 |
+
audio_elem = obj.embed(display=False, return_html=True)
|
129 |
+
if IN_NOTEBOOK:
|
130 |
+
return HTML(audio_elem)
|
131 |
+
else:
|
132 |
+
print(audio_elem)
|
133 |
+
if isinstance(obj, dict):
|
134 |
+
table = audio_table(obj, **kwargs)
|
135 |
+
if IN_NOTEBOOK:
|
136 |
+
return HTML(md.markdown(table, extras=["tables"]))
|
137 |
+
else:
|
138 |
+
print(table)
|
139 |
+
if isinstance(obj, plt.Figure):
|
140 |
+
plt.show()
|
audiotools/preference.py
ADDED
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##############################################################
|
2 |
+
### Tools for creating preference tests (MUSHRA, ABX, etc) ###
|
3 |
+
##############################################################
|
4 |
+
import copy
|
5 |
+
import csv
|
6 |
+
import random
|
7 |
+
import sys
|
8 |
+
import traceback
|
9 |
+
from collections import defaultdict
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import List
|
12 |
+
|
13 |
+
import gradio as gr
|
14 |
+
|
15 |
+
from audiotools.core.util import find_audio
|
16 |
+
|
17 |
+
################################################################
|
18 |
+
### Logic for audio player, and adding audio / play buttons. ###
|
19 |
+
################################################################
|
20 |
+
|
21 |
+
WAVESURFER = """<div id="waveform"></div><div id="wave-timeline"></div>"""
|
22 |
+
|
23 |
+
CUSTOM_CSS = """
|
24 |
+
.gradio-container {
|
25 |
+
max-width: 840px !important;
|
26 |
+
}
|
27 |
+
region.wavesurfer-region:before {
|
28 |
+
content: attr(data-region-label);
|
29 |
+
}
|
30 |
+
|
31 |
+
block {
|
32 |
+
min-width: 0 !important;
|
33 |
+
}
|
34 |
+
|
35 |
+
#wave-timeline {
|
36 |
+
background-color: rgba(0, 0, 0, 0.8);
|
37 |
+
}
|
38 |
+
|
39 |
+
.head.svelte-1cl284s {
|
40 |
+
display: none;
|
41 |
+
}
|
42 |
+
"""
|
43 |
+
|
44 |
+
load_wavesurfer_js = """
|
45 |
+
function load_wavesurfer() {
|
46 |
+
function load_script(url) {
|
47 |
+
const script = document.createElement('script');
|
48 |
+
script.src = url;
|
49 |
+
document.body.appendChild(script);
|
50 |
+
|
51 |
+
return new Promise((res, rej) => {
|
52 |
+
script.onload = function() {
|
53 |
+
res();
|
54 |
+
}
|
55 |
+
script.onerror = function () {
|
56 |
+
rej();
|
57 |
+
}
|
58 |
+
});
|
59 |
+
}
|
60 |
+
|
61 |
+
function create_wavesurfer() {
|
62 |
+
var options = {
|
63 |
+
container: '#waveform',
|
64 |
+
waveColor: '#F2F2F2', // Set a darker wave color
|
65 |
+
progressColor: 'white', // Set a slightly lighter progress color
|
66 |
+
loaderColor: 'white', // Set a slightly lighter loader color
|
67 |
+
cursorColor: 'black', // Set a slightly lighter cursor color
|
68 |
+
backgroundColor: '#00AAFF', // Set a black background color
|
69 |
+
barWidth: 4,
|
70 |
+
barRadius: 3,
|
71 |
+
barHeight: 1, // the height of the wave
|
72 |
+
plugins: [
|
73 |
+
WaveSurfer.regions.create({
|
74 |
+
regionsMinLength: 0.0,
|
75 |
+
dragSelection: {
|
76 |
+
slop: 5
|
77 |
+
},
|
78 |
+
color: 'hsla(200, 50%, 70%, 0.4)',
|
79 |
+
}),
|
80 |
+
WaveSurfer.timeline.create({
|
81 |
+
container: "#wave-timeline",
|
82 |
+
primaryLabelInterval: 5.0,
|
83 |
+
secondaryLabelInterval: 1.0,
|
84 |
+
primaryFontColor: '#F2F2F2',
|
85 |
+
secondaryFontColor: '#F2F2F2',
|
86 |
+
}),
|
87 |
+
]
|
88 |
+
};
|
89 |
+
wavesurfer = WaveSurfer.create(options);
|
90 |
+
wavesurfer.on('region-created', region => {
|
91 |
+
wavesurfer.regions.clear();
|
92 |
+
});
|
93 |
+
wavesurfer.on('finish', function () {
|
94 |
+
var loop = document.getElementById("loop-button").textContent.includes("ON");
|
95 |
+
if (loop) {
|
96 |
+
wavesurfer.play();
|
97 |
+
}
|
98 |
+
else {
|
99 |
+
var button_elements = document.getElementsByClassName('playpause')
|
100 |
+
var buttons = Array.from(button_elements);
|
101 |
+
|
102 |
+
for (let j = 0; j < buttons.length; j++) {
|
103 |
+
buttons[j].classList.remove("primary");
|
104 |
+
buttons[j].classList.add("secondary");
|
105 |
+
buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
|
106 |
+
}
|
107 |
+
}
|
108 |
+
});
|
109 |
+
|
110 |
+
wavesurfer.on('region-out', function () {
|
111 |
+
var loop = document.getElementById("loop-button").textContent.includes("ON");
|
112 |
+
if (!loop) {
|
113 |
+
var button_elements = document.getElementsByClassName('playpause')
|
114 |
+
var buttons = Array.from(button_elements);
|
115 |
+
|
116 |
+
for (let j = 0; j < buttons.length; j++) {
|
117 |
+
buttons[j].classList.remove("primary");
|
118 |
+
buttons[j].classList.add("secondary");
|
119 |
+
buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
|
120 |
+
}
|
121 |
+
wavesurfer.pause();
|
122 |
+
}
|
123 |
+
});
|
124 |
+
|
125 |
+
console.log("Created WaveSurfer object.")
|
126 |
+
}
|
127 |
+
|
128 |
+
load_script('https://unpkg.com/wavesurfer.js@6.6.4')
|
129 |
+
.then(() => {
|
130 |
+
load_script("https://unpkg.com/wavesurfer.js@6.6.4/dist/plugin/wavesurfer.timeline.min.js")
|
131 |
+
.then(() => {
|
132 |
+
load_script('https://unpkg.com/wavesurfer.js@6.6.4/dist/plugin/wavesurfer.regions.min.js')
|
133 |
+
.then(() => {
|
134 |
+
console.log("Loaded regions");
|
135 |
+
create_wavesurfer();
|
136 |
+
document.getElementById("start-survey").click();
|
137 |
+
})
|
138 |
+
})
|
139 |
+
});
|
140 |
+
}
|
141 |
+
"""
|
142 |
+
|
143 |
+
play = lambda i: """
|
144 |
+
function play() {
|
145 |
+
var audio_elements = document.getElementsByTagName('audio');
|
146 |
+
var button_elements = document.getElementsByClassName('playpause')
|
147 |
+
|
148 |
+
var audio_array = Array.from(audio_elements);
|
149 |
+
var buttons = Array.from(button_elements);
|
150 |
+
|
151 |
+
var src_link = audio_array[{i}].getAttribute("src");
|
152 |
+
console.log(src_link);
|
153 |
+
|
154 |
+
var loop = document.getElementById("loop-button").textContent.includes("ON");
|
155 |
+
var playing = buttons[{i}].textContent.includes("Stop");
|
156 |
+
|
157 |
+
for (let j = 0; j < buttons.length; j++) {
|
158 |
+
if (j != {i} || playing) {
|
159 |
+
buttons[j].classList.remove("primary");
|
160 |
+
buttons[j].classList.add("secondary");
|
161 |
+
buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
|
162 |
+
}
|
163 |
+
else {
|
164 |
+
buttons[j].classList.remove("secondary");
|
165 |
+
buttons[j].classList.add("primary");
|
166 |
+
buttons[j].textContent = buttons[j].textContent.replace("Play", "Stop")
|
167 |
+
}
|
168 |
+
}
|
169 |
+
|
170 |
+
if (playing) {
|
171 |
+
wavesurfer.pause();
|
172 |
+
wavesurfer.seekTo(0.0);
|
173 |
+
}
|
174 |
+
else {
|
175 |
+
wavesurfer.load(src_link);
|
176 |
+
wavesurfer.on('ready', function () {
|
177 |
+
var region = Object.values(wavesurfer.regions.list)[0];
|
178 |
+
|
179 |
+
if (region != null) {
|
180 |
+
region.loop = loop;
|
181 |
+
region.play();
|
182 |
+
} else {
|
183 |
+
wavesurfer.play();
|
184 |
+
}
|
185 |
+
});
|
186 |
+
}
|
187 |
+
}
|
188 |
+
""".replace(
|
189 |
+
"{i}", str(i)
|
190 |
+
)
|
191 |
+
|
192 |
+
clear_regions = """
|
193 |
+
function clear_regions() {
|
194 |
+
wavesurfer.clearRegions();
|
195 |
+
}
|
196 |
+
"""
|
197 |
+
|
198 |
+
reset_player = """
|
199 |
+
function reset_player() {
|
200 |
+
wavesurfer.clearRegions();
|
201 |
+
wavesurfer.pause();
|
202 |
+
wavesurfer.seekTo(0.0);
|
203 |
+
|
204 |
+
var button_elements = document.getElementsByClassName('playpause')
|
205 |
+
var buttons = Array.from(button_elements);
|
206 |
+
|
207 |
+
for (let j = 0; j < buttons.length; j++) {
|
208 |
+
buttons[j].classList.remove("primary");
|
209 |
+
buttons[j].classList.add("secondary");
|
210 |
+
buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
|
211 |
+
}
|
212 |
+
}
|
213 |
+
"""
|
214 |
+
|
215 |
+
loop_region = """
|
216 |
+
function loop_region() {
|
217 |
+
var element = document.getElementById("loop-button");
|
218 |
+
var loop = element.textContent.includes("OFF");
|
219 |
+
console.log(loop);
|
220 |
+
|
221 |
+
try {
|
222 |
+
var region = Object.values(wavesurfer.regions.list)[0];
|
223 |
+
region.loop = loop;
|
224 |
+
} catch {}
|
225 |
+
|
226 |
+
if (loop) {
|
227 |
+
element.classList.remove("secondary");
|
228 |
+
element.classList.add("primary");
|
229 |
+
element.textContent = "Looping ON";
|
230 |
+
} else {
|
231 |
+
element.classList.remove("primary");
|
232 |
+
element.classList.add("secondary");
|
233 |
+
element.textContent = "Looping OFF";
|
234 |
+
}
|
235 |
+
}
|
236 |
+
"""
|
237 |
+
|
238 |
+
|
239 |
+
class Player:
|
240 |
+
def __init__(self, app):
|
241 |
+
self.app = app
|
242 |
+
|
243 |
+
self.app.load(_js=load_wavesurfer_js)
|
244 |
+
self.app.css = CUSTOM_CSS
|
245 |
+
|
246 |
+
self.wavs = []
|
247 |
+
self.position = 0
|
248 |
+
|
249 |
+
def create(self):
|
250 |
+
gr.HTML(WAVESURFER)
|
251 |
+
gr.Markdown(
|
252 |
+
"Click and drag on the waveform above to select a region for playback. "
|
253 |
+
"Once created, the region can be moved around and resized. "
|
254 |
+
"Clear the regions using the button below. Hit play on one of the buttons below to start!"
|
255 |
+
)
|
256 |
+
|
257 |
+
with gr.Row():
|
258 |
+
clear = gr.Button("Clear region")
|
259 |
+
loop = gr.Button("Looping OFF", elem_id="loop-button")
|
260 |
+
|
261 |
+
loop.click(None, _js=loop_region)
|
262 |
+
clear.click(None, _js=clear_regions)
|
263 |
+
|
264 |
+
gr.HTML("<hr>")
|
265 |
+
|
266 |
+
def add(self, name: str = "Play"):
|
267 |
+
i = self.position
|
268 |
+
self.wavs.append(
|
269 |
+
{
|
270 |
+
"audio": gr.Audio(visible=False),
|
271 |
+
"button": gr.Button(name, elem_classes=["playpause"]),
|
272 |
+
"position": i,
|
273 |
+
}
|
274 |
+
)
|
275 |
+
self.wavs[-1]["button"].click(None, _js=play(i))
|
276 |
+
self.position += 1
|
277 |
+
return self.wavs[-1]
|
278 |
+
|
279 |
+
def to_list(self):
|
280 |
+
return [x["audio"] for x in self.wavs]
|
281 |
+
|
282 |
+
|
283 |
+
############################################################
|
284 |
+
### Keeping track of users, and CSS for the progress bar ###
|
285 |
+
############################################################
|
286 |
+
|
287 |
+
load_tracker = lambda name: """
|
288 |
+
function load_name() {
|
289 |
+
function setCookie(name, value, exp_days) {
|
290 |
+
var d = new Date();
|
291 |
+
d.setTime(d.getTime() + (exp_days*24*60*60*1000));
|
292 |
+
var expires = "expires=" + d.toGMTString();
|
293 |
+
document.cookie = name + "=" + value + ";" + expires + ";path=/";
|
294 |
+
}
|
295 |
+
|
296 |
+
function getCookie(name) {
|
297 |
+
var cname = name + "=";
|
298 |
+
var decodedCookie = decodeURIComponent(document.cookie);
|
299 |
+
var ca = decodedCookie.split(';');
|
300 |
+
for(var i = 0; i < ca.length; i++){
|
301 |
+
var c = ca[i];
|
302 |
+
while(c.charAt(0) == ' '){
|
303 |
+
c = c.substring(1);
|
304 |
+
}
|
305 |
+
if(c.indexOf(cname) == 0){
|
306 |
+
return c.substring(cname.length, c.length);
|
307 |
+
}
|
308 |
+
}
|
309 |
+
return "";
|
310 |
+
}
|
311 |
+
|
312 |
+
name = getCookie("{name}");
|
313 |
+
if (name == "") {
|
314 |
+
name = Math.random().toString(36).slice(2);
|
315 |
+
console.log(name);
|
316 |
+
setCookie("name", name, 30);
|
317 |
+
}
|
318 |
+
name = getCookie("{name}");
|
319 |
+
return name;
|
320 |
+
}
|
321 |
+
""".replace(
|
322 |
+
"{name}", name
|
323 |
+
)
|
324 |
+
|
325 |
+
# Progress bar
|
326 |
+
|
327 |
+
progress_template = """
|
328 |
+
<!DOCTYPE html>
|
329 |
+
<html>
|
330 |
+
<head>
|
331 |
+
<title>Progress Bar</title>
|
332 |
+
<style>
|
333 |
+
.progress-bar {
|
334 |
+
background-color: #ddd;
|
335 |
+
border-radius: 4px;
|
336 |
+
height: 30px;
|
337 |
+
width: 100%;
|
338 |
+
position: relative;
|
339 |
+
}
|
340 |
+
|
341 |
+
.progress {
|
342 |
+
background-color: #00AAFF;
|
343 |
+
border-radius: 4px;
|
344 |
+
height: 100%;
|
345 |
+
width: {PROGRESS}%; /* Change this value to control the progress */
|
346 |
+
}
|
347 |
+
|
348 |
+
.progress-text {
|
349 |
+
position: absolute;
|
350 |
+
top: 50%;
|
351 |
+
left: 50%;
|
352 |
+
transform: translate(-50%, -50%);
|
353 |
+
font-size: 18px;
|
354 |
+
font-family: Arial, sans-serif;
|
355 |
+
font-weight: bold;
|
356 |
+
color: #333 !important;
|
357 |
+
text-shadow: 1px 1px #fff;
|
358 |
+
}
|
359 |
+
</style>
|
360 |
+
</head>
|
361 |
+
<body>
|
362 |
+
<div class="progress-bar">
|
363 |
+
<div class="progress"></div>
|
364 |
+
<div class="progress-text">{TEXT}</div>
|
365 |
+
</div>
|
366 |
+
</body>
|
367 |
+
</html>
|
368 |
+
"""
|
369 |
+
|
370 |
+
|
371 |
+
def create_tracker(app, cookie_name="name"):
|
372 |
+
user = gr.Text(label="user", interactive=True, visible=False, elem_id="user")
|
373 |
+
app.load(_js=load_tracker(cookie_name), outputs=user)
|
374 |
+
return user
|
375 |
+
|
376 |
+
|
377 |
+
#################################################################
|
378 |
+
### CSS and HTML for labeling sliders for both ABX and MUSHRA ###
|
379 |
+
#################################################################
|
380 |
+
|
381 |
+
slider_abx = """
|
382 |
+
<!DOCTYPE html>
|
383 |
+
<html>
|
384 |
+
<head>
|
385 |
+
<meta charset="UTF-8">
|
386 |
+
<title>Labels Example</title>
|
387 |
+
<style>
|
388 |
+
body {
|
389 |
+
margin: 0;
|
390 |
+
padding: 0;
|
391 |
+
}
|
392 |
+
|
393 |
+
.labels-container {
|
394 |
+
display: flex;
|
395 |
+
justify-content: space-between;
|
396 |
+
align-items: center;
|
397 |
+
width: 100%;
|
398 |
+
height: 40px;
|
399 |
+
padding: 0px 12px 0px;
|
400 |
+
}
|
401 |
+
|
402 |
+
.label {
|
403 |
+
display: flex;
|
404 |
+
justify-content: center;
|
405 |
+
align-items: center;
|
406 |
+
width: 33%;
|
407 |
+
height: 100%;
|
408 |
+
font-weight: bold;
|
409 |
+
text-transform: uppercase;
|
410 |
+
padding: 10px;
|
411 |
+
font-family: Arial, sans-serif;
|
412 |
+
font-size: 16px;
|
413 |
+
font-weight: 700;
|
414 |
+
letter-spacing: 1px;
|
415 |
+
line-height: 1.5;
|
416 |
+
}
|
417 |
+
|
418 |
+
.label-a {
|
419 |
+
background-color: #00AAFF;
|
420 |
+
color: #333 !important;
|
421 |
+
}
|
422 |
+
|
423 |
+
.label-tie {
|
424 |
+
background-color: #f97316;
|
425 |
+
color: #333 !important;
|
426 |
+
}
|
427 |
+
|
428 |
+
.label-b {
|
429 |
+
background-color: #00AAFF;
|
430 |
+
color: #333 !important;
|
431 |
+
}
|
432 |
+
</style>
|
433 |
+
</head>
|
434 |
+
<body>
|
435 |
+
<div class="labels-container">
|
436 |
+
<div class="label label-a">Prefer A</div>
|
437 |
+
<div class="label label-tie">Toss-up</div>
|
438 |
+
<div class="label label-b">Prefer B</div>
|
439 |
+
</div>
|
440 |
+
</body>
|
441 |
+
</html>
|
442 |
+
"""
|
443 |
+
|
444 |
+
slider_mushra = """
|
445 |
+
<!DOCTYPE html>
|
446 |
+
<html>
|
447 |
+
<head>
|
448 |
+
<meta charset="UTF-8">
|
449 |
+
<title>Labels Example</title>
|
450 |
+
<style>
|
451 |
+
body {
|
452 |
+
margin: 0;
|
453 |
+
padding: 0;
|
454 |
+
}
|
455 |
+
|
456 |
+
.labels-container {
|
457 |
+
display: flex;
|
458 |
+
justify-content: space-between;
|
459 |
+
align-items: center;
|
460 |
+
width: 100%;
|
461 |
+
height: 30px;
|
462 |
+
padding: 10px;
|
463 |
+
}
|
464 |
+
|
465 |
+
.label {
|
466 |
+
display: flex;
|
467 |
+
justify-content: center;
|
468 |
+
align-items: center;
|
469 |
+
width: 20%;
|
470 |
+
height: 100%;
|
471 |
+
font-weight: bold;
|
472 |
+
text-transform: uppercase;
|
473 |
+
padding: 10px;
|
474 |
+
font-family: Arial, sans-serif;
|
475 |
+
font-size: 13.5px;
|
476 |
+
font-weight: 700;
|
477 |
+
line-height: 1.5;
|
478 |
+
}
|
479 |
+
|
480 |
+
.label-bad {
|
481 |
+
background-color: #ff5555;
|
482 |
+
color: #333 !important;
|
483 |
+
}
|
484 |
+
|
485 |
+
.label-poor {
|
486 |
+
background-color: #ffa500;
|
487 |
+
color: #333 !important;
|
488 |
+
}
|
489 |
+
|
490 |
+
.label-fair {
|
491 |
+
background-color: #ffd700;
|
492 |
+
color: #333 !important;
|
493 |
+
}
|
494 |
+
|
495 |
+
.label-good {
|
496 |
+
background-color: #97d997;
|
497 |
+
color: #333 !important;
|
498 |
+
}
|
499 |
+
|
500 |
+
.label-excellent {
|
501 |
+
background-color: #04c822;
|
502 |
+
color: #333 !important;
|
503 |
+
}
|
504 |
+
</style>
|
505 |
+
</head>
|
506 |
+
<body>
|
507 |
+
<div class="labels-container">
|
508 |
+
<div class="label label-bad">bad</div>
|
509 |
+
<div class="label label-poor">poor</div>
|
510 |
+
<div class="label label-fair">fair</div>
|
511 |
+
<div class="label label-good">good</div>
|
512 |
+
<div class="label label-excellent">excellent</div>
|
513 |
+
</div>
|
514 |
+
</body>
|
515 |
+
</html>
|
516 |
+
"""
|
517 |
+
|
518 |
+
#########################################################
|
519 |
+
### Handling loading audio and tracking session state ###
|
520 |
+
#########################################################
|
521 |
+
|
522 |
+
|
523 |
+
class Samples:
|
524 |
+
def __init__(self, folder: str, shuffle: bool = True, n_samples: int = None):
|
525 |
+
files = find_audio(folder)
|
526 |
+
samples = defaultdict(lambda: defaultdict())
|
527 |
+
|
528 |
+
for f in files:
|
529 |
+
condition = f.parent.stem
|
530 |
+
samples[f.name][condition] = f
|
531 |
+
|
532 |
+
self.samples = samples
|
533 |
+
self.names = list(samples.keys())
|
534 |
+
self.filtered = False
|
535 |
+
self.current = 0
|
536 |
+
|
537 |
+
if shuffle:
|
538 |
+
random.shuffle(self.names)
|
539 |
+
|
540 |
+
self.n_samples = len(self.names) if n_samples is None else n_samples
|
541 |
+
|
542 |
+
def get_updates(self, idx, order):
|
543 |
+
key = self.names[idx]
|
544 |
+
return [gr.update(value=str(self.samples[key][o])) for o in order]
|
545 |
+
|
546 |
+
def progress(self):
|
547 |
+
try:
|
548 |
+
pct = self.current / len(self) * 100
|
549 |
+
except: # pragma: no cover
|
550 |
+
pct = 100
|
551 |
+
text = f"On {self.current} / {len(self)} samples"
|
552 |
+
pbar = (
|
553 |
+
copy.copy(progress_template)
|
554 |
+
.replace("{PROGRESS}", str(pct))
|
555 |
+
.replace("{TEXT}", str(text))
|
556 |
+
)
|
557 |
+
return gr.update(value=pbar)
|
558 |
+
|
559 |
+
def __len__(self):
|
560 |
+
return self.n_samples
|
561 |
+
|
562 |
+
def filter_completed(self, user, save_path):
|
563 |
+
if not self.filtered:
|
564 |
+
done = []
|
565 |
+
if Path(save_path).exists():
|
566 |
+
with open(save_path, "r") as f:
|
567 |
+
reader = csv.DictReader(f)
|
568 |
+
done = [r["sample"] for r in reader if r["user"] == user]
|
569 |
+
self.names = [k for k in self.names if k not in done]
|
570 |
+
self.names = self.names[: self.n_samples]
|
571 |
+
self.filtered = True # Avoid filtering more than once per session.
|
572 |
+
|
573 |
+
def get_next_sample(self, reference, conditions):
|
574 |
+
random.shuffle(conditions)
|
575 |
+
if reference is not None:
|
576 |
+
self.order = [reference] + conditions
|
577 |
+
else:
|
578 |
+
self.order = conditions
|
579 |
+
|
580 |
+
try:
|
581 |
+
updates = self.get_updates(self.current, self.order)
|
582 |
+
self.current += 1
|
583 |
+
done = gr.update(interactive=True)
|
584 |
+
pbar = self.progress()
|
585 |
+
except:
|
586 |
+
traceback.print_exc()
|
587 |
+
updates = [gr.update() for _ in range(len(self.order))]
|
588 |
+
done = gr.update(value="No more samples!", interactive=False)
|
589 |
+
self.current = len(self)
|
590 |
+
pbar = self.progress()
|
591 |
+
|
592 |
+
return updates, done, pbar
|
593 |
+
|
594 |
+
|
595 |
+
def save_result(result, save_path):
|
596 |
+
with open(save_path, mode="a", newline="") as file:
|
597 |
+
writer = csv.DictWriter(file, fieldnames=sorted(list(result.keys())))
|
598 |
+
if file.tell() == 0:
|
599 |
+
writer.writeheader()
|
600 |
+
writer.writerow(result)
|
src/inference.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
import soundfile as sf
|
8 |
+
from tqdm import tqdm
|
9 |
+
from .utils import scale_shift_re
|
10 |
+
|
11 |
+
|
12 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
13 |
+
"""
|
14 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
15 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
16 |
+
"""
|
17 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
18 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
19 |
+
# rescale the results from guidance (fixes overexposure)
|
20 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
21 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
22 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
23 |
+
return noise_cfg
|
24 |
+
|
25 |
+
|
26 |
+
@torch.no_grad()
|
27 |
+
def inference(autoencoder, unet, gt, gt_mask,
|
28 |
+
tokenizer, text_encoder,
|
29 |
+
params, noise_scheduler,
|
30 |
+
text_raw, neg_text=None,
|
31 |
+
audio_frames=500,
|
32 |
+
guidance_scale=3, guidance_rescale=0.0,
|
33 |
+
ddim_steps=50, eta=1, random_seed=2024,
|
34 |
+
device='cuda',
|
35 |
+
):
|
36 |
+
if neg_text is None:
|
37 |
+
neg_text = [""]
|
38 |
+
if tokenizer is not None:
|
39 |
+
text_batch = tokenizer(text_raw,
|
40 |
+
max_length=params['text_encoder']['max_length'],
|
41 |
+
padding="max_length", truncation=True, return_tensors="pt")
|
42 |
+
text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool()
|
43 |
+
text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state
|
44 |
+
|
45 |
+
uncond_text_batch = tokenizer(neg_text,
|
46 |
+
max_length=params['text_encoder']['max_length'],
|
47 |
+
padding="max_length", truncation=True, return_tensors="pt")
|
48 |
+
uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool()
|
49 |
+
uncond_text = text_encoder(input_ids=uncond_text,
|
50 |
+
attention_mask=uncond_text_mask).last_hidden_state
|
51 |
+
else:
|
52 |
+
text, text_mask = None, None
|
53 |
+
guidance_scale = None
|
54 |
+
|
55 |
+
codec_dim = params['model']['out_chans']
|
56 |
+
unet.eval()
|
57 |
+
|
58 |
+
if random_seed is not None:
|
59 |
+
generator = torch.Generator(device=device).manual_seed(random_seed)
|
60 |
+
else:
|
61 |
+
generator = torch.Generator(device=device)
|
62 |
+
generator.seed()
|
63 |
+
|
64 |
+
noise_scheduler.set_timesteps(ddim_steps)
|
65 |
+
|
66 |
+
# init noise
|
67 |
+
noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device)
|
68 |
+
latents = noise
|
69 |
+
|
70 |
+
for t in noise_scheduler.timesteps:
|
71 |
+
latents = noise_scheduler.scale_model_input(latents, t)
|
72 |
+
|
73 |
+
if guidance_scale:
|
74 |
+
|
75 |
+
latents_combined = torch.cat([latents, latents], dim=0)
|
76 |
+
text_combined = torch.cat([text, uncond_text], dim=0)
|
77 |
+
text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0)
|
78 |
+
|
79 |
+
if gt is not None:
|
80 |
+
gt_combined = torch.cat([gt, gt], dim=0)
|
81 |
+
gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0)
|
82 |
+
else:
|
83 |
+
gt_combined = None
|
84 |
+
gt_mask_combined = None
|
85 |
+
|
86 |
+
output_combined, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined,
|
87 |
+
cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined)
|
88 |
+
output_text, output_uncond = torch.chunk(output_combined, 2, dim=0)
|
89 |
+
|
90 |
+
output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
|
91 |
+
if guidance_rescale > 0.0:
|
92 |
+
output_pred = rescale_noise_cfg(output_pred, output_text,
|
93 |
+
guidance_rescale=guidance_rescale)
|
94 |
+
else:
|
95 |
+
output_pred, mae_mask = unet(latents, t, text, context_mask=text_mask,
|
96 |
+
cls_token=None, gt=gt, mae_mask_infer=gt_mask)
|
97 |
+
|
98 |
+
latents = noise_scheduler.step(model_output=output_pred, timestep=t,
|
99 |
+
sample=latents,
|
100 |
+
eta=eta, generator=generator).prev_sample
|
101 |
+
|
102 |
+
pred = scale_shift_re(latents, params['autoencoder']['scale'],
|
103 |
+
params['autoencoder']['shift'])
|
104 |
+
if gt is not None:
|
105 |
+
pred[~gt_mask] = gt[~gt_mask]
|
106 |
+
pred_wav = autoencoder(embedding=pred)
|
107 |
+
return pred_wav
|
108 |
+
|
109 |
+
|
110 |
+
@torch.no_grad()
|
111 |
+
def eval_udit(autoencoder, unet,
|
112 |
+
tokenizer, text_encoder,
|
113 |
+
params, noise_scheduler,
|
114 |
+
val_df, subset,
|
115 |
+
audio_frames, mae=False,
|
116 |
+
guidance_scale=3, guidance_rescale=0.0,
|
117 |
+
ddim_steps=50, eta=1, random_seed=2023,
|
118 |
+
device='cuda',
|
119 |
+
epoch=0, save_path='logs/eval/', val_num=5):
|
120 |
+
val_df = pd.read_csv(val_df)
|
121 |
+
val_df = val_df[val_df['split'] == subset]
|
122 |
+
if mae:
|
123 |
+
val_df = val_df[val_df['audio_length'] != 0]
|
124 |
+
|
125 |
+
save_path = save_path + str(epoch) + '/'
|
126 |
+
os.makedirs(save_path, exist_ok=True)
|
127 |
+
|
128 |
+
for i in tqdm(range(len(val_df))):
|
129 |
+
row = val_df.iloc[i]
|
130 |
+
text = [row['caption']]
|
131 |
+
if mae:
|
132 |
+
audio_path = params['data']['val_dir'] + str(row['audio_path'])
|
133 |
+
gt, sr = librosa.load(audio_path, sr=params['data']['sr'])
|
134 |
+
gt = gt / (np.max(np.abs(gt)) + 1e-9)
|
135 |
+
sf.write(save_path + text[0] + '_gt.wav', gt, samplerate=params['data']['sr'])
|
136 |
+
num_samples = 10 * sr
|
137 |
+
if len(gt) < num_samples:
|
138 |
+
padding = num_samples - len(gt)
|
139 |
+
gt = np.pad(gt, (0, padding), 'constant')
|
140 |
+
else:
|
141 |
+
gt = gt[:num_samples]
|
142 |
+
gt = torch.tensor(gt).unsqueeze(0).unsqueeze(1).to(device)
|
143 |
+
gt = autoencoder(audio=gt)
|
144 |
+
B, D, L = gt.shape
|
145 |
+
mask_len = int(L * 0.2)
|
146 |
+
gt_mask = torch.zeros(B, D, L).to(device)
|
147 |
+
for _ in range(2):
|
148 |
+
start = random.randint(0, L - mask_len)
|
149 |
+
gt_mask[:, :, start:start + mask_len] = 1
|
150 |
+
gt_mask = gt_mask.bool()
|
151 |
+
else:
|
152 |
+
gt = None
|
153 |
+
gt_mask = None
|
154 |
+
|
155 |
+
pred = inference(autoencoder, unet, gt, gt_mask,
|
156 |
+
tokenizer, text_encoder,
|
157 |
+
params, noise_scheduler,
|
158 |
+
text, neg_text=None,
|
159 |
+
audio_frames=audio_frames,
|
160 |
+
guidance_scale=guidance_scale, guidance_rescale=guidance_rescale,
|
161 |
+
ddim_steps=ddim_steps, eta=eta, random_seed=random_seed,
|
162 |
+
device=device)
|
163 |
+
|
164 |
+
pred = pred.cpu().numpy().squeeze(0).squeeze(0)
|
165 |
+
|
166 |
+
sf.write(save_path + text[0] + '.wav', pred, samplerate=params['data']['sr'])
|
167 |
+
|
168 |
+
if i + 1 >= val_num:
|
169 |
+
break
|
src/inference_controlnet.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
import soundfile as sf
|
8 |
+
from tqdm import tqdm
|
9 |
+
from .utils import scale_shift_re
|
10 |
+
|
11 |
+
|
12 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
13 |
+
"""
|
14 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
15 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
16 |
+
"""
|
17 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
18 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
19 |
+
# rescale the results from guidance (fixes overexposure)
|
20 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
21 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
22 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
23 |
+
return noise_cfg
|
24 |
+
|
25 |
+
|
26 |
+
@torch.no_grad()
|
27 |
+
def inference(autoencoder, unet, controlnet,
|
28 |
+
gt, gt_mask, condition,
|
29 |
+
tokenizer, text_encoder,
|
30 |
+
params, noise_scheduler,
|
31 |
+
text_raw, neg_text=None,
|
32 |
+
audio_frames=500,
|
33 |
+
guidance_scale=3, guidance_rescale=0.0,
|
34 |
+
ddim_steps=50, eta=1, random_seed=2024,
|
35 |
+
conditioning_scale=1.0,
|
36 |
+
device='cuda',
|
37 |
+
):
|
38 |
+
if neg_text is None:
|
39 |
+
neg_text = [""]
|
40 |
+
if tokenizer is not None:
|
41 |
+
text_batch = tokenizer(text_raw,
|
42 |
+
max_length=params['text_encoder']['max_length'],
|
43 |
+
padding="max_length", truncation=True, return_tensors="pt")
|
44 |
+
text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool()
|
45 |
+
text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state
|
46 |
+
|
47 |
+
uncond_text_batch = tokenizer(neg_text,
|
48 |
+
max_length=params['text_encoder']['max_length'],
|
49 |
+
padding="max_length", truncation=True, return_tensors="pt")
|
50 |
+
uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool()
|
51 |
+
uncond_text = text_encoder(input_ids=uncond_text,
|
52 |
+
attention_mask=uncond_text_mask).last_hidden_state
|
53 |
+
else:
|
54 |
+
text, text_mask = None, None
|
55 |
+
guidance_scale = None
|
56 |
+
|
57 |
+
codec_dim = params['model']['out_chans']
|
58 |
+
unet.eval()
|
59 |
+
controlnet.eval()
|
60 |
+
|
61 |
+
if random_seed is not None:
|
62 |
+
generator = torch.Generator(device=device).manual_seed(random_seed)
|
63 |
+
else:
|
64 |
+
generator = torch.Generator(device=device)
|
65 |
+
generator.seed()
|
66 |
+
|
67 |
+
noise_scheduler.set_timesteps(ddim_steps)
|
68 |
+
|
69 |
+
# init noise
|
70 |
+
noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device)
|
71 |
+
latents = noise
|
72 |
+
|
73 |
+
for t in noise_scheduler.timesteps:
|
74 |
+
latents = noise_scheduler.scale_model_input(latents, t)
|
75 |
+
|
76 |
+
if guidance_scale:
|
77 |
+
latents_combined = torch.cat([latents, latents], dim=0)
|
78 |
+
text_combined = torch.cat([text, uncond_text], dim=0)
|
79 |
+
text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0)
|
80 |
+
condition_combined = torch.cat([condition, condition], dim=0)
|
81 |
+
|
82 |
+
if gt is not None:
|
83 |
+
gt_combined = torch.cat([gt, gt], dim=0)
|
84 |
+
gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0)
|
85 |
+
else:
|
86 |
+
gt_combined = None
|
87 |
+
gt_mask_combined = None
|
88 |
+
|
89 |
+
x, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined,
|
90 |
+
cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined,
|
91 |
+
forward_model=False)
|
92 |
+
controlnet_skips = controlnet(x, t, text_combined,
|
93 |
+
context_mask=text_mask_combined,
|
94 |
+
cls_token=None,
|
95 |
+
condition=condition_combined,
|
96 |
+
conditioning_scale=conditioning_scale)
|
97 |
+
output_combined = unet.model(x, t, text_combined,
|
98 |
+
context_mask=text_mask_combined,
|
99 |
+
cls_token=None, controlnet_skips=controlnet_skips)
|
100 |
+
|
101 |
+
output_text, output_uncond = torch.chunk(output_combined, 2, dim=0)
|
102 |
+
|
103 |
+
output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
|
104 |
+
if guidance_rescale > 0.0:
|
105 |
+
output_pred = rescale_noise_cfg(output_pred, output_text,
|
106 |
+
guidance_rescale=guidance_rescale)
|
107 |
+
else:
|
108 |
+
x, _ = unet(latents, t, text, context_mask=text_mask,
|
109 |
+
cls_token=None, gt=gt, mae_mask_infer=gt_mask,
|
110 |
+
forward_model=False)
|
111 |
+
controlnet_skips = controlnet(x, t, text,
|
112 |
+
context_mask=text_mask,
|
113 |
+
cls_token=None,
|
114 |
+
condition=condition,
|
115 |
+
conditioning_scale=conditioning_scale)
|
116 |
+
output_pred = unet.model(x, t, text,
|
117 |
+
context_mask=text_mask,
|
118 |
+
cls_token=None, controlnet_skips=controlnet_skips)
|
119 |
+
|
120 |
+
latents = noise_scheduler.step(model_output=output_pred, timestep=t,
|
121 |
+
sample=latents,
|
122 |
+
eta=eta, generator=generator).prev_sample
|
123 |
+
|
124 |
+
pred = scale_shift_re(latents, params['autoencoder']['scale'],
|
125 |
+
params['autoencoder']['shift'])
|
126 |
+
if gt is not None:
|
127 |
+
pred[~gt_mask] = gt[~gt_mask]
|
128 |
+
pred_wav = autoencoder(embedding=pred)
|
129 |
+
return pred_wav
|
src/models/.ipynb_checkpoints/blocks-checkpoint.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.utils.checkpoint import checkpoint
|
4 |
+
from .utils.attention import Attention, JointAttention
|
5 |
+
from .utils.modules import unpatchify, FeedForward
|
6 |
+
from .utils.modules import film_modulate
|
7 |
+
|
8 |
+
|
9 |
+
class AdaLN(nn.Module):
|
10 |
+
def __init__(self, dim, ada_mode='ada', r=None, alpha=None):
|
11 |
+
super().__init__()
|
12 |
+
self.ada_mode = ada_mode
|
13 |
+
self.scale_shift_table = None
|
14 |
+
if ada_mode == 'ada':
|
15 |
+
# move nn.silu outside
|
16 |
+
self.time_ada = nn.Linear(dim, 6 * dim, bias=True)
|
17 |
+
elif ada_mode == 'ada_single':
|
18 |
+
# adaln used in pixel-art alpha
|
19 |
+
self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
|
20 |
+
elif ada_mode in ['ada_lora', 'ada_lora_bias']:
|
21 |
+
self.lora_a = nn.Linear(dim, r * 6, bias=False)
|
22 |
+
self.lora_b = nn.Linear(r * 6, dim * 6, bias=False)
|
23 |
+
self.scaling = alpha / r
|
24 |
+
if ada_mode == 'ada_lora_bias':
|
25 |
+
# take bias out for consistency
|
26 |
+
self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
|
27 |
+
else:
|
28 |
+
raise NotImplementedError
|
29 |
+
|
30 |
+
def forward(self, time_token=None, time_ada=None):
|
31 |
+
if self.ada_mode == 'ada':
|
32 |
+
assert time_ada is None
|
33 |
+
B = time_token.shape[0]
|
34 |
+
time_ada = self.time_ada(time_token).reshape(B, 6, -1)
|
35 |
+
elif self.ada_mode == 'ada_single':
|
36 |
+
B = time_ada.shape[0]
|
37 |
+
time_ada = time_ada.reshape(B, 6, -1)
|
38 |
+
time_ada = self.scale_shift_table[None] + time_ada
|
39 |
+
elif self.ada_mode in ['ada_lora', 'ada_lora_bias']:
|
40 |
+
B = time_ada.shape[0]
|
41 |
+
time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling
|
42 |
+
time_ada = time_ada + time_ada_lora
|
43 |
+
time_ada = time_ada.reshape(B, 6, -1)
|
44 |
+
if self.scale_shift_table is not None:
|
45 |
+
time_ada = self.scale_shift_table[None] + time_ada
|
46 |
+
else:
|
47 |
+
raise NotImplementedError
|
48 |
+
return time_ada
|
49 |
+
|
50 |
+
|
51 |
+
class DiTBlock(nn.Module):
|
52 |
+
"""
|
53 |
+
A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(self, dim, context_dim=None,
|
57 |
+
num_heads=8, mlp_ratio=4.,
|
58 |
+
qkv_bias=False, qk_scale=None, qk_norm=None,
|
59 |
+
act_layer='gelu', norm_layer=nn.LayerNorm,
|
60 |
+
time_fusion='none',
|
61 |
+
ada_lora_rank=None, ada_lora_alpha=None,
|
62 |
+
skip=False, skip_norm=False,
|
63 |
+
rope_mode='none',
|
64 |
+
context_norm=False,
|
65 |
+
use_checkpoint=False):
|
66 |
+
|
67 |
+
super().__init__()
|
68 |
+
self.norm1 = norm_layer(dim)
|
69 |
+
self.attn = Attention(dim=dim,
|
70 |
+
num_heads=num_heads,
|
71 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
72 |
+
qk_norm=qk_norm,
|
73 |
+
rope_mode=rope_mode)
|
74 |
+
|
75 |
+
if context_dim is not None:
|
76 |
+
self.use_context = True
|
77 |
+
self.cross_attn = Attention(dim=dim,
|
78 |
+
num_heads=num_heads,
|
79 |
+
context_dim=context_dim,
|
80 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
81 |
+
qk_norm=qk_norm,
|
82 |
+
rope_mode='none')
|
83 |
+
self.norm2 = norm_layer(dim)
|
84 |
+
if context_norm:
|
85 |
+
self.norm_context = norm_layer(context_dim)
|
86 |
+
else:
|
87 |
+
self.norm_context = nn.Identity()
|
88 |
+
else:
|
89 |
+
self.use_context = False
|
90 |
+
|
91 |
+
self.norm3 = norm_layer(dim)
|
92 |
+
self.mlp = FeedForward(dim=dim, mult=mlp_ratio,
|
93 |
+
activation_fn=act_layer, dropout=0)
|
94 |
+
|
95 |
+
self.use_adanorm = True if time_fusion != 'token' else False
|
96 |
+
if self.use_adanorm:
|
97 |
+
self.adaln = AdaLN(dim, ada_mode=time_fusion,
|
98 |
+
r=ada_lora_rank, alpha=ada_lora_alpha)
|
99 |
+
if skip:
|
100 |
+
self.skip_norm = norm_layer(2 * dim) if skip_norm else nn.Identity()
|
101 |
+
self.skip_linear = nn.Linear(2 * dim, dim)
|
102 |
+
else:
|
103 |
+
self.skip_linear = None
|
104 |
+
|
105 |
+
self.use_checkpoint = use_checkpoint
|
106 |
+
|
107 |
+
def forward(self, x, time_token=None, time_ada=None,
|
108 |
+
skip=None, context=None,
|
109 |
+
x_mask=None, context_mask=None, extras=None):
|
110 |
+
if self.use_checkpoint:
|
111 |
+
return checkpoint(self._forward, x,
|
112 |
+
time_token, time_ada, skip, context,
|
113 |
+
x_mask, context_mask, extras,
|
114 |
+
use_reentrant=False)
|
115 |
+
else:
|
116 |
+
return self._forward(x,
|
117 |
+
time_token, time_ada, skip, context,
|
118 |
+
x_mask, context_mask, extras)
|
119 |
+
|
120 |
+
def _forward(self, x, time_token=None, time_ada=None,
|
121 |
+
skip=None, context=None,
|
122 |
+
x_mask=None, context_mask=None, extras=None):
|
123 |
+
B, T, C = x.shape
|
124 |
+
if self.skip_linear is not None:
|
125 |
+
assert skip is not None
|
126 |
+
cat = torch.cat([x, skip], dim=-1)
|
127 |
+
cat = self.skip_norm(cat)
|
128 |
+
x = self.skip_linear(cat)
|
129 |
+
|
130 |
+
if self.use_adanorm:
|
131 |
+
time_ada = self.adaln(time_token, time_ada)
|
132 |
+
(shift_msa, scale_msa, gate_msa,
|
133 |
+
shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
|
134 |
+
|
135 |
+
# self attention
|
136 |
+
if self.use_adanorm:
|
137 |
+
x_norm = film_modulate(self.norm1(x), shift=shift_msa,
|
138 |
+
scale=scale_msa)
|
139 |
+
x = x + (1 - gate_msa) * self.attn(x_norm, context=None,
|
140 |
+
context_mask=x_mask,
|
141 |
+
extras=extras)
|
142 |
+
else:
|
143 |
+
x = x + self.attn(self.norm1(x), context=None, context_mask=x_mask,
|
144 |
+
extras=extras)
|
145 |
+
|
146 |
+
# cross attention
|
147 |
+
if self.use_context:
|
148 |
+
assert context is not None
|
149 |
+
x = x + self.cross_attn(x=self.norm2(x),
|
150 |
+
context=self.norm_context(context),
|
151 |
+
context_mask=context_mask, extras=extras)
|
152 |
+
|
153 |
+
# mlp
|
154 |
+
if self.use_adanorm:
|
155 |
+
x_norm = film_modulate(self.norm3(x), shift=shift_mlp, scale=scale_mlp)
|
156 |
+
x = x + (1 - gate_mlp) * self.mlp(x_norm)
|
157 |
+
else:
|
158 |
+
x = x + self.mlp(self.norm3(x))
|
159 |
+
|
160 |
+
return x
|
161 |
+
|
162 |
+
|
163 |
+
class JointDiTBlock(nn.Module):
|
164 |
+
"""
|
165 |
+
A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(self, dim, context_dim=None,
|
169 |
+
num_heads=8, mlp_ratio=4.,
|
170 |
+
qkv_bias=False, qk_scale=None, qk_norm=None,
|
171 |
+
act_layer='gelu', norm_layer=nn.LayerNorm,
|
172 |
+
time_fusion='none',
|
173 |
+
ada_lora_rank=None, ada_lora_alpha=None,
|
174 |
+
skip=(False, False),
|
175 |
+
rope_mode=False,
|
176 |
+
context_norm=False,
|
177 |
+
use_checkpoint=False,):
|
178 |
+
|
179 |
+
super().__init__()
|
180 |
+
# no cross attention
|
181 |
+
assert context_dim is None
|
182 |
+
self.attn_norm_x = norm_layer(dim)
|
183 |
+
self.attn_norm_c = norm_layer(dim)
|
184 |
+
self.attn = JointAttention(dim=dim,
|
185 |
+
num_heads=num_heads,
|
186 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
187 |
+
qk_norm=qk_norm,
|
188 |
+
rope_mode=rope_mode)
|
189 |
+
self.ffn_norm_x = norm_layer(dim)
|
190 |
+
self.ffn_norm_c = norm_layer(dim)
|
191 |
+
self.mlp_x = FeedForward(dim=dim, mult=mlp_ratio,
|
192 |
+
activation_fn=act_layer, dropout=0)
|
193 |
+
self.mlp_c = FeedForward(dim=dim, mult=mlp_ratio,
|
194 |
+
activation_fn=act_layer, dropout=0)
|
195 |
+
|
196 |
+
# Zero-out the shift table
|
197 |
+
self.use_adanorm = True if time_fusion != 'token' else False
|
198 |
+
if self.use_adanorm:
|
199 |
+
self.adaln = AdaLN(dim, ada_mode=time_fusion,
|
200 |
+
r=ada_lora_rank, alpha=ada_lora_alpha)
|
201 |
+
|
202 |
+
if skip is False:
|
203 |
+
skip_x, skip_c = False, False
|
204 |
+
else:
|
205 |
+
skip_x, skip_c = skip
|
206 |
+
|
207 |
+
self.skip_linear_x = nn.Linear(2 * dim, dim) if skip_x else None
|
208 |
+
self.skip_linear_c = nn.Linear(2 * dim, dim) if skip_c else None
|
209 |
+
|
210 |
+
self.use_checkpoint = use_checkpoint
|
211 |
+
|
212 |
+
def forward(self, x, time_token=None, time_ada=None,
|
213 |
+
skip=None, context=None,
|
214 |
+
x_mask=None, context_mask=None, extras=None):
|
215 |
+
if self.use_checkpoint:
|
216 |
+
return checkpoint(self._forward, x,
|
217 |
+
time_token, time_ada, skip,
|
218 |
+
context, x_mask, context_mask, extras,
|
219 |
+
use_reentrant=False)
|
220 |
+
else:
|
221 |
+
return self._forward(x,
|
222 |
+
time_token, time_ada, skip,
|
223 |
+
context, x_mask, context_mask, extras)
|
224 |
+
|
225 |
+
def _forward(self, x, time_token=None, time_ada=None,
|
226 |
+
skip=None, context=None,
|
227 |
+
x_mask=None, context_mask=None, extras=None):
|
228 |
+
|
229 |
+
assert context is None and context_mask is None
|
230 |
+
|
231 |
+
context, x = x[:, :extras, :], x[:, extras:, :]
|
232 |
+
context_mask, x_mask = x_mask[:, :extras], x_mask[:, extras:]
|
233 |
+
|
234 |
+
if skip is not None:
|
235 |
+
skip_c, skip_x = skip[:, :extras, :], skip[:, extras:, :]
|
236 |
+
|
237 |
+
B, T, C = x.shape
|
238 |
+
if self.skip_linear_x is not None:
|
239 |
+
x = self.skip_linear_x(torch.cat([x, skip_x], dim=-1))
|
240 |
+
|
241 |
+
if self.skip_linear_c is not None:
|
242 |
+
context = self.skip_linear_c(torch.cat([context, skip_c], dim=-1))
|
243 |
+
|
244 |
+
if self.use_adanorm:
|
245 |
+
time_ada = self.adaln(time_token, time_ada)
|
246 |
+
(shift_msa, scale_msa, gate_msa,
|
247 |
+
shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
|
248 |
+
|
249 |
+
# self attention
|
250 |
+
x_norm = self.attn_norm_x(x)
|
251 |
+
c_norm = self.attn_norm_c(context)
|
252 |
+
if self.use_adanorm:
|
253 |
+
x_norm = film_modulate(x_norm, shift=shift_msa, scale=scale_msa)
|
254 |
+
x_out, c_out = self.attn(x_norm, context=c_norm,
|
255 |
+
x_mask=x_mask, context_mask=context_mask,
|
256 |
+
extras=extras)
|
257 |
+
if self.use_adanorm:
|
258 |
+
x = x + (1 - gate_msa) * x_out
|
259 |
+
else:
|
260 |
+
x = x + x_out
|
261 |
+
context = context + c_out
|
262 |
+
|
263 |
+
# mlp
|
264 |
+
if self.use_adanorm:
|
265 |
+
x_norm = film_modulate(self.ffn_norm_x(x),
|
266 |
+
shift=shift_mlp, scale=scale_mlp)
|
267 |
+
x = x + (1 - gate_mlp) * self.mlp_x(x_norm)
|
268 |
+
else:
|
269 |
+
x = x + self.mlp_x(self.ffn_norm_x(x))
|
270 |
+
|
271 |
+
c_norm = self.ffn_norm_c(context)
|
272 |
+
context = context + self.mlp_c(c_norm)
|
273 |
+
|
274 |
+
return torch.cat((context, x), dim=1)
|
275 |
+
|
276 |
+
|
277 |
+
class FinalBlock(nn.Module):
|
278 |
+
def __init__(self, embed_dim, patch_size, in_chans,
|
279 |
+
img_size,
|
280 |
+
input_type='2d',
|
281 |
+
norm_layer=nn.LayerNorm,
|
282 |
+
use_conv=True,
|
283 |
+
use_adanorm=True):
|
284 |
+
super().__init__()
|
285 |
+
self.in_chans = in_chans
|
286 |
+
self.img_size = img_size
|
287 |
+
self.input_type = input_type
|
288 |
+
|
289 |
+
self.norm = norm_layer(embed_dim)
|
290 |
+
if use_adanorm:
|
291 |
+
self.use_adanorm = True
|
292 |
+
else:
|
293 |
+
self.use_adanorm = False
|
294 |
+
|
295 |
+
if input_type == '2d':
|
296 |
+
self.patch_dim = patch_size ** 2 * in_chans
|
297 |
+
self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
|
298 |
+
if use_conv:
|
299 |
+
self.final_layer = nn.Conv2d(self.in_chans, self.in_chans,
|
300 |
+
3, padding=1)
|
301 |
+
else:
|
302 |
+
self.final_layer = nn.Identity()
|
303 |
+
|
304 |
+
elif input_type == '1d':
|
305 |
+
self.patch_dim = patch_size * in_chans
|
306 |
+
self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
|
307 |
+
if use_conv:
|
308 |
+
self.final_layer = nn.Conv1d(self.in_chans, self.in_chans,
|
309 |
+
3, padding=1)
|
310 |
+
else:
|
311 |
+
self.final_layer = nn.Identity()
|
312 |
+
|
313 |
+
def forward(self, x, time_ada=None, extras=0):
|
314 |
+
B, T, C = x.shape
|
315 |
+
x = x[:, extras:, :]
|
316 |
+
# only handle generation target
|
317 |
+
if self.use_adanorm:
|
318 |
+
shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1)
|
319 |
+
x = film_modulate(self.norm(x), shift, scale)
|
320 |
+
else:
|
321 |
+
x = self.norm(x)
|
322 |
+
x = self.linear(x)
|
323 |
+
x = unpatchify(x, self.in_chans, self.input_type, self.img_size)
|
324 |
+
x = self.final_layer(x)
|
325 |
+
return x
|
src/models/.ipynb_checkpoints/conditioners-checkpoint.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import repeat
|
5 |
+
import math
|
6 |
+
from .udit import UDiT
|
7 |
+
from .utils.span_mask import compute_mask_indices
|
8 |
+
|
9 |
+
|
10 |
+
class EmbeddingCFG(nn.Module):
|
11 |
+
"""
|
12 |
+
Handles label dropout for classifier-free guidance.
|
13 |
+
"""
|
14 |
+
# todo: support 2D input
|
15 |
+
|
16 |
+
def __init__(self, in_channels):
|
17 |
+
super().__init__()
|
18 |
+
self.cfg_embedding = nn.Parameter(
|
19 |
+
torch.randn(in_channels) / in_channels ** 0.5)
|
20 |
+
|
21 |
+
def token_drop(self, condition, condition_mask, cfg_prob):
|
22 |
+
"""
|
23 |
+
Drops labels to enable classifier-free guidance.
|
24 |
+
"""
|
25 |
+
b, t, device = condition.shape[0], condition.shape[1], condition.device
|
26 |
+
drop_ids = torch.rand(b, device=device) < cfg_prob
|
27 |
+
uncond = repeat(self.cfg_embedding, "c -> b t c", b=b, t=t)
|
28 |
+
condition = torch.where(drop_ids[:, None, None], uncond, condition)
|
29 |
+
if condition_mask is not None:
|
30 |
+
condition_mask[drop_ids] = False
|
31 |
+
condition_mask[drop_ids, 0] = True
|
32 |
+
|
33 |
+
return condition, condition_mask
|
34 |
+
|
35 |
+
def forward(self, condition, condition_mask, cfg_prob=0.0):
|
36 |
+
if condition_mask is not None:
|
37 |
+
condition_mask = condition_mask.clone()
|
38 |
+
if cfg_prob > 0:
|
39 |
+
condition, condition_mask = self.token_drop(condition,
|
40 |
+
condition_mask,
|
41 |
+
cfg_prob)
|
42 |
+
return condition, condition_mask
|
43 |
+
|
44 |
+
|
45 |
+
class DiscreteCFG(nn.Module):
|
46 |
+
def __init__(self, replace_id=2):
|
47 |
+
super(DiscreteCFG, self).__init__()
|
48 |
+
self.replace_id = replace_id
|
49 |
+
|
50 |
+
def forward(self, context, context_mask, cfg_prob):
|
51 |
+
context = context.clone()
|
52 |
+
if context_mask is not None:
|
53 |
+
context_mask = context_mask.clone()
|
54 |
+
if cfg_prob > 0:
|
55 |
+
cfg_mask = torch.rand(len(context)) < cfg_prob
|
56 |
+
if torch.any(cfg_mask):
|
57 |
+
context[cfg_mask] = 0
|
58 |
+
context[cfg_mask, 0] = self.replace_id
|
59 |
+
if context_mask is not None:
|
60 |
+
context_mask[cfg_mask] = False
|
61 |
+
context_mask[cfg_mask, 0] = True
|
62 |
+
return context, context_mask
|
63 |
+
|
64 |
+
|
65 |
+
class CFGModel(nn.Module):
|
66 |
+
def __init__(self, context_dim, backbone):
|
67 |
+
super().__init__()
|
68 |
+
self.model = backbone
|
69 |
+
self.context_cfg = EmbeddingCFG(context_dim)
|
70 |
+
|
71 |
+
def forward(self, x, timesteps,
|
72 |
+
context, x_mask=None, context_mask=None,
|
73 |
+
cfg_prob=0.0):
|
74 |
+
context = self.context_cfg(context, cfg_prob)
|
75 |
+
x = self.model(x=x, timesteps=timesteps,
|
76 |
+
context=context,
|
77 |
+
x_mask=x_mask, context_mask=context_mask)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class ConcatModel(nn.Module):
|
82 |
+
def __init__(self, backbone, in_dim, stride=[]):
|
83 |
+
super().__init__()
|
84 |
+
self.model = backbone
|
85 |
+
|
86 |
+
self.downsample_layers = nn.ModuleList()
|
87 |
+
for i, s in enumerate(stride):
|
88 |
+
downsample_layer = nn.Conv1d(
|
89 |
+
in_dim,
|
90 |
+
in_dim * 2,
|
91 |
+
kernel_size=2 * s,
|
92 |
+
stride=s,
|
93 |
+
padding=math.ceil(s / 2),
|
94 |
+
)
|
95 |
+
self.downsample_layers.append(downsample_layer)
|
96 |
+
in_dim = in_dim * 2
|
97 |
+
|
98 |
+
self.context_cfg = EmbeddingCFG(in_dim)
|
99 |
+
|
100 |
+
def forward(self, x, timesteps,
|
101 |
+
context, x_mask=None,
|
102 |
+
cfg=False, cfg_prob=0.0):
|
103 |
+
|
104 |
+
# todo: support 2D input
|
105 |
+
# x: B, C, L
|
106 |
+
# context: B, C, L
|
107 |
+
|
108 |
+
for downsample_layer in self.downsample_layers:
|
109 |
+
context = downsample_layer(context)
|
110 |
+
|
111 |
+
context = context.transpose(1, 2)
|
112 |
+
context = self.context_cfg(caption=context,
|
113 |
+
cfg=cfg, cfg_prob=cfg_prob)
|
114 |
+
context = context.transpose(1, 2)
|
115 |
+
|
116 |
+
assert context.shape[-1] == x.shape[-1]
|
117 |
+
x = torch.cat([context, x], dim=1)
|
118 |
+
x = self.model(x=x, timesteps=timesteps,
|
119 |
+
context=None, x_mask=x_mask, context_mask=None)
|
120 |
+
return x
|
121 |
+
|
122 |
+
|
123 |
+
class MaskDiT(nn.Module):
|
124 |
+
def __init__(self, mae=False, mae_prob=0.5, mask_ratio=[0.25, 1.0], mask_span=10, **kwargs):
|
125 |
+
super().__init__()
|
126 |
+
self.model = UDiT(**kwargs)
|
127 |
+
self.mae = mae
|
128 |
+
if self.mae:
|
129 |
+
out_channel = kwargs.pop('out_chans', None)
|
130 |
+
self.mask_embed = nn.Parameter(torch.zeros((out_channel)))
|
131 |
+
self.mae_prob = mae_prob
|
132 |
+
self.mask_ratio = mask_ratio
|
133 |
+
self.mask_span = mask_span
|
134 |
+
|
135 |
+
def random_masking(self, gt, mask_ratios, mae_mask_infer=None):
|
136 |
+
B, D, L = gt.shape
|
137 |
+
if mae_mask_infer is None:
|
138 |
+
# mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1)
|
139 |
+
mask_ratios = mask_ratios.cpu().numpy()
|
140 |
+
mask = compute_mask_indices(shape=[B, L],
|
141 |
+
padding_mask=None,
|
142 |
+
mask_prob=mask_ratios,
|
143 |
+
mask_length=self.mask_span,
|
144 |
+
mask_type="static",
|
145 |
+
mask_other=0.0,
|
146 |
+
min_masks=1,
|
147 |
+
no_overlap=False,
|
148 |
+
min_space=0,)
|
149 |
+
mask = mask.unsqueeze(1).expand_as(gt)
|
150 |
+
else:
|
151 |
+
mask = mae_mask_infer
|
152 |
+
mask = mask.expand_as(gt)
|
153 |
+
gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask]
|
154 |
+
return gt, mask.type_as(gt)
|
155 |
+
|
156 |
+
def forward(self, x, timesteps, context,
|
157 |
+
x_mask=None, context_mask=None, cls_token=None,
|
158 |
+
gt=None, mae_mask_infer=None,
|
159 |
+
forward_model=True):
|
160 |
+
# todo: handle controlnet inside
|
161 |
+
mae_mask = torch.ones_like(x)
|
162 |
+
if self.mae:
|
163 |
+
if gt is not None:
|
164 |
+
B, D, L = gt.shape
|
165 |
+
mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(gt.device)
|
166 |
+
gt, mae_mask = self.random_masking(gt, mask_ratios, mae_mask_infer)
|
167 |
+
# apply mae only to the selected batches
|
168 |
+
if mae_mask_infer is None:
|
169 |
+
# determine mae batch
|
170 |
+
mae_batch = torch.rand(B) < self.mae_prob
|
171 |
+
gt[~mae_batch] = self.mask_embed.view(1, D, 1).expand_as(gt)[~mae_batch]
|
172 |
+
mae_mask[~mae_batch] = 1.0
|
173 |
+
else:
|
174 |
+
B, D, L = x.shape
|
175 |
+
gt = self.mask_embed.view(1, D, 1).expand_as(x)
|
176 |
+
x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1)
|
177 |
+
|
178 |
+
if forward_model:
|
179 |
+
x = self.model(x=x, timesteps=timesteps, context=context,
|
180 |
+
x_mask=x_mask, context_mask=context_mask,
|
181 |
+
cls_token=cls_token)
|
182 |
+
# print(mae_mask[:, 0, :].sum(dim=-1))
|
183 |
+
return x, mae_mask
|
src/models/.ipynb_checkpoints/controlnet-checkpoint.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .utils.modules import PatchEmbed, TimestepEmbedder
|
5 |
+
from .utils.modules import PE_wrapper, RMSNorm
|
6 |
+
from .blocks import DiTBlock, JointDiTBlock
|
7 |
+
from .utils.span_mask import compute_mask_indices
|
8 |
+
|
9 |
+
|
10 |
+
class DiTControlNetEmbed(nn.Module):
|
11 |
+
def __init__(self, in_chans, out_chans, blocks,
|
12 |
+
cond_mask=False, cond_mask_prob=None,
|
13 |
+
cond_mask_ratio=None, cond_mask_span=None):
|
14 |
+
super().__init__()
|
15 |
+
self.conv_in = nn.Conv1d(in_chans, blocks[0], kernel_size=1)
|
16 |
+
|
17 |
+
self.cond_mask = cond_mask
|
18 |
+
if self.cond_mask:
|
19 |
+
self.mask_embed = nn.Parameter(torch.zeros((blocks[0])))
|
20 |
+
self.mask_prob = cond_mask_prob
|
21 |
+
self.mask_ratio = cond_mask_ratio
|
22 |
+
self.mask_span = cond_mask_span
|
23 |
+
blocks[0] = blocks[0] + 1
|
24 |
+
|
25 |
+
conv_blocks = []
|
26 |
+
for i in range(len(blocks) - 1):
|
27 |
+
channel_in = blocks[i]
|
28 |
+
channel_out = blocks[i + 1]
|
29 |
+
block = nn.Sequential(
|
30 |
+
nn.Conv1d(channel_in, channel_in, kernel_size=3, padding=1),
|
31 |
+
nn.SiLU(),
|
32 |
+
nn.Conv1d(channel_in, channel_out, kernel_size=3, padding=1, stride=2),
|
33 |
+
nn.SiLU(),)
|
34 |
+
conv_blocks.append(block)
|
35 |
+
self.blocks = nn.ModuleList(conv_blocks)
|
36 |
+
|
37 |
+
self.conv_out = nn.Conv1d(blocks[-1], out_chans, kernel_size=1)
|
38 |
+
nn.init.zeros_(self.conv_out.weight)
|
39 |
+
nn.init.zeros_(self.conv_out.bias)
|
40 |
+
|
41 |
+
def random_masking(self, gt, mask_ratios, mae_mask_infer=None):
|
42 |
+
B, D, L = gt.shape
|
43 |
+
if mae_mask_infer is None:
|
44 |
+
# mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1)
|
45 |
+
mask_ratios = mask_ratios.cpu().numpy()
|
46 |
+
mask = compute_mask_indices(shape=[B, L],
|
47 |
+
padding_mask=None,
|
48 |
+
mask_prob=mask_ratios,
|
49 |
+
mask_length=self.mask_span,
|
50 |
+
mask_type="static",
|
51 |
+
mask_other=0.0,
|
52 |
+
min_masks=1,
|
53 |
+
no_overlap=False,
|
54 |
+
min_space=0,)
|
55 |
+
# only apply mask to some batches
|
56 |
+
mask_batch = torch.rand(B) < self.mask_prob
|
57 |
+
mask[~mask_batch] = False
|
58 |
+
mask = mask.unsqueeze(1).expand_as(gt)
|
59 |
+
else:
|
60 |
+
mask = mae_mask_infer
|
61 |
+
mask = mask.expand_as(gt)
|
62 |
+
gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask].type_as(gt)
|
63 |
+
return gt, mask.type_as(gt)
|
64 |
+
|
65 |
+
def forward(self, conditioning, cond_mask_infer=None):
|
66 |
+
embedding = self.conv_in(conditioning)
|
67 |
+
|
68 |
+
if self.cond_mask:
|
69 |
+
B, D, L = embedding.shape
|
70 |
+
if not self.training and cond_mask_infer is None:
|
71 |
+
cond_mask_infer = torch.zeros_like(embedding).bool()
|
72 |
+
mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(embedding.device)
|
73 |
+
embedding, cond_mask = self.random_masking(embedding, mask_ratios, cond_mask_infer)
|
74 |
+
embedding = torch.cat([embedding, cond_mask[:, 0:1, :]], dim=1)
|
75 |
+
|
76 |
+
for block in self.blocks:
|
77 |
+
embedding = block(embedding)
|
78 |
+
|
79 |
+
embedding = self.conv_out(embedding)
|
80 |
+
|
81 |
+
# B, L, C
|
82 |
+
embedding = embedding.transpose(1, 2).contiguous()
|
83 |
+
|
84 |
+
return embedding
|
85 |
+
|
86 |
+
|
87 |
+
class DiTControlNet(nn.Module):
|
88 |
+
def __init__(self,
|
89 |
+
img_size=(224, 224), patch_size=16, in_chans=3,
|
90 |
+
input_type='2d', out_chans=None,
|
91 |
+
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
|
92 |
+
qkv_bias=False, qk_scale=None, qk_norm=None,
|
93 |
+
act_layer='gelu', norm_layer='layernorm',
|
94 |
+
context_norm=False,
|
95 |
+
use_checkpoint=False,
|
96 |
+
# time fusion ada or token
|
97 |
+
time_fusion='token',
|
98 |
+
ada_lora_rank=None, ada_lora_alpha=None,
|
99 |
+
cls_dim=None,
|
100 |
+
# max length is only used for concat
|
101 |
+
context_dim=768, context_fusion='concat',
|
102 |
+
context_max_length=128, context_pe_method='sinu',
|
103 |
+
pe_method='abs', rope_mode='none',
|
104 |
+
use_conv=True,
|
105 |
+
skip=True, skip_norm=True,
|
106 |
+
# controlnet configs
|
107 |
+
cond_in=None, cond_blocks=None,
|
108 |
+
cond_mask=False, cond_mask_prob=None,
|
109 |
+
cond_mask_ratio=None, cond_mask_span=None,
|
110 |
+
**kwargs):
|
111 |
+
super().__init__()
|
112 |
+
self.num_features = self.embed_dim = embed_dim
|
113 |
+
# input
|
114 |
+
self.in_chans = in_chans
|
115 |
+
self.input_type = input_type
|
116 |
+
if self.input_type == '2d':
|
117 |
+
num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
|
118 |
+
elif self.input_type == '1d':
|
119 |
+
num_patches = img_size // patch_size
|
120 |
+
self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans,
|
121 |
+
embed_dim=embed_dim, input_type=input_type)
|
122 |
+
out_chans = in_chans if out_chans is None else out_chans
|
123 |
+
self.out_chans = out_chans
|
124 |
+
|
125 |
+
# position embedding
|
126 |
+
self.rope = rope_mode
|
127 |
+
self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method,
|
128 |
+
length=num_patches)
|
129 |
+
|
130 |
+
print(f'x position embedding: {pe_method}')
|
131 |
+
print(f'rope mode: {self.rope}')
|
132 |
+
|
133 |
+
# time embed
|
134 |
+
self.time_embed = TimestepEmbedder(embed_dim)
|
135 |
+
self.time_fusion = time_fusion
|
136 |
+
self.use_adanorm = False
|
137 |
+
|
138 |
+
# cls embed
|
139 |
+
if cls_dim is not None:
|
140 |
+
self.cls_embed = nn.Sequential(
|
141 |
+
nn.Linear(cls_dim, embed_dim, bias=True),
|
142 |
+
nn.SiLU(),
|
143 |
+
nn.Linear(embed_dim, embed_dim, bias=True),)
|
144 |
+
else:
|
145 |
+
self.cls_embed = None
|
146 |
+
|
147 |
+
# time fusion
|
148 |
+
if time_fusion == 'token':
|
149 |
+
# put token at the beginning of sequence
|
150 |
+
self.extras = 2 if self.cls_embed else 1
|
151 |
+
self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras)
|
152 |
+
elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']:
|
153 |
+
self.use_adanorm = True
|
154 |
+
# aviod repetitive silu for each adaln block
|
155 |
+
self.time_act = nn.SiLU()
|
156 |
+
self.extras = 0
|
157 |
+
if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']:
|
158 |
+
# shared adaln
|
159 |
+
self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
|
160 |
+
else:
|
161 |
+
self.time_ada = None
|
162 |
+
else:
|
163 |
+
raise NotImplementedError
|
164 |
+
print(f'time fusion mode: {self.time_fusion}')
|
165 |
+
|
166 |
+
# context
|
167 |
+
# use a simple projection
|
168 |
+
self.use_context = False
|
169 |
+
self.context_cross = False
|
170 |
+
self.context_max_length = context_max_length
|
171 |
+
self.context_fusion = 'none'
|
172 |
+
if context_dim is not None:
|
173 |
+
self.use_context = True
|
174 |
+
self.context_embed = nn.Sequential(
|
175 |
+
nn.Linear(context_dim, embed_dim, bias=True),
|
176 |
+
nn.SiLU(),
|
177 |
+
nn.Linear(embed_dim, embed_dim, bias=True),)
|
178 |
+
self.context_fusion = context_fusion
|
179 |
+
if context_fusion == 'concat' or context_fusion == 'joint':
|
180 |
+
self.extras += context_max_length
|
181 |
+
self.context_pe = PE_wrapper(dim=embed_dim,
|
182 |
+
method=context_pe_method,
|
183 |
+
length=context_max_length)
|
184 |
+
# no cross attention layers
|
185 |
+
context_dim = None
|
186 |
+
elif context_fusion == 'cross':
|
187 |
+
self.context_pe = PE_wrapper(dim=embed_dim,
|
188 |
+
method=context_pe_method,
|
189 |
+
length=context_max_length)
|
190 |
+
self.context_cross = True
|
191 |
+
context_dim = embed_dim
|
192 |
+
else:
|
193 |
+
raise NotImplementedError
|
194 |
+
print(f'context fusion mode: {context_fusion}')
|
195 |
+
print(f'context position embedding: {context_pe_method}')
|
196 |
+
|
197 |
+
if self.context_fusion == 'joint':
|
198 |
+
Block = JointDiTBlock
|
199 |
+
else:
|
200 |
+
Block = DiTBlock
|
201 |
+
|
202 |
+
# norm layers
|
203 |
+
if norm_layer == 'layernorm':
|
204 |
+
norm_layer = nn.LayerNorm
|
205 |
+
elif norm_layer == 'rmsnorm':
|
206 |
+
norm_layer = RMSNorm
|
207 |
+
else:
|
208 |
+
raise NotImplementedError
|
209 |
+
|
210 |
+
self.in_blocks = nn.ModuleList([
|
211 |
+
Block(
|
212 |
+
dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
|
213 |
+
mlp_ratio=mlp_ratio,
|
214 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
|
215 |
+
act_layer=act_layer, norm_layer=norm_layer,
|
216 |
+
time_fusion=time_fusion,
|
217 |
+
ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
|
218 |
+
skip=False, skip_norm=False,
|
219 |
+
rope_mode=self.rope,
|
220 |
+
context_norm=context_norm,
|
221 |
+
use_checkpoint=use_checkpoint)
|
222 |
+
for _ in range(depth // 2)])
|
223 |
+
|
224 |
+
self.controlnet_pre = DiTControlNetEmbed(in_chans=cond_in, out_chans=embed_dim,
|
225 |
+
blocks=cond_blocks,
|
226 |
+
cond_mask=cond_mask,
|
227 |
+
cond_mask_prob=cond_mask_prob,
|
228 |
+
cond_mask_ratio=cond_mask_ratio,
|
229 |
+
cond_mask_span=cond_mask_span)
|
230 |
+
|
231 |
+
controlnet_zero_blocks = []
|
232 |
+
for i in range(depth // 2):
|
233 |
+
block = nn.Linear(embed_dim, embed_dim)
|
234 |
+
nn.init.zeros_(block.weight)
|
235 |
+
nn.init.zeros_(block.bias)
|
236 |
+
controlnet_zero_blocks.append(block)
|
237 |
+
self.controlnet_zero_blocks = nn.ModuleList(controlnet_zero_blocks)
|
238 |
+
|
239 |
+
print('ControlNet ready \n')
|
240 |
+
|
241 |
+
def set_trainable(self):
|
242 |
+
for param in self.parameters():
|
243 |
+
param.requires_grad = False
|
244 |
+
|
245 |
+
# only train input_proj, blocks, and output_proj
|
246 |
+
for module_name in ['controlnet_pre', 'in_blocks', 'controlnet_zero_blocks']:
|
247 |
+
module = getattr(self, module_name, None)
|
248 |
+
if module is not None:
|
249 |
+
for param in module.parameters():
|
250 |
+
param.requires_grad = True
|
251 |
+
module.train()
|
252 |
+
else:
|
253 |
+
print(f'\n!!!warning missing trainable blocks: {module_name}!!!\n')
|
254 |
+
|
255 |
+
def forward(self, x, timesteps, context,
|
256 |
+
x_mask=None, context_mask=None,
|
257 |
+
cls_token=None,
|
258 |
+
condition=None, cond_mask_infer=None,
|
259 |
+
conditioning_scale=1.0):
|
260 |
+
# make it compatible with int time step during inference
|
261 |
+
if timesteps.dim() == 0:
|
262 |
+
timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
|
263 |
+
|
264 |
+
x = self.patch_embed(x)
|
265 |
+
# add condition to x
|
266 |
+
condition = self.controlnet_pre(condition)
|
267 |
+
x = x + condition
|
268 |
+
x = self.x_pe(x)
|
269 |
+
|
270 |
+
B, L, D = x.shape
|
271 |
+
|
272 |
+
if self.use_context:
|
273 |
+
context_token = self.context_embed(context)
|
274 |
+
context_token = self.context_pe(context_token)
|
275 |
+
if self.context_fusion == 'concat' or self.context_fusion == 'joint':
|
276 |
+
x, x_mask = self._concat_x_context(x=x, context=context_token,
|
277 |
+
x_mask=x_mask,
|
278 |
+
context_mask=context_mask)
|
279 |
+
context_token, context_mask = None, None
|
280 |
+
else:
|
281 |
+
context_token, context_mask = None, None
|
282 |
+
|
283 |
+
time_token = self.time_embed(timesteps)
|
284 |
+
if self.cls_embed:
|
285 |
+
cls_token = self.cls_embed(cls_token)
|
286 |
+
time_ada = None
|
287 |
+
if self.use_adanorm:
|
288 |
+
if self.cls_embed:
|
289 |
+
time_token = time_token + cls_token
|
290 |
+
time_token = self.time_act(time_token)
|
291 |
+
if self.time_ada is not None:
|
292 |
+
time_ada = self.time_ada(time_token)
|
293 |
+
else:
|
294 |
+
time_token = time_token.unsqueeze(dim=1)
|
295 |
+
if self.cls_embed:
|
296 |
+
cls_token = cls_token.unsqueeze(dim=1)
|
297 |
+
time_token = torch.cat([time_token, cls_token], dim=1)
|
298 |
+
time_token = self.time_pe(time_token)
|
299 |
+
x = torch.cat((time_token, x), dim=1)
|
300 |
+
if x_mask is not None:
|
301 |
+
x_mask = torch.cat(
|
302 |
+
[torch.ones(B, time_token.shape[1], device=x_mask.device).bool(),
|
303 |
+
x_mask], dim=1)
|
304 |
+
time_token = None
|
305 |
+
|
306 |
+
skips = []
|
307 |
+
for blk in self.in_blocks:
|
308 |
+
x = blk(x=x, time_token=time_token, time_ada=time_ada,
|
309 |
+
skip=None, context=context_token,
|
310 |
+
x_mask=x_mask, context_mask=context_mask,
|
311 |
+
extras=self.extras)
|
312 |
+
skips.append(x)
|
313 |
+
|
314 |
+
controlnet_skips = []
|
315 |
+
for skip, controlnet_block in zip(skips, self.controlnet_zero_blocks):
|
316 |
+
controlnet_skips.append(controlnet_block(skip) * conditioning_scale)
|
317 |
+
|
318 |
+
return controlnet_skips
|
src/models/.ipynb_checkpoints/udit-checkpoint.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.utils.checkpoint
|
4 |
+
import math
|
5 |
+
from .utils.modules import PatchEmbed, TimestepEmbedder
|
6 |
+
from .utils.modules import PE_wrapper, RMSNorm
|
7 |
+
from .blocks import DiTBlock, JointDiTBlock, FinalBlock
|
8 |
+
|
9 |
+
|
10 |
+
class UDiT(nn.Module):
|
11 |
+
def __init__(self,
|
12 |
+
img_size=224, patch_size=16, in_chans=3,
|
13 |
+
input_type='2d', out_chans=None,
|
14 |
+
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
|
15 |
+
qkv_bias=False, qk_scale=None, qk_norm=None,
|
16 |
+
act_layer='gelu', norm_layer='layernorm',
|
17 |
+
context_norm=False,
|
18 |
+
use_checkpoint=False,
|
19 |
+
# time fusion ada or token
|
20 |
+
time_fusion='token',
|
21 |
+
ada_lora_rank=None, ada_lora_alpha=None,
|
22 |
+
cls_dim=None,
|
23 |
+
# max length is only used for concat
|
24 |
+
context_dim=768, context_fusion='concat',
|
25 |
+
context_max_length=128, context_pe_method='sinu',
|
26 |
+
pe_method='abs', rope_mode='none',
|
27 |
+
use_conv=True,
|
28 |
+
skip=True, skip_norm=True):
|
29 |
+
super().__init__()
|
30 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
31 |
+
|
32 |
+
# input
|
33 |
+
self.in_chans = in_chans
|
34 |
+
self.input_type = input_type
|
35 |
+
if self.input_type == '2d':
|
36 |
+
num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
|
37 |
+
elif self.input_type == '1d':
|
38 |
+
num_patches = img_size // patch_size
|
39 |
+
self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans,
|
40 |
+
embed_dim=embed_dim, input_type=input_type)
|
41 |
+
out_chans = in_chans if out_chans is None else out_chans
|
42 |
+
self.out_chans = out_chans
|
43 |
+
|
44 |
+
# position embedding
|
45 |
+
self.rope = rope_mode
|
46 |
+
self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method,
|
47 |
+
length=num_patches)
|
48 |
+
|
49 |
+
print(f'x position embedding: {pe_method}')
|
50 |
+
print(f'rope mode: {self.rope}')
|
51 |
+
|
52 |
+
# time embed
|
53 |
+
self.time_embed = TimestepEmbedder(embed_dim)
|
54 |
+
self.time_fusion = time_fusion
|
55 |
+
self.use_adanorm = False
|
56 |
+
|
57 |
+
# cls embed
|
58 |
+
if cls_dim is not None:
|
59 |
+
self.cls_embed = nn.Sequential(
|
60 |
+
nn.Linear(cls_dim, embed_dim, bias=True),
|
61 |
+
nn.SiLU(),
|
62 |
+
nn.Linear(embed_dim, embed_dim, bias=True),)
|
63 |
+
else:
|
64 |
+
self.cls_embed = None
|
65 |
+
|
66 |
+
# time fusion
|
67 |
+
if time_fusion == 'token':
|
68 |
+
# put token at the beginning of sequence
|
69 |
+
self.extras = 2 if self.cls_embed else 1
|
70 |
+
self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras)
|
71 |
+
elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']:
|
72 |
+
self.use_adanorm = True
|
73 |
+
# aviod repetitive silu for each adaln block
|
74 |
+
self.time_act = nn.SiLU()
|
75 |
+
self.extras = 0
|
76 |
+
self.time_ada_final = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
|
77 |
+
if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']:
|
78 |
+
# shared adaln
|
79 |
+
self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
|
80 |
+
else:
|
81 |
+
self.time_ada = None
|
82 |
+
else:
|
83 |
+
raise NotImplementedError
|
84 |
+
print(f'time fusion mode: {self.time_fusion}')
|
85 |
+
|
86 |
+
# context
|
87 |
+
# use a simple projection
|
88 |
+
self.use_context = False
|
89 |
+
self.context_cross = False
|
90 |
+
self.context_max_length = context_max_length
|
91 |
+
self.context_fusion = 'none'
|
92 |
+
if context_dim is not None:
|
93 |
+
self.use_context = True
|
94 |
+
self.context_embed = nn.Sequential(
|
95 |
+
nn.Linear(context_dim, embed_dim, bias=True),
|
96 |
+
nn.SiLU(),
|
97 |
+
nn.Linear(embed_dim, embed_dim, bias=True),)
|
98 |
+
self.context_fusion = context_fusion
|
99 |
+
if context_fusion == 'concat' or context_fusion == 'joint':
|
100 |
+
self.extras += context_max_length
|
101 |
+
self.context_pe = PE_wrapper(dim=embed_dim,
|
102 |
+
method=context_pe_method,
|
103 |
+
length=context_max_length)
|
104 |
+
# no cross attention layers
|
105 |
+
context_dim = None
|
106 |
+
elif context_fusion == 'cross':
|
107 |
+
self.context_pe = PE_wrapper(dim=embed_dim,
|
108 |
+
method=context_pe_method,
|
109 |
+
length=context_max_length)
|
110 |
+
self.context_cross = True
|
111 |
+
context_dim = embed_dim
|
112 |
+
else:
|
113 |
+
raise NotImplementedError
|
114 |
+
print(f'context fusion mode: {context_fusion}')
|
115 |
+
print(f'context position embedding: {context_pe_method}')
|
116 |
+
|
117 |
+
if self.context_fusion == 'joint':
|
118 |
+
Block = JointDiTBlock
|
119 |
+
self.use_skip = skip[0]
|
120 |
+
else:
|
121 |
+
Block = DiTBlock
|
122 |
+
self.use_skip = skip
|
123 |
+
|
124 |
+
# norm layers
|
125 |
+
if norm_layer == 'layernorm':
|
126 |
+
norm_layer = nn.LayerNorm
|
127 |
+
elif norm_layer == 'rmsnorm':
|
128 |
+
norm_layer = RMSNorm
|
129 |
+
else:
|
130 |
+
raise NotImplementedError
|
131 |
+
|
132 |
+
print(f'use long skip connection: {skip}')
|
133 |
+
self.in_blocks = nn.ModuleList([
|
134 |
+
Block(
|
135 |
+
dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
|
136 |
+
mlp_ratio=mlp_ratio,
|
137 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
|
138 |
+
act_layer=act_layer, norm_layer=norm_layer,
|
139 |
+
time_fusion=time_fusion,
|
140 |
+
ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
|
141 |
+
skip=False, skip_norm=False,
|
142 |
+
rope_mode=self.rope,
|
143 |
+
context_norm=context_norm,
|
144 |
+
use_checkpoint=use_checkpoint)
|
145 |
+
for _ in range(depth // 2)])
|
146 |
+
|
147 |
+
self.mid_block = Block(
|
148 |
+
dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
|
149 |
+
mlp_ratio=mlp_ratio,
|
150 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
|
151 |
+
act_layer=act_layer, norm_layer=norm_layer,
|
152 |
+
time_fusion=time_fusion,
|
153 |
+
ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
|
154 |
+
skip=False, skip_norm=False,
|
155 |
+
rope_mode=self.rope,
|
156 |
+
context_norm=context_norm,
|
157 |
+
use_checkpoint=use_checkpoint)
|
158 |
+
|
159 |
+
self.out_blocks = nn.ModuleList([
|
160 |
+
Block(
|
161 |
+
dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
|
162 |
+
mlp_ratio=mlp_ratio,
|
163 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
|
164 |
+
act_layer=act_layer, norm_layer=norm_layer,
|
165 |
+
time_fusion=time_fusion,
|
166 |
+
ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
|
167 |
+
skip=skip, skip_norm=skip_norm,
|
168 |
+
rope_mode=self.rope,
|
169 |
+
context_norm=context_norm,
|
170 |
+
use_checkpoint=use_checkpoint)
|
171 |
+
for _ in range(depth // 2)])
|
172 |
+
|
173 |
+
# FinalLayer block
|
174 |
+
self.use_conv = use_conv
|
175 |
+
self.final_block = FinalBlock(embed_dim=embed_dim,
|
176 |
+
patch_size=patch_size,
|
177 |
+
img_size=img_size,
|
178 |
+
in_chans=out_chans,
|
179 |
+
input_type=input_type,
|
180 |
+
norm_layer=norm_layer,
|
181 |
+
use_conv=use_conv,
|
182 |
+
use_adanorm=self.use_adanorm)
|
183 |
+
self.initialize_weights()
|
184 |
+
|
185 |
+
def _init_ada(self):
|
186 |
+
if self.time_fusion == 'ada':
|
187 |
+
nn.init.constant_(self.time_ada_final.weight, 0)
|
188 |
+
nn.init.constant_(self.time_ada_final.bias, 0)
|
189 |
+
for block in self.in_blocks:
|
190 |
+
nn.init.constant_(block.adaln.time_ada.weight, 0)
|
191 |
+
nn.init.constant_(block.adaln.time_ada.bias, 0)
|
192 |
+
nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0)
|
193 |
+
nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0)
|
194 |
+
for block in self.out_blocks:
|
195 |
+
nn.init.constant_(block.adaln.time_ada.weight, 0)
|
196 |
+
nn.init.constant_(block.adaln.time_ada.bias, 0)
|
197 |
+
elif self.time_fusion == 'ada_single':
|
198 |
+
nn.init.constant_(self.time_ada.weight, 0)
|
199 |
+
nn.init.constant_(self.time_ada.bias, 0)
|
200 |
+
nn.init.constant_(self.time_ada_final.weight, 0)
|
201 |
+
nn.init.constant_(self.time_ada_final.bias, 0)
|
202 |
+
elif self.time_fusion in ['ada_lora', 'ada_lora_bias']:
|
203 |
+
nn.init.constant_(self.time_ada.weight, 0)
|
204 |
+
nn.init.constant_(self.time_ada.bias, 0)
|
205 |
+
nn.init.constant_(self.time_ada_final.weight, 0)
|
206 |
+
nn.init.constant_(self.time_ada_final.bias, 0)
|
207 |
+
for block in self.in_blocks:
|
208 |
+
nn.init.kaiming_uniform_(block.adaln.lora_a.weight,
|
209 |
+
a=math.sqrt(5))
|
210 |
+
nn.init.constant_(block.adaln.lora_b.weight, 0)
|
211 |
+
nn.init.kaiming_uniform_(self.mid_block.adaln.lora_a.weight,
|
212 |
+
a=math.sqrt(5))
|
213 |
+
nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0)
|
214 |
+
for block in self.out_blocks:
|
215 |
+
nn.init.kaiming_uniform_(block.adaln.lora_a.weight,
|
216 |
+
a=math.sqrt(5))
|
217 |
+
nn.init.constant_(block.adaln.lora_b.weight, 0)
|
218 |
+
|
219 |
+
def initialize_weights(self):
|
220 |
+
# Basic init for all layers
|
221 |
+
def _basic_init(module):
|
222 |
+
if isinstance(module, nn.Linear):
|
223 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
224 |
+
if module.bias is not None:
|
225 |
+
nn.init.constant_(module.bias, 0)
|
226 |
+
self.apply(_basic_init)
|
227 |
+
|
228 |
+
# init patch Conv like Linear
|
229 |
+
w = self.patch_embed.proj.weight.data
|
230 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
231 |
+
nn.init.constant_(self.patch_embed.proj.bias, 0)
|
232 |
+
|
233 |
+
# Zero-out AdaLN
|
234 |
+
if self.use_adanorm:
|
235 |
+
self._init_ada()
|
236 |
+
|
237 |
+
# Zero-out Cross Attention
|
238 |
+
if self.context_cross:
|
239 |
+
for block in self.in_blocks:
|
240 |
+
nn.init.constant_(block.cross_attn.proj.weight, 0)
|
241 |
+
nn.init.constant_(block.cross_attn.proj.bias, 0)
|
242 |
+
nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0)
|
243 |
+
nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0)
|
244 |
+
for block in self.out_blocks:
|
245 |
+
nn.init.constant_(block.cross_attn.proj.weight, 0)
|
246 |
+
nn.init.constant_(block.cross_attn.proj.bias, 0)
|
247 |
+
|
248 |
+
# Zero-out cls embedding
|
249 |
+
if self.cls_embed:
|
250 |
+
if self.use_adanorm:
|
251 |
+
nn.init.constant_(self.cls_embed[-1].weight, 0)
|
252 |
+
nn.init.constant_(self.cls_embed[-1].bias, 0)
|
253 |
+
|
254 |
+
# Zero-out Output
|
255 |
+
# might not zero-out this when using v-prediction
|
256 |
+
# it could be good when using noise-prediction
|
257 |
+
# nn.init.constant_(self.final_block.linear.weight, 0)
|
258 |
+
# nn.init.constant_(self.final_block.linear.bias, 0)
|
259 |
+
# if self.use_conv:
|
260 |
+
# nn.init.constant_(self.final_block.final_layer.weight.data, 0)
|
261 |
+
# nn.init.constant_(self.final_block.final_layer.bias, 0)
|
262 |
+
|
263 |
+
# init out Conv
|
264 |
+
if self.use_conv:
|
265 |
+
nn.init.xavier_uniform_(self.final_block.final_layer.weight)
|
266 |
+
nn.init.constant_(self.final_block.final_layer.bias, 0)
|
267 |
+
|
268 |
+
def _concat_x_context(self, x, context, x_mask=None, context_mask=None):
|
269 |
+
assert context.shape[-2] == self.context_max_length
|
270 |
+
# Check if either x_mask or context_mask is provided
|
271 |
+
B = x.shape[0]
|
272 |
+
# Create default masks if they are not provided
|
273 |
+
if x_mask is None:
|
274 |
+
x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
|
275 |
+
if context_mask is None:
|
276 |
+
context_mask = torch.ones(B, context.shape[-2],
|
277 |
+
device=context.device).bool()
|
278 |
+
# Concatenate the masks along the second dimension (dim=1)
|
279 |
+
x_mask = torch.cat([context_mask, x_mask], dim=1)
|
280 |
+
# Concatenate context and x along the second dimension (dim=1)
|
281 |
+
x = torch.cat((context, x), dim=1)
|
282 |
+
return x, x_mask
|
283 |
+
|
284 |
+
def forward(self, x, timesteps, context,
|
285 |
+
x_mask=None, context_mask=None,
|
286 |
+
cls_token=None, controlnet_skips=None,
|
287 |
+
):
|
288 |
+
# make it compatible with int time step during inference
|
289 |
+
if timesteps.dim() == 0:
|
290 |
+
timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
|
291 |
+
|
292 |
+
x = self.patch_embed(x)
|
293 |
+
x = self.x_pe(x)
|
294 |
+
|
295 |
+
B, L, D = x.shape
|
296 |
+
|
297 |
+
if self.use_context:
|
298 |
+
context_token = self.context_embed(context)
|
299 |
+
context_token = self.context_pe(context_token)
|
300 |
+
if self.context_fusion == 'concat' or self.context_fusion == 'joint':
|
301 |
+
x, x_mask = self._concat_x_context(x=x, context=context_token,
|
302 |
+
x_mask=x_mask,
|
303 |
+
context_mask=context_mask)
|
304 |
+
context_token, context_mask = None, None
|
305 |
+
else:
|
306 |
+
context_token, context_mask = None, None
|
307 |
+
|
308 |
+
time_token = self.time_embed(timesteps)
|
309 |
+
if self.cls_embed:
|
310 |
+
cls_token = self.cls_embed(cls_token)
|
311 |
+
time_ada = None
|
312 |
+
time_ada_final = None
|
313 |
+
if self.use_adanorm:
|
314 |
+
if self.cls_embed:
|
315 |
+
time_token = time_token + cls_token
|
316 |
+
time_token = self.time_act(time_token)
|
317 |
+
time_ada_final = self.time_ada_final(time_token)
|
318 |
+
if self.time_ada is not None:
|
319 |
+
time_ada = self.time_ada(time_token)
|
320 |
+
else:
|
321 |
+
time_token = time_token.unsqueeze(dim=1)
|
322 |
+
if self.cls_embed:
|
323 |
+
cls_token = cls_token.unsqueeze(dim=1)
|
324 |
+
time_token = torch.cat([time_token, cls_token], dim=1)
|
325 |
+
time_token = self.time_pe(time_token)
|
326 |
+
x = torch.cat((time_token, x), dim=1)
|
327 |
+
if x_mask is not None:
|
328 |
+
x_mask = torch.cat(
|
329 |
+
[torch.ones(B, time_token.shape[1], device=x_mask.device).bool(),
|
330 |
+
x_mask], dim=1)
|
331 |
+
time_token = None
|
332 |
+
|
333 |
+
skips = []
|
334 |
+
for blk in self.in_blocks:
|
335 |
+
x = blk(x=x, time_token=time_token, time_ada=time_ada,
|
336 |
+
skip=None, context=context_token,
|
337 |
+
x_mask=x_mask, context_mask=context_mask,
|
338 |
+
extras=self.extras)
|
339 |
+
if self.use_skip:
|
340 |
+
skips.append(x)
|
341 |
+
|
342 |
+
x = self.mid_block(x=x, time_token=time_token, time_ada=time_ada,
|
343 |
+
skip=None, context=context_token,
|
344 |
+
x_mask=x_mask, context_mask=context_mask,
|
345 |
+
extras=self.extras)
|
346 |
+
for blk in self.out_blocks:
|
347 |
+
if self.use_skip:
|
348 |
+
skip = skips.pop()
|
349 |
+
if controlnet_skips:
|
350 |
+
# add to skip like u-net controlnet
|
351 |
+
skip = skip + controlnet_skips.pop()
|
352 |
+
else:
|
353 |
+
skip = None
|
354 |
+
if controlnet_skips:
|
355 |
+
# directly add to x
|
356 |
+
x = x + controlnet_skips.pop()
|
357 |
+
|
358 |
+
x = blk(x=x, time_token=time_token, time_ada=time_ada,
|
359 |
+
skip=skip, context=context_token,
|
360 |
+
x_mask=x_mask, context_mask=context_mask,
|
361 |
+
extras=self.extras)
|
362 |
+
|
363 |
+
x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
|
364 |
+
|
365 |
+
return x
|
src/models/__pycache__/attention.cpython-311.pyc
ADDED
Binary file (6.05 kB). View file
|
|
src/models/__pycache__/blocks.cpython-310.pyc
ADDED
Binary file (7.32 kB). View file
|
|
src/models/__pycache__/blocks.cpython-311.pyc
ADDED
Binary file (14.9 kB). View file
|
|
src/models/__pycache__/conditioners.cpython-310.pyc
ADDED
Binary file (5.63 kB). View file
|
|
src/models/__pycache__/conditioners.cpython-311.pyc
ADDED
Binary file (10.3 kB). View file
|
|
src/models/__pycache__/controlnet.cpython-311.pyc
ADDED
Binary file (15.2 kB). View file
|
|
src/models/__pycache__/modules.cpython-311.pyc
ADDED
Binary file (11.3 kB). View file
|
|
src/models/__pycache__/rotary.cpython-311.pyc
ADDED
Binary file (4.83 kB). View file
|
|
src/models/__pycache__/timm.cpython-311.pyc
ADDED
Binary file (6.45 kB). View file
|
|
src/models/__pycache__/udit.cpython-310.pyc
ADDED
Binary file (7.9 kB). View file
|
|
src/models/__pycache__/udit.cpython-311.pyc
ADDED
Binary file (18.5 kB). View file
|
|
src/models/blocks.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.utils.checkpoint import checkpoint
|
4 |
+
from .utils.attention import Attention, JointAttention
|
5 |
+
from .utils.modules import unpatchify, FeedForward
|
6 |
+
from .utils.modules import film_modulate
|
7 |
+
|
8 |
+
|
9 |
+
class AdaLN(nn.Module):
|
10 |
+
def __init__(self, dim, ada_mode='ada', r=None, alpha=None):
|
11 |
+
super().__init__()
|
12 |
+
self.ada_mode = ada_mode
|
13 |
+
self.scale_shift_table = None
|
14 |
+
if ada_mode == 'ada':
|
15 |
+
# move nn.silu outside
|
16 |
+
self.time_ada = nn.Linear(dim, 6 * dim, bias=True)
|
17 |
+
elif ada_mode == 'ada_single':
|
18 |
+
# adaln used in pixel-art alpha
|
19 |
+
self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
|
20 |
+
elif ada_mode in ['ada_lora', 'ada_lora_bias']:
|
21 |
+
self.lora_a = nn.Linear(dim, r * 6, bias=False)
|
22 |
+
self.lora_b = nn.Linear(r * 6, dim * 6, bias=False)
|
23 |
+
self.scaling = alpha / r
|
24 |
+
if ada_mode == 'ada_lora_bias':
|
25 |
+
# take bias out for consistency
|
26 |
+
self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
|
27 |
+
else:
|
28 |
+
raise NotImplementedError
|
29 |
+
|
30 |
+
def forward(self, time_token=None, time_ada=None):
|
31 |
+
if self.ada_mode == 'ada':
|
32 |
+
assert time_ada is None
|
33 |
+
B = time_token.shape[0]
|
34 |
+
time_ada = self.time_ada(time_token).reshape(B, 6, -1)
|
35 |
+
elif self.ada_mode == 'ada_single':
|
36 |
+
B = time_ada.shape[0]
|
37 |
+
time_ada = time_ada.reshape(B, 6, -1)
|
38 |
+
time_ada = self.scale_shift_table[None] + time_ada
|
39 |
+
elif self.ada_mode in ['ada_lora', 'ada_lora_bias']:
|
40 |
+
B = time_ada.shape[0]
|
41 |
+
time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling
|
42 |
+
time_ada = time_ada + time_ada_lora
|
43 |
+
time_ada = time_ada.reshape(B, 6, -1)
|
44 |
+
if self.scale_shift_table is not None:
|
45 |
+
time_ada = self.scale_shift_table[None] + time_ada
|
46 |
+
else:
|
47 |
+
raise NotImplementedError
|
48 |
+
return time_ada
|
49 |
+
|
50 |
+
|
51 |
+
class DiTBlock(nn.Module):
|
52 |
+
"""
|
53 |
+
A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(self, dim, context_dim=None,
|
57 |
+
num_heads=8, mlp_ratio=4.,
|
58 |
+
qkv_bias=False, qk_scale=None, qk_norm=None,
|
59 |
+
act_layer='gelu', norm_layer=nn.LayerNorm,
|
60 |
+
time_fusion='none',
|
61 |
+
ada_lora_rank=None, ada_lora_alpha=None,
|
62 |
+
skip=False, skip_norm=False,
|
63 |
+
rope_mode='none',
|
64 |
+
context_norm=False,
|
65 |
+
use_checkpoint=False):
|
66 |
+
|
67 |
+
super().__init__()
|
68 |
+
self.norm1 = norm_layer(dim)
|
69 |
+
self.attn = Attention(dim=dim,
|
70 |
+
num_heads=num_heads,
|
71 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
72 |
+
qk_norm=qk_norm,
|
73 |
+
rope_mode=rope_mode)
|
74 |
+
|
75 |
+
if context_dim is not None:
|
76 |
+
self.use_context = True
|
77 |
+
self.cross_attn = Attention(dim=dim,
|
78 |
+
num_heads=num_heads,
|
79 |
+
context_dim=context_dim,
|
80 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
81 |
+
qk_norm=qk_norm,
|
82 |
+
rope_mode='none')
|
83 |
+
self.norm2 = norm_layer(dim)
|
84 |
+
if context_norm:
|
85 |
+
self.norm_context = norm_layer(context_dim)
|
86 |
+
else:
|
87 |
+
self.norm_context = nn.Identity()
|
88 |
+
else:
|
89 |
+
self.use_context = False
|
90 |
+
|
91 |
+
self.norm3 = norm_layer(dim)
|
92 |
+
self.mlp = FeedForward(dim=dim, mult=mlp_ratio,
|
93 |
+
activation_fn=act_layer, dropout=0)
|
94 |
+
|
95 |
+
self.use_adanorm = True if time_fusion != 'token' else False
|
96 |
+
if self.use_adanorm:
|
97 |
+
self.adaln = AdaLN(dim, ada_mode=time_fusion,
|
98 |
+
r=ada_lora_rank, alpha=ada_lora_alpha)
|
99 |
+
if skip:
|
100 |
+
self.skip_norm = norm_layer(2 * dim) if skip_norm else nn.Identity()
|
101 |
+
self.skip_linear = nn.Linear(2 * dim, dim)
|
102 |
+
else:
|
103 |
+
self.skip_linear = None
|
104 |
+
|
105 |
+
self.use_checkpoint = use_checkpoint
|
106 |
+
|
107 |
+
def forward(self, x, time_token=None, time_ada=None,
|
108 |
+
skip=None, context=None,
|
109 |
+
x_mask=None, context_mask=None, extras=None):
|
110 |
+
if self.use_checkpoint:
|
111 |
+
return checkpoint(self._forward, x,
|
112 |
+
time_token, time_ada, skip, context,
|
113 |
+
x_mask, context_mask, extras,
|
114 |
+
use_reentrant=False)
|
115 |
+
else:
|
116 |
+
return self._forward(x,
|
117 |
+
time_token, time_ada, skip, context,
|
118 |
+
x_mask, context_mask, extras)
|
119 |
+
|
120 |
+
def _forward(self, x, time_token=None, time_ada=None,
|
121 |
+
skip=None, context=None,
|
122 |
+
x_mask=None, context_mask=None, extras=None):
|
123 |
+
B, T, C = x.shape
|
124 |
+
if self.skip_linear is not None:
|
125 |
+
assert skip is not None
|
126 |
+
cat = torch.cat([x, skip], dim=-1)
|
127 |
+
cat = self.skip_norm(cat)
|
128 |
+
x = self.skip_linear(cat)
|
129 |
+
|
130 |
+
if self.use_adanorm:
|
131 |
+
time_ada = self.adaln(time_token, time_ada)
|
132 |
+
(shift_msa, scale_msa, gate_msa,
|
133 |
+
shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
|
134 |
+
|
135 |
+
# self attention
|
136 |
+
if self.use_adanorm:
|
137 |
+
x_norm = film_modulate(self.norm1(x), shift=shift_msa,
|
138 |
+
scale=scale_msa)
|
139 |
+
x = x + (1 - gate_msa) * self.attn(x_norm, context=None,
|
140 |
+
context_mask=x_mask,
|
141 |
+
extras=extras)
|
142 |
+
else:
|
143 |
+
x = x + self.attn(self.norm1(x), context=None, context_mask=x_mask,
|
144 |
+
extras=extras)
|
145 |
+
|
146 |
+
# cross attention
|
147 |
+
if self.use_context:
|
148 |
+
assert context is not None
|
149 |
+
x = x + self.cross_attn(x=self.norm2(x),
|
150 |
+
context=self.norm_context(context),
|
151 |
+
context_mask=context_mask, extras=extras)
|
152 |
+
|
153 |
+
# mlp
|
154 |
+
if self.use_adanorm:
|
155 |
+
x_norm = film_modulate(self.norm3(x), shift=shift_mlp, scale=scale_mlp)
|
156 |
+
x = x + (1 - gate_mlp) * self.mlp(x_norm)
|
157 |
+
else:
|
158 |
+
x = x + self.mlp(self.norm3(x))
|
159 |
+
|
160 |
+
return x
|
161 |
+
|
162 |
+
|
163 |
+
class JointDiTBlock(nn.Module):
|
164 |
+
"""
|
165 |
+
A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(self, dim, context_dim=None,
|
169 |
+
num_heads=8, mlp_ratio=4.,
|
170 |
+
qkv_bias=False, qk_scale=None, qk_norm=None,
|
171 |
+
act_layer='gelu', norm_layer=nn.LayerNorm,
|
172 |
+
time_fusion='none',
|
173 |
+
ada_lora_rank=None, ada_lora_alpha=None,
|
174 |
+
skip=(False, False),
|
175 |
+
rope_mode=False,
|
176 |
+
context_norm=False,
|
177 |
+
use_checkpoint=False,):
|
178 |
+
|
179 |
+
super().__init__()
|
180 |
+
# no cross attention
|
181 |
+
assert context_dim is None
|
182 |
+
self.attn_norm_x = norm_layer(dim)
|
183 |
+
self.attn_norm_c = norm_layer(dim)
|
184 |
+
self.attn = JointAttention(dim=dim,
|
185 |
+
num_heads=num_heads,
|
186 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
187 |
+
qk_norm=qk_norm,
|
188 |
+
rope_mode=rope_mode)
|
189 |
+
self.ffn_norm_x = norm_layer(dim)
|
190 |
+
self.ffn_norm_c = norm_layer(dim)
|
191 |
+
self.mlp_x = FeedForward(dim=dim, mult=mlp_ratio,
|
192 |
+
activation_fn=act_layer, dropout=0)
|
193 |
+
self.mlp_c = FeedForward(dim=dim, mult=mlp_ratio,
|
194 |
+
activation_fn=act_layer, dropout=0)
|
195 |
+
|
196 |
+
# Zero-out the shift table
|
197 |
+
self.use_adanorm = True if time_fusion != 'token' else False
|
198 |
+
if self.use_adanorm:
|
199 |
+
self.adaln = AdaLN(dim, ada_mode=time_fusion,
|
200 |
+
r=ada_lora_rank, alpha=ada_lora_alpha)
|
201 |
+
|
202 |
+
if skip is False:
|
203 |
+
skip_x, skip_c = False, False
|
204 |
+
else:
|
205 |
+
skip_x, skip_c = skip
|
206 |
+
|
207 |
+
self.skip_linear_x = nn.Linear(2 * dim, dim) if skip_x else None
|
208 |
+
self.skip_linear_c = nn.Linear(2 * dim, dim) if skip_c else None
|
209 |
+
|
210 |
+
self.use_checkpoint = use_checkpoint
|
211 |
+
|
212 |
+
def forward(self, x, time_token=None, time_ada=None,
|
213 |
+
skip=None, context=None,
|
214 |
+
x_mask=None, context_mask=None, extras=None):
|
215 |
+
if self.use_checkpoint:
|
216 |
+
return checkpoint(self._forward, x,
|
217 |
+
time_token, time_ada, skip,
|
218 |
+
context, x_mask, context_mask, extras,
|
219 |
+
use_reentrant=False)
|
220 |
+
else:
|
221 |
+
return self._forward(x,
|
222 |
+
time_token, time_ada, skip,
|
223 |
+
context, x_mask, context_mask, extras)
|
224 |
+
|
225 |
+
def _forward(self, x, time_token=None, time_ada=None,
|
226 |
+
skip=None, context=None,
|
227 |
+
x_mask=None, context_mask=None, extras=None):
|
228 |
+
|
229 |
+
assert context is None and context_mask is None
|
230 |
+
|
231 |
+
context, x = x[:, :extras, :], x[:, extras:, :]
|
232 |
+
context_mask, x_mask = x_mask[:, :extras], x_mask[:, extras:]
|
233 |
+
|
234 |
+
if skip is not None:
|
235 |
+
skip_c, skip_x = skip[:, :extras, :], skip[:, extras:, :]
|
236 |
+
|
237 |
+
B, T, C = x.shape
|
238 |
+
if self.skip_linear_x is not None:
|
239 |
+
x = self.skip_linear_x(torch.cat([x, skip_x], dim=-1))
|
240 |
+
|
241 |
+
if self.skip_linear_c is not None:
|
242 |
+
context = self.skip_linear_c(torch.cat([context, skip_c], dim=-1))
|
243 |
+
|
244 |
+
if self.use_adanorm:
|
245 |
+
time_ada = self.adaln(time_token, time_ada)
|
246 |
+
(shift_msa, scale_msa, gate_msa,
|
247 |
+
shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
|
248 |
+
|
249 |
+
# self attention
|
250 |
+
x_norm = self.attn_norm_x(x)
|
251 |
+
c_norm = self.attn_norm_c(context)
|
252 |
+
if self.use_adanorm:
|
253 |
+
x_norm = film_modulate(x_norm, shift=shift_msa, scale=scale_msa)
|
254 |
+
x_out, c_out = self.attn(x_norm, context=c_norm,
|
255 |
+
x_mask=x_mask, context_mask=context_mask,
|
256 |
+
extras=extras)
|
257 |
+
if self.use_adanorm:
|
258 |
+
x = x + (1 - gate_msa) * x_out
|
259 |
+
else:
|
260 |
+
x = x + x_out
|
261 |
+
context = context + c_out
|
262 |
+
|
263 |
+
# mlp
|
264 |
+
if self.use_adanorm:
|
265 |
+
x_norm = film_modulate(self.ffn_norm_x(x),
|
266 |
+
shift=shift_mlp, scale=scale_mlp)
|
267 |
+
x = x + (1 - gate_mlp) * self.mlp_x(x_norm)
|
268 |
+
else:
|
269 |
+
x = x + self.mlp_x(self.ffn_norm_x(x))
|
270 |
+
|
271 |
+
c_norm = self.ffn_norm_c(context)
|
272 |
+
context = context + self.mlp_c(c_norm)
|
273 |
+
|
274 |
+
return torch.cat((context, x), dim=1)
|
275 |
+
|
276 |
+
|
277 |
+
class FinalBlock(nn.Module):
|
278 |
+
def __init__(self, embed_dim, patch_size, in_chans,
|
279 |
+
img_size,
|
280 |
+
input_type='2d',
|
281 |
+
norm_layer=nn.LayerNorm,
|
282 |
+
use_conv=True,
|
283 |
+
use_adanorm=True):
|
284 |
+
super().__init__()
|
285 |
+
self.in_chans = in_chans
|
286 |
+
self.img_size = img_size
|
287 |
+
self.input_type = input_type
|
288 |
+
|
289 |
+
self.norm = norm_layer(embed_dim)
|
290 |
+
if use_adanorm:
|
291 |
+
self.use_adanorm = True
|
292 |
+
else:
|
293 |
+
self.use_adanorm = False
|
294 |
+
|
295 |
+
if input_type == '2d':
|
296 |
+
self.patch_dim = patch_size ** 2 * in_chans
|
297 |
+
self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
|
298 |
+
if use_conv:
|
299 |
+
self.final_layer = nn.Conv2d(self.in_chans, self.in_chans,
|
300 |
+
3, padding=1)
|
301 |
+
else:
|
302 |
+
self.final_layer = nn.Identity()
|
303 |
+
|
304 |
+
elif input_type == '1d':
|
305 |
+
self.patch_dim = patch_size * in_chans
|
306 |
+
self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
|
307 |
+
if use_conv:
|
308 |
+
self.final_layer = nn.Conv1d(self.in_chans, self.in_chans,
|
309 |
+
3, padding=1)
|
310 |
+
else:
|
311 |
+
self.final_layer = nn.Identity()
|
312 |
+
|
313 |
+
def forward(self, x, time_ada=None, extras=0):
|
314 |
+
B, T, C = x.shape
|
315 |
+
x = x[:, extras:, :]
|
316 |
+
# only handle generation target
|
317 |
+
if self.use_adanorm:
|
318 |
+
shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1)
|
319 |
+
x = film_modulate(self.norm(x), shift, scale)
|
320 |
+
else:
|
321 |
+
x = self.norm(x)
|
322 |
+
x = self.linear(x)
|
323 |
+
x = unpatchify(x, self.in_chans, self.input_type, self.img_size)
|
324 |
+
x = self.final_layer(x)
|
325 |
+
return x
|