Spaces:
Sleeping
Sleeping
hugo flores garcia
commited on
Commit
•
41b9d24
0
Parent(s):
recovering from a gittastrophe
Browse files- .gitattributes +34 -0
- .gitignore +194 -0
- LICENSE +21 -0
- README.md +113 -0
- app.py +428 -0
- assets/example.wav +0 -0
- conf/c2f.yml +14 -0
- conf/interface.yml +10 -0
- conf/lora/lora.yml +22 -0
- conf/salad_bowl.yml +0 -0
- conf/vampnet.yml +49 -0
- hello.py +43 -0
- requirements.txt +10 -0
- scripts/exp/eval.py +110 -0
- scripts/exp/experiment.py +254 -0
- scripts/exp/export.py +22 -0
- scripts/exp/fine_tune.py +81 -0
- scripts/exp/train.py +686 -0
- scripts/utils/README.md +28 -0
- scripts/utils/gtzan_embeddings.py +264 -0
- scripts/utils/plots.py +43 -0
- scripts/utils/remove_quiet_files.py +29 -0
- scripts/utils/split.py +66 -0
- scripts/utils/split_long_audio_file.py +34 -0
- scripts/utils/stage.py +30 -0
- scripts/utils/visualize_embeddings.py +265 -0
- scripts/utils/xeno-canto-dl.py +234 -0
- setup.py +40 -0
- token_telephone/tt.py +616 -0
- token_telephone/ttutil.py +65 -0
- token_telephone/vamp_helper.py +172 -0
- vampnet/__init__.py +90 -0
- vampnet/beats.py +250 -0
- vampnet/interface.py +623 -0
- vampnet/mask.py +226 -0
- vampnet/modules/__init__.py +6 -0
- vampnet/modules/activations.py +55 -0
- vampnet/modules/layers.py +164 -0
- vampnet/modules/transformer.py +965 -0
- vampnet/scheduler.py +47 -0
- vampnet/util.py +46 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
25 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/env.sh
|
108 |
+
venv/
|
109 |
+
env.bak/
|
110 |
+
venv.bak/
|
111 |
+
|
112 |
+
# Spyder project settings
|
113 |
+
.spyderproject
|
114 |
+
.spyproject
|
115 |
+
|
116 |
+
# Rope project settings
|
117 |
+
.ropeproject
|
118 |
+
|
119 |
+
# mkdocs documentation
|
120 |
+
/site
|
121 |
+
|
122 |
+
# mypy
|
123 |
+
.mypy_cache/
|
124 |
+
.dmypy.json
|
125 |
+
dmypy.json
|
126 |
+
|
127 |
+
# Pyre type checker
|
128 |
+
.pyre/
|
129 |
+
|
130 |
+
# Files created by experiments
|
131 |
+
output/
|
132 |
+
snapshot/
|
133 |
+
*.m4a
|
134 |
+
notebooks/scratch.ipynb
|
135 |
+
notebooks/inspect.ipynb
|
136 |
+
notebooks/effects.ipynb
|
137 |
+
notebooks/*.ipynb
|
138 |
+
notebooks/*.gif
|
139 |
+
notebooks/*.wav
|
140 |
+
notebooks/*.mp4
|
141 |
+
*runs/
|
142 |
+
boards/
|
143 |
+
samples/
|
144 |
+
*.ipynb
|
145 |
+
|
146 |
+
results.json
|
147 |
+
metrics.csv
|
148 |
+
mprofile_*
|
149 |
+
mem.png
|
150 |
+
|
151 |
+
results/
|
152 |
+
mprofile*
|
153 |
+
*.png
|
154 |
+
# do not ignore the test wav file
|
155 |
+
!tests/audio/short_test_audio.wav
|
156 |
+
!tests/audio/output.wav
|
157 |
+
*/.DS_Store
|
158 |
+
.DS_Store
|
159 |
+
env.sh
|
160 |
+
_codebraid/
|
161 |
+
**/*.html
|
162 |
+
**/*.exec.md
|
163 |
+
flagged/
|
164 |
+
log.txt
|
165 |
+
ckpt/
|
166 |
+
.syncthing*
|
167 |
+
tests/assets/
|
168 |
+
archived/
|
169 |
+
|
170 |
+
scratch/
|
171 |
+
|
172 |
+
runs-archive
|
173 |
+
lyrebird-audiotools
|
174 |
+
lyrebird-audio-codec
|
175 |
+
samples-*/**
|
176 |
+
|
177 |
+
gradio-outputs/
|
178 |
+
samples*/
|
179 |
+
models-all/
|
180 |
+
models.zip
|
181 |
+
.git-old
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
gtzan.zip
|
186 |
+
.gtzan_emb_cache
|
187 |
+
|
188 |
+
|
189 |
+
data/
|
190 |
+
data
|
191 |
+
pyharp
|
192 |
+
|
193 |
+
models/vampnet/*
|
194 |
+
models/*
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Hugo Flores García and Prem Seetharaman
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: salad bowl (vampnet)
|
3 |
+
emoji: 🥗
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.37.2
|
8 |
+
python_version: 3.9.17
|
9 |
+
app_file: app.py
|
10 |
+
pinned: false
|
11 |
+
license: cc-by-nc-4.0
|
12 |
+
---
|
13 |
+
|
14 |
+
# VampNet
|
15 |
+
|
16 |
+
This repository contains recipes for training generative music models on top of the Descript Audio Codec.
|
17 |
+
|
18 |
+
# Setting up
|
19 |
+
|
20 |
+
**Requires Python 3.9**.
|
21 |
+
|
22 |
+
you'll need a Python 3.9 environment to run VampNet. This is due to a [known issue with madmom](https://github.com/hugofloresgarcia/vampnet/issues/15).
|
23 |
+
|
24 |
+
(for example, using conda)
|
25 |
+
```bash
|
26 |
+
conda create -n vampnet python=3.9
|
27 |
+
conda activate vampnet
|
28 |
+
```
|
29 |
+
|
30 |
+
install VampNet
|
31 |
+
|
32 |
+
```bash
|
33 |
+
git clone https://github.com/hugofloresgarcia/vampnet.git
|
34 |
+
pip install -e ./vampnet
|
35 |
+
```
|
36 |
+
|
37 |
+
# Usage
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
## Launching the Gradio Interface
|
42 |
+
You can launch a gradio UI to play with vampnet.
|
43 |
+
|
44 |
+
```bash
|
45 |
+
python app.py --args.load conf/interface.yml --Interface.device cuda
|
46 |
+
```
|
47 |
+
|
48 |
+
# Training / Fine-tuning
|
49 |
+
|
50 |
+
## Training a model
|
51 |
+
|
52 |
+
To train a model, run the following script:
|
53 |
+
|
54 |
+
```bash
|
55 |
+
python scripts/exp/train.py --args.load conf/vampnet.yml --save_path /path/to/checkpoints
|
56 |
+
```
|
57 |
+
|
58 |
+
for multi-gpu training, use torchrun:
|
59 |
+
|
60 |
+
```bash
|
61 |
+
torchrun --nproc_per_node gpu scripts/exp/train.py --args.load conf/vampnet.yml --save_path path/to/ckpt
|
62 |
+
```
|
63 |
+
|
64 |
+
You can edit `conf/vampnet.yml` to change the dataset paths or any training hyperparameters.
|
65 |
+
|
66 |
+
For coarse2fine models, you can use `conf/c2f.yml` as a starting configuration.
|
67 |
+
|
68 |
+
See `python scripts/exp/train.py -h` for a list of options.
|
69 |
+
|
70 |
+
## Debugging training
|
71 |
+
|
72 |
+
To debug training, it's easier to debug with 1 gpu and 0 workers
|
73 |
+
|
74 |
+
```bash
|
75 |
+
CUDA_VISIBLE_DEVICES=0 python -m pdb scripts/exp/train.py --args.load conf/vampnet.yml --save_path /path/to/checkpoints --num_workers 0
|
76 |
+
```
|
77 |
+
|
78 |
+
## Fine-tuning
|
79 |
+
To fine-tune a model, use the script in `scripts/exp/fine_tune.py` to generate 3 configuration files: `c2f.yml`, `coarse.yml`, and `interface.yml`.
|
80 |
+
The first two are used to fine-tune the coarse and fine models, respectively. The last one is used to launch the gradio interface.
|
81 |
+
|
82 |
+
```bash
|
83 |
+
python scripts/exp/fine_tune.py "/path/to/audio1.mp3 /path/to/audio2/ /path/to/audio3.wav" <fine_tune_name>
|
84 |
+
```
|
85 |
+
|
86 |
+
This will create a folder under `conf/<fine_tune_name>/` with the 3 configuration files.
|
87 |
+
|
88 |
+
The save_paths will be set to `runs/<fine_tune_name>/coarse` and `runs/<fine_tune_name>/c2f`.
|
89 |
+
|
90 |
+
launch the coarse job:
|
91 |
+
```bash
|
92 |
+
python scripts/exp/train.py --args.load conf/generated/<fine_tune_name>/coarse.yml
|
93 |
+
```
|
94 |
+
|
95 |
+
this will save the coarse model to `runs/<fine_tune_name>/coarse/ckpt/best/`.
|
96 |
+
|
97 |
+
launch the c2f job:
|
98 |
+
```bash
|
99 |
+
python scripts/exp/train.py --args.load conf/generated/<fine_tune_name>/c2f.yml
|
100 |
+
```
|
101 |
+
|
102 |
+
## A note on argbind
|
103 |
+
This repository relies on [argbind](https://github.com/pseeth/argbind) to manage CLIs and config files.
|
104 |
+
Config files are stored in the `conf/` folder.
|
105 |
+
|
106 |
+
### Licensing for Pretrained Models:
|
107 |
+
The weights for the models are licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml). Likewise, any VampNet models fine-tuned on the pretrained models are also licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml).
|
108 |
+
|
109 |
+
Download the pretrained models from [this link](https://zenodo.org/record/8136629). Then, extract the models to the `models/` folder.
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
|
app.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
from pathlib import Path
|
3 |
+
import yaml
|
4 |
+
import time
|
5 |
+
import uuid
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import audiotools as at
|
9 |
+
import argbind
|
10 |
+
import shutil
|
11 |
+
import torch
|
12 |
+
from datetime import datetime
|
13 |
+
|
14 |
+
import gradio as gr
|
15 |
+
from vampnet.interface import Interface, signal_concat
|
16 |
+
from vampnet import mask as pmask
|
17 |
+
|
18 |
+
|
19 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
+
|
21 |
+
interface = Interface.default()
|
22 |
+
|
23 |
+
# populate the model choices with any interface.yml files in the generated confs
|
24 |
+
MODEL_CHOICES = {
|
25 |
+
"default": {
|
26 |
+
"Interface.coarse_ckpt": str(interface.coarse_path),
|
27 |
+
"Interface.coarse2fine_ckpt": str(interface.c2f_path),
|
28 |
+
"Interface.codec_ckpt": str(interface.codec_path),
|
29 |
+
}
|
30 |
+
}
|
31 |
+
generated_confs = Path("conf/generated")
|
32 |
+
for conf_file in generated_confs.glob("*/interface.yml"):
|
33 |
+
with open(conf_file) as f:
|
34 |
+
_conf = yaml.safe_load(f)
|
35 |
+
|
36 |
+
# check if the coarse, c2f, and codec ckpts exist
|
37 |
+
# otherwise, dont' add this model choice
|
38 |
+
if not (
|
39 |
+
Path(_conf["Interface.coarse_ckpt"]).exists() and
|
40 |
+
Path(_conf["Interface.coarse2fine_ckpt"]).exists() and
|
41 |
+
Path(_conf["Interface.codec_ckpt"]).exists()
|
42 |
+
):
|
43 |
+
continue
|
44 |
+
|
45 |
+
MODEL_CHOICES[conf_file.parent.name] = _conf
|
46 |
+
|
47 |
+
|
48 |
+
def to_output(sig):
|
49 |
+
return sig.sample_rate, sig.cpu().detach().numpy()[0][0]
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
MAX_DURATION_S = 5
|
54 |
+
def load_audio(file):
|
55 |
+
print(file)
|
56 |
+
if isinstance(file, str):
|
57 |
+
filepath = file
|
58 |
+
elif isinstance(file, tuple):
|
59 |
+
# not a file
|
60 |
+
sr, samples = file
|
61 |
+
samples = samples / np.iinfo(samples.dtype).max
|
62 |
+
return sr, samples
|
63 |
+
else:
|
64 |
+
filepath = file.name
|
65 |
+
sig = at.AudioSignal.salient_excerpt(
|
66 |
+
filepath, duration=MAX_DURATION_S
|
67 |
+
)
|
68 |
+
sig = at.AudioSignal(filepath)
|
69 |
+
return to_output(sig)
|
70 |
+
|
71 |
+
|
72 |
+
def load_example_audio():
|
73 |
+
return load_audio("./assets/example.wav")
|
74 |
+
|
75 |
+
from torch_pitch_shift import pitch_shift, get_fast_shifts
|
76 |
+
def shift_pitch(signal, interval: int):
|
77 |
+
signal.samples = pitch_shift(
|
78 |
+
signal.samples,
|
79 |
+
shift=interval,
|
80 |
+
sample_rate=signal.sample_rate
|
81 |
+
)
|
82 |
+
return signal
|
83 |
+
|
84 |
+
|
85 |
+
@spaces.GPU
|
86 |
+
def _vamp(
|
87 |
+
seed, input_audio, model_choice,
|
88 |
+
pitch_shift_amt, periodic_p,
|
89 |
+
n_mask_codebooks, periodic_w, onset_mask_width,
|
90 |
+
dropout, sampletemp, typical_filtering,
|
91 |
+
typical_mass, typical_min_tokens, top_p,
|
92 |
+
sample_cutoff, stretch_factor, api=False
|
93 |
+
):
|
94 |
+
t0 = time.time()
|
95 |
+
interface.to("cuda" if torch.cuda.is_available() else "cpu")
|
96 |
+
print(f"using device {interface.device}")
|
97 |
+
_seed = seed if seed > 0 else None
|
98 |
+
if _seed is None:
|
99 |
+
_seed = int(torch.randint(0, 2**32, (1,)).item())
|
100 |
+
at.util.seed(_seed)
|
101 |
+
|
102 |
+
sr, input_audio = input_audio
|
103 |
+
input_audio = input_audio / np.iinfo(input_audio.dtype).max
|
104 |
+
|
105 |
+
sig = at.AudioSignal(input_audio, sr)
|
106 |
+
|
107 |
+
# reload the model if necessary
|
108 |
+
interface.reload(
|
109 |
+
coarse_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse_ckpt"],
|
110 |
+
c2f_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse2fine_ckpt"],
|
111 |
+
)
|
112 |
+
|
113 |
+
if pitch_shift_amt != 0:
|
114 |
+
sig = shift_pitch(sig, pitch_shift_amt)
|
115 |
+
|
116 |
+
build_mask_kwargs = dict(
|
117 |
+
rand_mask_intensity=1.0,
|
118 |
+
prefix_s=0.0,
|
119 |
+
suffix_s=0.0,
|
120 |
+
periodic_prompt=int(periodic_p),
|
121 |
+
periodic_prompt_width=periodic_w,
|
122 |
+
onset_mask_width=onset_mask_width,
|
123 |
+
_dropout=dropout,
|
124 |
+
upper_codebook_mask=int(n_mask_codebooks),
|
125 |
+
)
|
126 |
+
|
127 |
+
vamp_kwargs = dict(
|
128 |
+
temperature=sampletemp,
|
129 |
+
typical_filtering=typical_filtering,
|
130 |
+
typical_mass=typical_mass,
|
131 |
+
typical_min_tokens=typical_min_tokens,
|
132 |
+
top_p=None,
|
133 |
+
seed=_seed,
|
134 |
+
sample_cutoff=1.0,
|
135 |
+
)
|
136 |
+
|
137 |
+
# save the mask as a txt file
|
138 |
+
interface.set_chunk_size(10.0)
|
139 |
+
sig, mask, codes = interface.ez_vamp(
|
140 |
+
sig,
|
141 |
+
batch_size=1 if api else 1,
|
142 |
+
feedback_steps=1,
|
143 |
+
time_stretch_factor=stretch_factor,
|
144 |
+
build_mask_kwargs=build_mask_kwargs,
|
145 |
+
vamp_kwargs=vamp_kwargs,
|
146 |
+
return_mask=True,
|
147 |
+
)
|
148 |
+
print(f"vamp took {time.time() - t0} seconds")
|
149 |
+
|
150 |
+
|
151 |
+
return to_output(sig)
|
152 |
+
|
153 |
+
def vamp(data):
|
154 |
+
return _vamp(
|
155 |
+
seed=data[seed],
|
156 |
+
input_audio=data[input_audio],
|
157 |
+
model_choice=data[model_choice],
|
158 |
+
pitch_shift_amt=data[pitch_shift_amt],
|
159 |
+
periodic_p=data[periodic_p],
|
160 |
+
n_mask_codebooks=data[n_mask_codebooks],
|
161 |
+
periodic_w=data[periodic_w],
|
162 |
+
onset_mask_width=data[onset_mask_width],
|
163 |
+
dropout=data[dropout],
|
164 |
+
sampletemp=data[sampletemp],
|
165 |
+
typical_filtering=data[typical_filtering],
|
166 |
+
typical_mass=data[typical_mass],
|
167 |
+
typical_min_tokens=data[typical_min_tokens],
|
168 |
+
top_p=data[top_p],
|
169 |
+
sample_cutoff=data[sample_cutoff],
|
170 |
+
stretch_factor=data[stretch_factor],
|
171 |
+
api=False,
|
172 |
+
)
|
173 |
+
|
174 |
+
def api_vamp(data):
|
175 |
+
return _vamp(
|
176 |
+
seed=data[seed],
|
177 |
+
input_audio=data[input_audio],
|
178 |
+
model_choice=data[model_choice],
|
179 |
+
pitch_shift_amt=data[pitch_shift_amt],
|
180 |
+
periodic_p=data[periodic_p],
|
181 |
+
n_mask_codebooks=data[n_mask_codebooks],
|
182 |
+
periodic_w=data[periodic_w],
|
183 |
+
onset_mask_width=data[onset_mask_width],
|
184 |
+
dropout=data[dropout],
|
185 |
+
sampletemp=data[sampletemp],
|
186 |
+
typical_filtering=data[typical_filtering],
|
187 |
+
typical_mass=data[typical_mass],
|
188 |
+
typical_min_tokens=data[typical_min_tokens],
|
189 |
+
top_p=data[top_p],
|
190 |
+
sample_cutoff=data[sample_cutoff],
|
191 |
+
stretch_factor=data[stretch_factor],
|
192 |
+
api=True,
|
193 |
+
)
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
with gr.Blocks() as demo:
|
200 |
+
with gr.Row():
|
201 |
+
with gr.Column():
|
202 |
+
manual_audio_upload = gr.File(
|
203 |
+
label=f"upload some audio (will be randomly trimmed to max of 100s)",
|
204 |
+
file_types=["audio"]
|
205 |
+
)
|
206 |
+
load_example_audio_button = gr.Button("or load example audio")
|
207 |
+
|
208 |
+
input_audio = gr.Audio(
|
209 |
+
label="input audio",
|
210 |
+
interactive=False,
|
211 |
+
type="numpy",
|
212 |
+
)
|
213 |
+
|
214 |
+
audio_mask = gr.Audio(
|
215 |
+
label="audio mask (listen to this to hear the mask hints)",
|
216 |
+
interactive=False,
|
217 |
+
type="numpy",
|
218 |
+
)
|
219 |
+
|
220 |
+
# connect widgets
|
221 |
+
load_example_audio_button.click(
|
222 |
+
fn=load_example_audio,
|
223 |
+
inputs=[],
|
224 |
+
outputs=[ input_audio]
|
225 |
+
)
|
226 |
+
|
227 |
+
manual_audio_upload.change(
|
228 |
+
fn=load_audio,
|
229 |
+
inputs=[manual_audio_upload],
|
230 |
+
outputs=[ input_audio]
|
231 |
+
)
|
232 |
+
|
233 |
+
|
234 |
+
# mask settings
|
235 |
+
with gr.Column():
|
236 |
+
with gr.Accordion("manual controls", open=True):
|
237 |
+
periodic_p = gr.Slider(
|
238 |
+
label="periodic prompt",
|
239 |
+
minimum=0,
|
240 |
+
maximum=13,
|
241 |
+
step=1,
|
242 |
+
value=3,
|
243 |
+
)
|
244 |
+
|
245 |
+
onset_mask_width = gr.Slider(
|
246 |
+
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
|
247 |
+
minimum=0,
|
248 |
+
maximum=100,
|
249 |
+
step=1,
|
250 |
+
value=0, visible=False
|
251 |
+
)
|
252 |
+
|
253 |
+
n_mask_codebooks = gr.Slider(
|
254 |
+
label="compression prompt ",
|
255 |
+
value=3,
|
256 |
+
minimum=1,
|
257 |
+
maximum=14,
|
258 |
+
step=1,
|
259 |
+
)
|
260 |
+
|
261 |
+
maskimg = gr.Image(
|
262 |
+
label="mask image",
|
263 |
+
interactive=False,
|
264 |
+
type="filepath"
|
265 |
+
)
|
266 |
+
|
267 |
+
with gr.Accordion("extras ", open=False):
|
268 |
+
pitch_shift_amt = gr.Slider(
|
269 |
+
label="pitch shift amount (semitones)",
|
270 |
+
minimum=-12,
|
271 |
+
maximum=12,
|
272 |
+
step=1,
|
273 |
+
value=0,
|
274 |
+
)
|
275 |
+
|
276 |
+
stretch_factor = gr.Slider(
|
277 |
+
label="time stretch factor",
|
278 |
+
minimum=0,
|
279 |
+
maximum=8,
|
280 |
+
step=1,
|
281 |
+
value=1,
|
282 |
+
)
|
283 |
+
|
284 |
+
periodic_w = gr.Slider(
|
285 |
+
label="periodic prompt width (steps, 1 step ~= 10milliseconds)",
|
286 |
+
minimum=1,
|
287 |
+
maximum=20,
|
288 |
+
step=1,
|
289 |
+
value=1,
|
290 |
+
)
|
291 |
+
|
292 |
+
|
293 |
+
with gr.Accordion("sampling settings", open=False):
|
294 |
+
sampletemp = gr.Slider(
|
295 |
+
label="sample temperature",
|
296 |
+
minimum=0.1,
|
297 |
+
maximum=10.0,
|
298 |
+
value=1.0,
|
299 |
+
step=0.001
|
300 |
+
)
|
301 |
+
|
302 |
+
top_p = gr.Slider(
|
303 |
+
label="top p (0.0 = off)",
|
304 |
+
minimum=0.0,
|
305 |
+
maximum=1.0,
|
306 |
+
value=0.0
|
307 |
+
)
|
308 |
+
typical_filtering = gr.Checkbox(
|
309 |
+
label="typical filtering ",
|
310 |
+
value=True
|
311 |
+
)
|
312 |
+
typical_mass = gr.Slider(
|
313 |
+
label="typical mass (should probably stay between 0.1 and 0.5)",
|
314 |
+
minimum=0.01,
|
315 |
+
maximum=0.99,
|
316 |
+
value=0.15
|
317 |
+
)
|
318 |
+
typical_min_tokens = gr.Slider(
|
319 |
+
label="typical min tokens (should probably stay between 1 and 256)",
|
320 |
+
minimum=1,
|
321 |
+
maximum=256,
|
322 |
+
step=1,
|
323 |
+
value=64
|
324 |
+
)
|
325 |
+
sample_cutoff = gr.Slider(
|
326 |
+
label="sample cutoff",
|
327 |
+
minimum=0.0,
|
328 |
+
maximum=0.9,
|
329 |
+
value=1.0,
|
330 |
+
step=0.01
|
331 |
+
)
|
332 |
+
|
333 |
+
|
334 |
+
dropout = gr.Slider(
|
335 |
+
label="mask dropout",
|
336 |
+
minimum=0.0,
|
337 |
+
maximum=1.0,
|
338 |
+
step=0.01,
|
339 |
+
value=0.0
|
340 |
+
)
|
341 |
+
|
342 |
+
|
343 |
+
seed = gr.Number(
|
344 |
+
label="seed (0 for random)",
|
345 |
+
value=0,
|
346 |
+
precision=0,
|
347 |
+
)
|
348 |
+
|
349 |
+
|
350 |
+
# mask settings
|
351 |
+
with gr.Column():
|
352 |
+
|
353 |
+
model_choice = gr.Dropdown(
|
354 |
+
label="model choice",
|
355 |
+
choices=list(MODEL_CHOICES.keys()),
|
356 |
+
value="default",
|
357 |
+
visible=True
|
358 |
+
)
|
359 |
+
|
360 |
+
|
361 |
+
vamp_button = gr.Button("generate (vamp)!!!")
|
362 |
+
|
363 |
+
|
364 |
+
audio_outs = []
|
365 |
+
use_as_input_btns = []
|
366 |
+
for i in range(1):
|
367 |
+
with gr.Column():
|
368 |
+
audio_outs.append(gr.Audio(
|
369 |
+
label=f"output audio {i+1}",
|
370 |
+
interactive=False,
|
371 |
+
type="numpy"
|
372 |
+
))
|
373 |
+
use_as_input_btns.append(
|
374 |
+
gr.Button(f"use as input (feedback)")
|
375 |
+
)
|
376 |
+
|
377 |
+
thank_you = gr.Markdown("")
|
378 |
+
|
379 |
+
# download all the outputs
|
380 |
+
# download = gr.File(type="filepath", label="download outputs")
|
381 |
+
|
382 |
+
|
383 |
+
_inputs = {
|
384 |
+
input_audio,
|
385 |
+
sampletemp,
|
386 |
+
top_p,
|
387 |
+
periodic_p, periodic_w,
|
388 |
+
dropout,
|
389 |
+
stretch_factor,
|
390 |
+
onset_mask_width,
|
391 |
+
typical_filtering,
|
392 |
+
typical_mass,
|
393 |
+
typical_min_tokens,
|
394 |
+
seed,
|
395 |
+
model_choice,
|
396 |
+
n_mask_codebooks,
|
397 |
+
pitch_shift_amt,
|
398 |
+
sample_cutoff,
|
399 |
+
}
|
400 |
+
|
401 |
+
# connect widgets
|
402 |
+
vamp_button.click(
|
403 |
+
fn=vamp,
|
404 |
+
inputs=_inputs,
|
405 |
+
outputs=[audio_outs[0]],
|
406 |
+
)
|
407 |
+
|
408 |
+
api_vamp_button = gr.Button("api vamp", visible=True)
|
409 |
+
api_vamp_button.click(
|
410 |
+
fn=api_vamp,
|
411 |
+
inputs=_inputs,
|
412 |
+
outputs=[audio_outs[0]],
|
413 |
+
api_name="vamp"
|
414 |
+
)
|
415 |
+
|
416 |
+
for i, btn in enumerate(use_as_input_btns):
|
417 |
+
btn.click(
|
418 |
+
fn=load_audio,
|
419 |
+
inputs=[audio_outs[i]],
|
420 |
+
outputs=[input_audio]
|
421 |
+
)
|
422 |
+
|
423 |
+
try:
|
424 |
+
demo.queue()
|
425 |
+
demo.launch(share=True)
|
426 |
+
except KeyboardInterrupt:
|
427 |
+
shutil.rmtree("gradio-outputs", ignore_errors=True)
|
428 |
+
raise
|
assets/example.wav
ADDED
Binary file (883 kB). View file
|
|
conf/c2f.yml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/vampnet.yml
|
3 |
+
|
4 |
+
VampNet.n_codebooks: 14
|
5 |
+
VampNet.n_conditioning_codebooks: 4
|
6 |
+
|
7 |
+
VampNet.embedding_dim: 1280
|
8 |
+
VampNet.n_layers: 16
|
9 |
+
VampNet.n_heads: 20
|
10 |
+
|
11 |
+
AudioDataset.duration: 3.0
|
12 |
+
|
13 |
+
|
14 |
+
AudioDataset.loudness_cutoff: -40.0
|
conf/interface.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Interface.coarse_ckpt: ./models/vampnet/coarse.pth
|
2 |
+
Interface.coarse2fine_ckpt: ./models/vampnet/c2f.pth
|
3 |
+
Interface.codec_ckpt: ./models/vampnet/codec.pth
|
4 |
+
Interface.coarse_chunk_size_s: 10
|
5 |
+
Interface.coarse2fine_chunk_size_s: 3
|
6 |
+
Interface.wavebeat_ckpt: ./models/wavebeat.pth
|
7 |
+
|
8 |
+
# AudioLoader.sources:
|
9 |
+
# - /media/CHONK/null
|
10 |
+
|
conf/lora/lora.yml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/vampnet.yml
|
3 |
+
|
4 |
+
fine_tune: True
|
5 |
+
|
6 |
+
train/AudioDataset.n_examples: 100000000
|
7 |
+
val/AudioDataset.n_examples: 500
|
8 |
+
|
9 |
+
|
10 |
+
NoamScheduler.warmup: 500
|
11 |
+
|
12 |
+
batch_size: 7
|
13 |
+
num_workers: 7
|
14 |
+
save_iters: [2000, 4000, 10000,20000, 40000, 100000]
|
15 |
+
sample_freq: 2000
|
16 |
+
val_freq: 1000
|
17 |
+
|
18 |
+
AdamW.lr: 0.0001
|
19 |
+
|
20 |
+
# let's us organize sound classes into folders and choose from those sound classes uniformly
|
21 |
+
AudioDataset.without_replacement: False
|
22 |
+
num_iters: 500000
|
conf/salad_bowl.yml
ADDED
File without changes
|
conf/vampnet.yml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
codec_ckpt: ./models/vampnet/codec.pth
|
3 |
+
save_path: ckpt
|
4 |
+
|
5 |
+
num_iters: 1000000000
|
6 |
+
save_iters: [10000, 50000, 100000, 300000, 500000]
|
7 |
+
val_idx: [0,1,2,3,4,5,6,7,8,9]
|
8 |
+
sample_freq: 10000
|
9 |
+
val_freq: 1000
|
10 |
+
|
11 |
+
batch_size: 8
|
12 |
+
num_workers: 10
|
13 |
+
|
14 |
+
# Optimization
|
15 |
+
amp: false
|
16 |
+
|
17 |
+
CrossEntropyLoss.label_smoothing: 0.1
|
18 |
+
|
19 |
+
AdamW.lr: 0.001
|
20 |
+
|
21 |
+
NoamScheduler.factor: 2.0
|
22 |
+
NoamScheduler.warmup: 10000
|
23 |
+
|
24 |
+
VampNet.vocab_size: 1024
|
25 |
+
VampNet.n_codebooks: 4
|
26 |
+
VampNet.n_conditioning_codebooks: 0
|
27 |
+
VampNet.r_cond_dim: 0
|
28 |
+
VampNet.noise_mode: mask
|
29 |
+
VampNet.embedding_dim: 1280
|
30 |
+
VampNet.n_layers: 20
|
31 |
+
VampNet.n_heads: 20
|
32 |
+
VampNet.flash_attn: false
|
33 |
+
VampNet.dropout: 0.1
|
34 |
+
|
35 |
+
AudioLoader.relative_path: ""
|
36 |
+
AudioDataset.loudness_cutoff: -30.0
|
37 |
+
AudioDataset.without_replacement: true
|
38 |
+
AudioLoader.shuffle: true
|
39 |
+
|
40 |
+
AudioDataset.duration: 10.0
|
41 |
+
|
42 |
+
train/AudioDataset.n_examples: 10000000
|
43 |
+
train/AudioLoader.sources:
|
44 |
+
- /media/CHONK/hugo/spotdl/audio-train
|
45 |
+
|
46 |
+
val/AudioDataset.n_examples: 2000
|
47 |
+
val/AudioLoader.sources:
|
48 |
+
- /media/CHONK/hugo/spotdl/audio-val
|
49 |
+
|
hello.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import vampnet
|
3 |
+
import audiotools as at
|
4 |
+
|
5 |
+
# load the default vampnet model
|
6 |
+
interface = vampnet.interface.Interface.default()
|
7 |
+
|
8 |
+
# list available finetuned models
|
9 |
+
finetuned_model_choices = interface.available_models()
|
10 |
+
print(f"available finetuned models: {finetuned_model_choices}")
|
11 |
+
|
12 |
+
# pick a random finetuned model
|
13 |
+
model_choice = random.choice(finetuned_model_choices)
|
14 |
+
print(f"choosing model: {model_choice}")
|
15 |
+
|
16 |
+
# load a finetuned model
|
17 |
+
interface.load_finetuned(model_choice)
|
18 |
+
|
19 |
+
# load an example audio file
|
20 |
+
signal = at.AudioSignal("assets/example.wav")
|
21 |
+
|
22 |
+
# get the tokens for the audio
|
23 |
+
codes = interface.encode(signal)
|
24 |
+
|
25 |
+
# build a mask for the audio
|
26 |
+
mask = interface.build_mask(
|
27 |
+
codes, signal,
|
28 |
+
periodic_prompt=7,
|
29 |
+
upper_codebook_mask=3,
|
30 |
+
)
|
31 |
+
|
32 |
+
# generate the output tokens
|
33 |
+
output_tokens = interface.vamp(
|
34 |
+
codes, mask, return_mask=False,
|
35 |
+
temperature=1.0,
|
36 |
+
typical_filtering=True,
|
37 |
+
)
|
38 |
+
|
39 |
+
# convert them to a signal
|
40 |
+
output_signal = interface.decode(output_tokens)
|
41 |
+
|
42 |
+
# save the output signal
|
43 |
+
output_signal.write("scratch/output.wav")
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.1.0
|
2 |
+
argbind>=0.3.2
|
3 |
+
numpy==1.23
|
4 |
+
loralib
|
5 |
+
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
|
6 |
+
lac @ git+https://github.com/hugofloresgarcia/lac.git
|
7 |
+
descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git
|
8 |
+
-e git+https://github.com/audacitorch/pyharp.git#egg=pyharp
|
9 |
+
torch_pitch_shift
|
10 |
+
gradio==4.37.2
|
scripts/exp/eval.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import os
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
from frechet_audio_distance import FrechetAudioDistance
|
6 |
+
import pandas
|
7 |
+
import argbind
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
import audiotools
|
12 |
+
from audiotools import AudioSignal
|
13 |
+
|
14 |
+
@argbind.bind(without_prefix=True)
|
15 |
+
def eval(
|
16 |
+
exp_dir: str = None,
|
17 |
+
baseline_key: str = "baseline",
|
18 |
+
audio_ext: str = ".wav",
|
19 |
+
):
|
20 |
+
assert exp_dir is not None
|
21 |
+
exp_dir = Path(exp_dir)
|
22 |
+
assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
|
23 |
+
|
24 |
+
# set up our metrics
|
25 |
+
# sisdr_loss = audiotools.metrics.distance.SISDRLoss()
|
26 |
+
# stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
|
27 |
+
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
|
28 |
+
frechet = FrechetAudioDistance(
|
29 |
+
use_pca=False,
|
30 |
+
use_activation=False,
|
31 |
+
verbose=True,
|
32 |
+
audio_load_worker=4,
|
33 |
+
)
|
34 |
+
frechet.model.to("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
+
|
36 |
+
# figure out what conditions we have
|
37 |
+
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
|
38 |
+
|
39 |
+
assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}"
|
40 |
+
conditions.remove(baseline_key)
|
41 |
+
|
42 |
+
print(f"Found {len(conditions)} conditions in {exp_dir}")
|
43 |
+
print(f"conditions: {conditions}")
|
44 |
+
|
45 |
+
baseline_dir = exp_dir / baseline_key
|
46 |
+
baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
|
47 |
+
|
48 |
+
metrics = []
|
49 |
+
for condition in tqdm(conditions):
|
50 |
+
cond_dir = exp_dir / condition
|
51 |
+
cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
|
52 |
+
|
53 |
+
print(f"computing fad for {baseline_dir} and {cond_dir}")
|
54 |
+
frechet_score = frechet.score(baseline_dir, cond_dir)
|
55 |
+
|
56 |
+
# make sure we have the same number of files
|
57 |
+
num_files = min(len(baseline_files), len(cond_files))
|
58 |
+
baseline_files = baseline_files[:num_files]
|
59 |
+
cond_files = cond_files[:num_files]
|
60 |
+
assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
|
61 |
+
|
62 |
+
def process(baseline_file, cond_file):
|
63 |
+
# make sure the files match (same name)
|
64 |
+
assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
|
65 |
+
|
66 |
+
# load the files
|
67 |
+
baseline_sig = AudioSignal(str(baseline_file))
|
68 |
+
cond_sig = AudioSignal(str(cond_file))
|
69 |
+
|
70 |
+
cond_sig.resample(baseline_sig.sample_rate)
|
71 |
+
cond_sig.truncate_samples(baseline_sig.length)
|
72 |
+
|
73 |
+
# if our condition is inpainting, we need to trim the conditioning off
|
74 |
+
if "inpaint" in condition:
|
75 |
+
ctx_amt = float(condition.split("_")[-1])
|
76 |
+
ctx_samples = int(ctx_amt * baseline_sig.sample_rate)
|
77 |
+
print(f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}")
|
78 |
+
cond_sig.trim(ctx_samples, ctx_samples)
|
79 |
+
baseline_sig.trim(ctx_samples, ctx_samples)
|
80 |
+
|
81 |
+
return {
|
82 |
+
# "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
|
83 |
+
# "stft": stft_loss(baseline_sig, cond_sig).item(),
|
84 |
+
"mel": mel_loss(baseline_sig, cond_sig).item(),
|
85 |
+
"frechet": frechet_score,
|
86 |
+
# "visqol": vsq,
|
87 |
+
"condition": condition,
|
88 |
+
"file": baseline_file.stem,
|
89 |
+
}
|
90 |
+
|
91 |
+
print(f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}")
|
92 |
+
metrics.extend(tqdm(map(process, baseline_files, cond_files), total=len(baseline_files)))
|
93 |
+
|
94 |
+
metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
|
95 |
+
|
96 |
+
|
97 |
+
for mk in metric_keys:
|
98 |
+
stat = pandas.DataFrame(metrics)
|
99 |
+
stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std'])
|
100 |
+
stat.to_csv(exp_dir / f"stats-{mk}.csv")
|
101 |
+
|
102 |
+
df = pandas.DataFrame(metrics)
|
103 |
+
df.to_csv(exp_dir / "metrics-all.csv", index=False)
|
104 |
+
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
args = argbind.parse_args()
|
108 |
+
|
109 |
+
with argbind.scope(args):
|
110 |
+
eval()
|
scripts/exp/experiment.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import random
|
3 |
+
from typing import List
|
4 |
+
import tempfile
|
5 |
+
import subprocess
|
6 |
+
|
7 |
+
import argbind
|
8 |
+
from tqdm import tqdm
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from vampnet.interface import Interface
|
12 |
+
from vampnet import mask as pmask
|
13 |
+
import audiotools as at
|
14 |
+
|
15 |
+
Interface: Interface = argbind.bind(Interface)
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
def calculate_bitrate(
|
20 |
+
interface, num_codebooks,
|
21 |
+
downsample_factor
|
22 |
+
):
|
23 |
+
bit_width = 10
|
24 |
+
sr = interface.codec.sample_rate
|
25 |
+
hop = interface.codec.hop_size
|
26 |
+
rate = (sr / hop) * ((bit_width * num_codebooks) / downsample_factor)
|
27 |
+
return rate
|
28 |
+
|
29 |
+
def baseline(sig, interface):
|
30 |
+
return interface.preprocess(sig)
|
31 |
+
|
32 |
+
def reconstructed(sig, interface):
|
33 |
+
return interface.decode(
|
34 |
+
interface.encode(sig)
|
35 |
+
)
|
36 |
+
|
37 |
+
def coarse2fine(sig, interface):
|
38 |
+
z = interface.encode(sig)
|
39 |
+
z = z[:, :interface.c2f.n_conditioning_codebooks, :]
|
40 |
+
|
41 |
+
z = interface.coarse_to_fine(z)
|
42 |
+
return interface.decode(z)
|
43 |
+
|
44 |
+
class CoarseCond:
|
45 |
+
|
46 |
+
def __init__(self, num_conditioning_codebooks, downsample_factor):
|
47 |
+
self.num_conditioning_codebooks = num_conditioning_codebooks
|
48 |
+
self.downsample_factor = downsample_factor
|
49 |
+
|
50 |
+
def __call__(self, sig, interface):
|
51 |
+
z = interface.encode(sig)
|
52 |
+
mask = pmask.full_mask(z)
|
53 |
+
mask = pmask.codebook_unmask(mask, self.num_conditioning_codebooks)
|
54 |
+
mask = pmask.periodic_mask(mask, self.downsample_factor)
|
55 |
+
|
56 |
+
zv = interface.coarse_vamp(z, mask)
|
57 |
+
zv = interface.coarse_to_fine(zv)
|
58 |
+
return interface.decode(zv)
|
59 |
+
|
60 |
+
def opus(sig, interface, bitrate=128):
|
61 |
+
sig = interface.preprocess(sig)
|
62 |
+
|
63 |
+
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
|
64 |
+
sig.write(f.name)
|
65 |
+
|
66 |
+
opus_name = Path(f.name).with_suffix(".opus")
|
67 |
+
# convert to opus
|
68 |
+
cmd = [
|
69 |
+
"ffmpeg", "-y", "-i", f.name,
|
70 |
+
"-c:a", "libopus",
|
71 |
+
"-b:a", f"{bitrate}",
|
72 |
+
opus_name
|
73 |
+
]
|
74 |
+
subprocess.run(cmd, check=True)
|
75 |
+
|
76 |
+
# convert back to wav
|
77 |
+
output_name = Path(f"{f.name}-opus").with_suffix(".wav")
|
78 |
+
cmd = [
|
79 |
+
"ffmpeg", "-y", "-i", opus_name,
|
80 |
+
output_name
|
81 |
+
]
|
82 |
+
|
83 |
+
subprocess.run(cmd, check=True)
|
84 |
+
|
85 |
+
sig = at.AudioSignal(
|
86 |
+
output_name,
|
87 |
+
sample_rate=sig.sample_rate
|
88 |
+
)
|
89 |
+
return sig
|
90 |
+
|
91 |
+
def mask_ratio_1_step(ratio=1.0):
|
92 |
+
def wrapper(sig, interface):
|
93 |
+
z = interface.encode(sig)
|
94 |
+
mask = pmask.linear_random(z, ratio)
|
95 |
+
zv = interface.coarse_vamp(
|
96 |
+
z,
|
97 |
+
mask,
|
98 |
+
sampling_steps=1,
|
99 |
+
)
|
100 |
+
|
101 |
+
return interface.decode(zv)
|
102 |
+
return wrapper
|
103 |
+
|
104 |
+
def num_sampling_steps(num_steps=1):
|
105 |
+
def wrapper(sig, interface: Interface):
|
106 |
+
z = interface.encode(sig)
|
107 |
+
mask = pmask.periodic_mask(z, 16)
|
108 |
+
zv = interface.coarse_vamp(
|
109 |
+
z,
|
110 |
+
mask,
|
111 |
+
sampling_steps=num_steps,
|
112 |
+
)
|
113 |
+
|
114 |
+
zv = interface.coarse_to_fine(zv)
|
115 |
+
return interface.decode(zv)
|
116 |
+
return wrapper
|
117 |
+
|
118 |
+
def beat_mask(ctx_time):
|
119 |
+
def wrapper(sig, interface):
|
120 |
+
beat_mask = interface.make_beat_mask(
|
121 |
+
sig,
|
122 |
+
before_beat_s=ctx_time/2,
|
123 |
+
after_beat_s=ctx_time/2,
|
124 |
+
invert=True
|
125 |
+
)
|
126 |
+
|
127 |
+
z = interface.encode(sig)
|
128 |
+
|
129 |
+
zv = interface.coarse_vamp(
|
130 |
+
z, beat_mask
|
131 |
+
)
|
132 |
+
|
133 |
+
zv = interface.coarse_to_fine(zv)
|
134 |
+
return interface.decode(zv)
|
135 |
+
return wrapper
|
136 |
+
|
137 |
+
def inpaint(ctx_time):
|
138 |
+
def wrapper(sig, interface: Interface):
|
139 |
+
z = interface.encode(sig)
|
140 |
+
mask = pmask.inpaint(z, interface.s2t(ctx_time), interface.s2t(ctx_time))
|
141 |
+
|
142 |
+
zv = interface.coarse_vamp(z, mask)
|
143 |
+
zv = interface.coarse_to_fine(zv)
|
144 |
+
|
145 |
+
return interface.decode(zv)
|
146 |
+
return wrapper
|
147 |
+
|
148 |
+
def token_noise(noise_amt):
|
149 |
+
def wrapper(sig, interface: Interface):
|
150 |
+
z = interface.encode(sig)
|
151 |
+
mask = pmask.random(z, noise_amt)
|
152 |
+
z = torch.where(
|
153 |
+
mask,
|
154 |
+
torch.randint_like(z, 0, interface.coarse.vocab_size),
|
155 |
+
z
|
156 |
+
)
|
157 |
+
return interface.decode(z)
|
158 |
+
return wrapper
|
159 |
+
|
160 |
+
EXP_REGISTRY = {}
|
161 |
+
|
162 |
+
EXP_REGISTRY["gen-compression"] = {
|
163 |
+
"baseline": baseline,
|
164 |
+
"reconstructed": reconstructed,
|
165 |
+
"coarse2fine": coarse2fine,
|
166 |
+
**{
|
167 |
+
f"{n}_codebooks_downsampled_{x}x": CoarseCond(num_conditioning_codebooks=n, downsample_factor=x)
|
168 |
+
for (n, x) in (
|
169 |
+
(1, 1), # 1 codebook, no downsampling
|
170 |
+
(4, 4), # 4 codebooks, downsampled 4x
|
171 |
+
(4, 16), # 4 codebooks, downsampled 16x
|
172 |
+
(4, 32), # 4 codebooks, downsampled 16x
|
173 |
+
)
|
174 |
+
},
|
175 |
+
**{
|
176 |
+
f"token_noise_{x}": mask_ratio_1_step(ratio=x)
|
177 |
+
for x in [0.25, 0.5, 0.75]
|
178 |
+
},
|
179 |
+
|
180 |
+
}
|
181 |
+
|
182 |
+
|
183 |
+
EXP_REGISTRY["sampling-steps"] = {
|
184 |
+
# "codec": reconstructed,
|
185 |
+
**{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 36, 64, 72]},
|
186 |
+
}
|
187 |
+
|
188 |
+
|
189 |
+
EXP_REGISTRY["musical-sampling"] = {
|
190 |
+
**{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
|
191 |
+
**{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right)
|
192 |
+
}
|
193 |
+
|
194 |
+
@argbind.bind(without_prefix=True)
|
195 |
+
def main(
|
196 |
+
sources=[
|
197 |
+
"/media/CHONK/hugo/spotdl/val",
|
198 |
+
],
|
199 |
+
output_dir: str = "./samples",
|
200 |
+
max_excerpts: int = 2000,
|
201 |
+
exp_type: str = "gen-compression",
|
202 |
+
seed: int = 0,
|
203 |
+
ext: str = [".mp3"],
|
204 |
+
):
|
205 |
+
at.util.seed(seed)
|
206 |
+
interface = Interface()
|
207 |
+
|
208 |
+
output_dir = Path(output_dir)
|
209 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
210 |
+
|
211 |
+
from audiotools.data.datasets import AudioLoader, AudioDataset
|
212 |
+
|
213 |
+
loader = AudioLoader(sources=sources, shuffle_state=seed, ext=ext)
|
214 |
+
dataset = AudioDataset(loader,
|
215 |
+
sample_rate=interface.codec.sample_rate,
|
216 |
+
duration=interface.coarse.chunk_size_s,
|
217 |
+
n_examples=max_excerpts,
|
218 |
+
without_replacement=True,
|
219 |
+
)
|
220 |
+
|
221 |
+
if exp_type in EXP_REGISTRY:
|
222 |
+
SAMPLE_CONDS = EXP_REGISTRY[exp_type]
|
223 |
+
else:
|
224 |
+
raise ValueError(f"Unknown exp_type {exp_type}")
|
225 |
+
|
226 |
+
|
227 |
+
indices = list(range(max_excerpts))
|
228 |
+
random.shuffle(indices)
|
229 |
+
for i in tqdm(indices):
|
230 |
+
# if all our files are already there, skip
|
231 |
+
done = []
|
232 |
+
for name in SAMPLE_CONDS:
|
233 |
+
o_dir = Path(output_dir) / name
|
234 |
+
done.append((o_dir / f"{i}.wav").exists())
|
235 |
+
if all(done):
|
236 |
+
continue
|
237 |
+
|
238 |
+
sig = dataset[i]["signal"]
|
239 |
+
results = {
|
240 |
+
name: cond(sig, interface).cpu()
|
241 |
+
for name, cond in SAMPLE_CONDS.items()
|
242 |
+
}
|
243 |
+
|
244 |
+
for name, sig in results.items():
|
245 |
+
o_dir = Path(output_dir) / name
|
246 |
+
o_dir.mkdir(exist_ok=True, parents=True)
|
247 |
+
|
248 |
+
sig.write(o_dir / f"{i}.wav")
|
249 |
+
|
250 |
+
if __name__ == "__main__":
|
251 |
+
args = argbind.parse_args()
|
252 |
+
|
253 |
+
with argbind.scope(args):
|
254 |
+
main()
|
scripts/exp/export.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
run_dir = Path("runs/sample-instrument")
|
4 |
+
name = run_dir.name
|
5 |
+
|
6 |
+
repo_dir = Path("models/vampnet")
|
7 |
+
|
8 |
+
|
9 |
+
for part in ("coarse", "c2f"):
|
10 |
+
outdir = repo_dir / "loras" / name
|
11 |
+
outdir.mkdir(parents=True, exist_ok=True)
|
12 |
+
outpath = outdir / f"{part}.pth"
|
13 |
+
path = run_dir / part / "latest" / "vampnet" / "weights.pth"
|
14 |
+
path.rename(outpath)
|
15 |
+
print(f"moved {path} to {outpath}")
|
16 |
+
|
17 |
+
# now, push to hub
|
18 |
+
from huggingface_hub import Repository
|
19 |
+
repo = Repository(repo_dir, git_user="hugofloresgarcia", git_email="huferflo@gmail.com")
|
20 |
+
repo.push_to_hub(
|
21 |
+
commit_message=f"add {name}"
|
22 |
+
)
|
scripts/exp/fine_tune.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argbind
|
2 |
+
from pathlib import Path
|
3 |
+
import yaml
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
"""example output: (yaml)
|
10 |
+
|
11 |
+
"""
|
12 |
+
|
13 |
+
@argbind.bind(without_prefix=True, positional=True)
|
14 |
+
def fine_tune(audio_files_or_folders: List[str], name: str):
|
15 |
+
|
16 |
+
conf_dir = Path("conf")
|
17 |
+
assert conf_dir.exists(), "conf directory not found. are you in the vampnet directory?"
|
18 |
+
|
19 |
+
conf_dir = conf_dir / "generated"
|
20 |
+
conf_dir.mkdir(exist_ok=True)
|
21 |
+
|
22 |
+
finetune_dir = conf_dir / name
|
23 |
+
finetune_dir.mkdir(exist_ok=True)
|
24 |
+
|
25 |
+
finetune_c2f_conf = {
|
26 |
+
"$include": ["conf/lora/lora.yml"],
|
27 |
+
"fine_tune": True,
|
28 |
+
"train/AudioLoader.sources": audio_files_or_folders,
|
29 |
+
"val/AudioLoader.sources": audio_files_or_folders,
|
30 |
+
"VampNet.n_codebooks": 14,
|
31 |
+
"VampNet.n_conditioning_codebooks": 4,
|
32 |
+
"VampNet.embedding_dim": 1280,
|
33 |
+
"VampNet.n_layers": 16,
|
34 |
+
"VampNet.n_heads": 20,
|
35 |
+
"AudioDataset.duration": 3.0,
|
36 |
+
"AudioDataset.loudness_cutoff": -40.0,
|
37 |
+
"save_path": f"./runs/{name}/c2f",
|
38 |
+
"fine_tune_checkpoint": "./models/vampnet/c2f.pth"
|
39 |
+
}
|
40 |
+
|
41 |
+
finetune_coarse_conf = {
|
42 |
+
"$include": ["conf/lora/lora.yml"],
|
43 |
+
"fine_tune": True,
|
44 |
+
"train/AudioLoader.sources": audio_files_or_folders,
|
45 |
+
"val/AudioLoader.sources": audio_files_or_folders,
|
46 |
+
"save_path": f"./runs/{name}/coarse",
|
47 |
+
"fine_tune_checkpoint": "./models/vampnet/coarse.pth"
|
48 |
+
}
|
49 |
+
|
50 |
+
interface_conf = {
|
51 |
+
"Interface.coarse_ckpt": f"./runs/{name}/coarse/latest/vampnet/weights.pth",
|
52 |
+
|
53 |
+
"Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/latest/vampnet/weights.pth",
|
54 |
+
"Interface.wavebeat_ckpt": "./models/wavebeat.pth",
|
55 |
+
|
56 |
+
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
|
57 |
+
"AudioLoader.sources": [audio_files_or_folders],
|
58 |
+
}
|
59 |
+
|
60 |
+
# save the confs
|
61 |
+
with open(finetune_dir / "c2f.yml", "w") as f:
|
62 |
+
yaml.dump(finetune_c2f_conf, f)
|
63 |
+
|
64 |
+
with open(finetune_dir / "coarse.yml", "w") as f:
|
65 |
+
yaml.dump(finetune_coarse_conf, f)
|
66 |
+
|
67 |
+
with open(finetune_dir / "interface.yml", "w") as f:
|
68 |
+
yaml.dump(interface_conf, f)
|
69 |
+
|
70 |
+
|
71 |
+
print(f"generated confs in {finetune_dir}. run training jobs with `python scripts/exp/train.py --args.load {finetune_dir}/<c2f/coarse>.yml` ")
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
args = argbind.parse_args()
|
75 |
+
|
76 |
+
with argbind.scope(args):
|
77 |
+
fine_tune()
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
|
scripts/exp/train.py
ADDED
@@ -0,0 +1,686 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import warnings
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
import argbind
|
9 |
+
import audiotools as at
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from audiotools import AudioSignal
|
13 |
+
from audiotools.data import transforms as tfm
|
14 |
+
from einops import rearrange
|
15 |
+
from rich import pretty
|
16 |
+
from rich.traceback import install
|
17 |
+
from torch.utils.tensorboard import SummaryWriter
|
18 |
+
|
19 |
+
import vampnet
|
20 |
+
from vampnet.modules.transformer import VampNet
|
21 |
+
from vampnet.util import codebook_unflatten, codebook_flatten
|
22 |
+
from vampnet import mask as pmask
|
23 |
+
# from dac.model.dac import DAC
|
24 |
+
from lac.model.lac import LAC as DAC
|
25 |
+
|
26 |
+
from audiotools.ml.decorators import (
|
27 |
+
timer, Tracker, when
|
28 |
+
)
|
29 |
+
|
30 |
+
import loralib as lora
|
31 |
+
|
32 |
+
import torch._dynamo
|
33 |
+
torch._dynamo.config.verbose=True
|
34 |
+
|
35 |
+
|
36 |
+
# Enable cudnn autotuner to speed up training
|
37 |
+
# (can be altered by the funcs.seed function)
|
38 |
+
torch.backends.cudnn.benchmark = bool(int(os.getenv("CUDNN_BENCHMARK", 1)))
|
39 |
+
# Uncomment to trade memory for speed.
|
40 |
+
|
41 |
+
# Install to make things look nice
|
42 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
43 |
+
pretty.install()
|
44 |
+
install()
|
45 |
+
|
46 |
+
# optim
|
47 |
+
Accelerator = argbind.bind(at.ml.Accelerator, without_prefix=True)
|
48 |
+
CrossEntropyLoss = argbind.bind(nn.CrossEntropyLoss)
|
49 |
+
AdamW = argbind.bind(torch.optim.AdamW)
|
50 |
+
NoamScheduler = argbind.bind(vampnet.scheduler.NoamScheduler)
|
51 |
+
|
52 |
+
# transforms
|
53 |
+
filter_fn = lambda fn: hasattr(fn, "transform") and fn.__qualname__ not in [
|
54 |
+
"BaseTransform",
|
55 |
+
"Compose",
|
56 |
+
"Choose",
|
57 |
+
]
|
58 |
+
|
59 |
+
# model
|
60 |
+
VampNet = argbind.bind(VampNet)
|
61 |
+
|
62 |
+
|
63 |
+
# data
|
64 |
+
AudioLoader = argbind.bind(at.datasets.AudioLoader)
|
65 |
+
AudioDataset = argbind.bind(at.datasets.AudioDataset, "train", "val")
|
66 |
+
|
67 |
+
IGNORE_INDEX = -100
|
68 |
+
|
69 |
+
|
70 |
+
@argbind.bind("train", "val", without_prefix=True)
|
71 |
+
def build_transform():
|
72 |
+
transform = tfm.Compose(
|
73 |
+
tfm.VolumeNorm(("const", -24)),
|
74 |
+
# tfm.PitchShift(),
|
75 |
+
tfm.RescaleAudio(),
|
76 |
+
)
|
77 |
+
return transform
|
78 |
+
|
79 |
+
|
80 |
+
@torch.no_grad()
|
81 |
+
def apply_transform(transform_fn, batch):
|
82 |
+
sig: AudioSignal = batch["signal"]
|
83 |
+
kwargs = batch["transform_args"]
|
84 |
+
|
85 |
+
sig: AudioSignal = transform_fn(sig.clone(), **kwargs)
|
86 |
+
return sig
|
87 |
+
|
88 |
+
|
89 |
+
def build_datasets(args, sample_rate: int):
|
90 |
+
with argbind.scope(args, "train"):
|
91 |
+
train_data = AudioDataset(
|
92 |
+
AudioLoader(), sample_rate, transform=build_transform()
|
93 |
+
)
|
94 |
+
with argbind.scope(args, "val"):
|
95 |
+
val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
|
96 |
+
return train_data, val_data
|
97 |
+
|
98 |
+
|
99 |
+
def rand_float(shape, low, high, rng):
|
100 |
+
return rng.draw(shape)[:, 0] * (high - low) + low
|
101 |
+
|
102 |
+
|
103 |
+
def flip_coin(shape, p, rng):
|
104 |
+
return rng.draw(shape)[:, 0] < p
|
105 |
+
|
106 |
+
|
107 |
+
def num_params_hook(o, p):
|
108 |
+
return o + f" {p/1e6:<.3f}M params."
|
109 |
+
|
110 |
+
|
111 |
+
def add_num_params_repr_hook(model):
|
112 |
+
import numpy as np
|
113 |
+
from functools import partial
|
114 |
+
|
115 |
+
for n, m in model.named_modules():
|
116 |
+
o = m.extra_repr()
|
117 |
+
p = sum([np.prod(p.size()) for p in m.parameters()])
|
118 |
+
|
119 |
+
setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
|
120 |
+
|
121 |
+
|
122 |
+
def accuracy(
|
123 |
+
preds: torch.Tensor,
|
124 |
+
target: torch.Tensor,
|
125 |
+
top_k: int = 1,
|
126 |
+
ignore_index: Optional[int] = None,
|
127 |
+
) -> torch.Tensor:
|
128 |
+
# Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
|
129 |
+
preds = rearrange(preds, "b p s -> (b s) p")
|
130 |
+
target = rearrange(target, "b s -> (b s)")
|
131 |
+
|
132 |
+
# return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
|
133 |
+
if ignore_index is not None:
|
134 |
+
# Create a mask for the ignored index
|
135 |
+
mask = target != ignore_index
|
136 |
+
# Apply the mask to the target and predictions
|
137 |
+
preds = preds[mask]
|
138 |
+
target = target[mask]
|
139 |
+
|
140 |
+
# Get the top-k predicted classes and their indices
|
141 |
+
_, pred_indices = torch.topk(preds, k=top_k, dim=-1)
|
142 |
+
|
143 |
+
# Determine if the true target is in the top-k predicted classes
|
144 |
+
correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
|
145 |
+
|
146 |
+
# Calculate the accuracy
|
147 |
+
accuracy = torch.mean(correct.float())
|
148 |
+
|
149 |
+
return accuracy
|
150 |
+
|
151 |
+
def _metrics(z_hat, r, target, flat_mask, output):
|
152 |
+
for r_range in [(0, 0.5), (0.5, 1.0)]:
|
153 |
+
unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
|
154 |
+
masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
155 |
+
|
156 |
+
assert target.shape[0] == r.shape[0]
|
157 |
+
# grab the indices of the r values that are in the range
|
158 |
+
r_idx = (r >= r_range[0]) & (r < r_range[1])
|
159 |
+
|
160 |
+
# grab the target and z_hat values that are in the range
|
161 |
+
r_unmasked_target = unmasked_target[r_idx]
|
162 |
+
r_masked_target = masked_target[r_idx]
|
163 |
+
r_z_hat = z_hat[r_idx]
|
164 |
+
|
165 |
+
for topk in (1, 25):
|
166 |
+
s, e = r_range
|
167 |
+
tag = f"accuracy-{s}-{e}/top{topk}"
|
168 |
+
|
169 |
+
output[f"{tag}/unmasked"] = accuracy(
|
170 |
+
preds=r_z_hat,
|
171 |
+
target=r_unmasked_target,
|
172 |
+
ignore_index=IGNORE_INDEX,
|
173 |
+
top_k=topk,
|
174 |
+
)
|
175 |
+
output[f"{tag}/masked"] = accuracy(
|
176 |
+
preds=r_z_hat,
|
177 |
+
target=r_masked_target,
|
178 |
+
ignore_index=IGNORE_INDEX,
|
179 |
+
top_k=topk,
|
180 |
+
)
|
181 |
+
|
182 |
+
|
183 |
+
@dataclass
|
184 |
+
class State:
|
185 |
+
model: VampNet
|
186 |
+
codec: DAC
|
187 |
+
|
188 |
+
optimizer: AdamW
|
189 |
+
scheduler: NoamScheduler
|
190 |
+
criterion: CrossEntropyLoss
|
191 |
+
grad_clip_val: float
|
192 |
+
|
193 |
+
rng: torch.quasirandom.SobolEngine
|
194 |
+
|
195 |
+
train_data: AudioDataset
|
196 |
+
val_data: AudioDataset
|
197 |
+
|
198 |
+
tracker: Tracker
|
199 |
+
|
200 |
+
|
201 |
+
@timer()
|
202 |
+
def train_loop(state: State, batch: dict, accel: Accelerator):
|
203 |
+
state.model.train()
|
204 |
+
batch = at.util.prepare_batch(batch, accel.device)
|
205 |
+
signal = apply_transform(state.train_data.transform, batch)
|
206 |
+
|
207 |
+
output = {}
|
208 |
+
vn = accel.unwrap(state.model)
|
209 |
+
with accel.autocast():
|
210 |
+
with torch.inference_mode():
|
211 |
+
state.codec.to(accel.device)
|
212 |
+
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
213 |
+
z = z[:, : vn.n_codebooks, :]
|
214 |
+
|
215 |
+
n_batch = z.shape[0]
|
216 |
+
r = state.rng.draw(n_batch)[:, 0].to(accel.device)
|
217 |
+
|
218 |
+
mask = pmask.random(z, r)
|
219 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
220 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
221 |
+
|
222 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
223 |
+
|
224 |
+
dtype = torch.bfloat16 if accel.amp else None
|
225 |
+
with accel.autocast(dtype=dtype):
|
226 |
+
z_hat = state.model(z_mask_latent)
|
227 |
+
|
228 |
+
target = codebook_flatten(
|
229 |
+
z[:, vn.n_conditioning_codebooks :, :],
|
230 |
+
)
|
231 |
+
|
232 |
+
flat_mask = codebook_flatten(
|
233 |
+
mask[:, vn.n_conditioning_codebooks :, :],
|
234 |
+
)
|
235 |
+
|
236 |
+
# replace target with ignore index for masked tokens
|
237 |
+
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
238 |
+
output["loss"] = state.criterion(z_hat, t_masked)
|
239 |
+
|
240 |
+
_metrics(
|
241 |
+
r=r,
|
242 |
+
z_hat=z_hat,
|
243 |
+
target=target,
|
244 |
+
flat_mask=flat_mask,
|
245 |
+
output=output,
|
246 |
+
)
|
247 |
+
|
248 |
+
|
249 |
+
accel.backward(output["loss"])
|
250 |
+
|
251 |
+
output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"]
|
252 |
+
output["other/batch_size"] = z.shape[0]
|
253 |
+
|
254 |
+
|
255 |
+
accel.scaler.unscale_(state.optimizer)
|
256 |
+
output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
|
257 |
+
state.model.parameters(), state.grad_clip_val
|
258 |
+
)
|
259 |
+
|
260 |
+
accel.step(state.optimizer)
|
261 |
+
state.optimizer.zero_grad()
|
262 |
+
|
263 |
+
state.scheduler.step()
|
264 |
+
accel.update()
|
265 |
+
|
266 |
+
|
267 |
+
return {k: v for k, v in sorted(output.items())}
|
268 |
+
|
269 |
+
|
270 |
+
@timer()
|
271 |
+
@torch.no_grad()
|
272 |
+
def val_loop(state: State, batch: dict, accel: Accelerator):
|
273 |
+
state.model.eval()
|
274 |
+
state.codec.eval()
|
275 |
+
batch = at.util.prepare_batch(batch, accel.device)
|
276 |
+
signal = apply_transform(state.val_data.transform, batch)
|
277 |
+
|
278 |
+
vn = accel.unwrap(state.model)
|
279 |
+
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
280 |
+
z = z[:, : vn.n_codebooks, :]
|
281 |
+
|
282 |
+
n_batch = z.shape[0]
|
283 |
+
r = state.rng.draw(n_batch)[:, 0].to(accel.device)
|
284 |
+
|
285 |
+
mask = pmask.random(z, r)
|
286 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
287 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
288 |
+
|
289 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
290 |
+
|
291 |
+
z_hat = state.model(z_mask_latent)
|
292 |
+
|
293 |
+
target = codebook_flatten(
|
294 |
+
z[:, vn.n_conditioning_codebooks :, :],
|
295 |
+
)
|
296 |
+
|
297 |
+
flat_mask = codebook_flatten(
|
298 |
+
mask[:, vn.n_conditioning_codebooks :, :]
|
299 |
+
)
|
300 |
+
|
301 |
+
output = {}
|
302 |
+
# replace target with ignore index for masked tokens
|
303 |
+
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
304 |
+
output["loss"] = state.criterion(z_hat, t_masked)
|
305 |
+
|
306 |
+
_metrics(
|
307 |
+
r=r,
|
308 |
+
z_hat=z_hat,
|
309 |
+
target=target,
|
310 |
+
flat_mask=flat_mask,
|
311 |
+
output=output,
|
312 |
+
)
|
313 |
+
|
314 |
+
return output
|
315 |
+
|
316 |
+
|
317 |
+
def validate(state, val_dataloader, accel):
|
318 |
+
for batch in val_dataloader:
|
319 |
+
output = val_loop(state, batch, accel)
|
320 |
+
# Consolidate state dicts if using ZeroRedundancyOptimizer
|
321 |
+
if hasattr(state.optimizer, "consolidate_state_dict"):
|
322 |
+
state.optimizer.consolidate_state_dict()
|
323 |
+
return output
|
324 |
+
|
325 |
+
|
326 |
+
def checkpoint(state, save_iters, save_path, fine_tune):
|
327 |
+
if accel.local_rank != 0:
|
328 |
+
state.tracker.print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
|
329 |
+
return
|
330 |
+
|
331 |
+
metadata = {"logs": dict(state.tracker.history)}
|
332 |
+
|
333 |
+
tags = ["latest"]
|
334 |
+
state.tracker.print(f"Saving to {str(Path('.').absolute())}")
|
335 |
+
|
336 |
+
if state.tracker.step in save_iters:
|
337 |
+
tags.append(f"{state.tracker.step // 1000}k")
|
338 |
+
|
339 |
+
if state.tracker.is_best("val", "loss"):
|
340 |
+
state.tracker.print(f"Best model so far")
|
341 |
+
tags.append("best")
|
342 |
+
|
343 |
+
if fine_tune:
|
344 |
+
for tag in tags:
|
345 |
+
# save the lora model
|
346 |
+
(Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
|
347 |
+
torch.save(
|
348 |
+
lora.lora_state_dict(accel.unwrap(state.model)),
|
349 |
+
f"{save_path}/{tag}/lora.pth"
|
350 |
+
)
|
351 |
+
|
352 |
+
for tag in tags:
|
353 |
+
model_extra = {
|
354 |
+
"optimizer.pth": state.optimizer.state_dict(),
|
355 |
+
"scheduler.pth": state.scheduler.state_dict(),
|
356 |
+
"tracker.pth": state.tracker.state_dict(),
|
357 |
+
"metadata.pth": metadata,
|
358 |
+
}
|
359 |
+
|
360 |
+
accel.unwrap(state.model).metadata = metadata
|
361 |
+
accel.unwrap(state.model).save_to_folder(
|
362 |
+
f"{save_path}/{tag}", model_extra, package=False
|
363 |
+
)
|
364 |
+
|
365 |
+
|
366 |
+
def save_sampled(state, z, writer):
|
367 |
+
num_samples = z.shape[0]
|
368 |
+
|
369 |
+
for i in range(num_samples):
|
370 |
+
sampled = accel.unwrap(state.model).generate(
|
371 |
+
codec=state.codec,
|
372 |
+
time_steps=z.shape[-1],
|
373 |
+
start_tokens=z[i : i + 1],
|
374 |
+
)
|
375 |
+
sampled.cpu().write_audio_to_tb(
|
376 |
+
f"sampled/{i}",
|
377 |
+
writer,
|
378 |
+
step=state.tracker.step,
|
379 |
+
plot_fn=None,
|
380 |
+
)
|
381 |
+
|
382 |
+
|
383 |
+
def save_imputation(state, z, val_idx, writer):
|
384 |
+
n_prefix = int(z.shape[-1] * 0.25)
|
385 |
+
n_suffix = int(z.shape[-1] * 0.25)
|
386 |
+
|
387 |
+
vn = accel.unwrap(state.model)
|
388 |
+
|
389 |
+
mask = pmask.inpaint(z, n_prefix, n_suffix)
|
390 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
391 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
392 |
+
|
393 |
+
imputed_noisy = vn.decode(z_mask, state.codec)
|
394 |
+
imputed_true = vn.decode(z, state.codec)
|
395 |
+
|
396 |
+
imputed = []
|
397 |
+
for i in range(len(z)):
|
398 |
+
imputed.append(
|
399 |
+
vn.generate(
|
400 |
+
codec=state.codec,
|
401 |
+
time_steps=z.shape[-1],
|
402 |
+
start_tokens=z[i][None, ...],
|
403 |
+
mask=mask[i][None, ...],
|
404 |
+
)
|
405 |
+
)
|
406 |
+
imputed = AudioSignal.batch(imputed)
|
407 |
+
|
408 |
+
for i in range(len(val_idx)):
|
409 |
+
imputed_noisy[i].cpu().write_audio_to_tb(
|
410 |
+
f"inpainted_prompt/{i}",
|
411 |
+
writer,
|
412 |
+
step=state.tracker.step,
|
413 |
+
plot_fn=None,
|
414 |
+
)
|
415 |
+
imputed[i].cpu().write_audio_to_tb(
|
416 |
+
f"inpainted_middle/{i}",
|
417 |
+
writer,
|
418 |
+
step=state.tracker.step,
|
419 |
+
plot_fn=None,
|
420 |
+
)
|
421 |
+
imputed_true[i].cpu().write_audio_to_tb(
|
422 |
+
f"reconstructed/{i}",
|
423 |
+
writer,
|
424 |
+
step=state.tracker.step,
|
425 |
+
plot_fn=None,
|
426 |
+
)
|
427 |
+
|
428 |
+
|
429 |
+
@torch.no_grad()
|
430 |
+
def save_samples(state: State, val_idx: int, writer: SummaryWriter):
|
431 |
+
state.model.eval()
|
432 |
+
state.codec.eval()
|
433 |
+
vn = accel.unwrap(state.model)
|
434 |
+
|
435 |
+
batch = [state.val_data[i] for i in val_idx]
|
436 |
+
batch = at.util.prepare_batch(state.val_data.collate(batch), accel.device)
|
437 |
+
|
438 |
+
signal = apply_transform(state.val_data.transform, batch)
|
439 |
+
|
440 |
+
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
441 |
+
z = z[:, : vn.n_codebooks, :]
|
442 |
+
|
443 |
+
r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
|
444 |
+
|
445 |
+
|
446 |
+
mask = pmask.random(z, r)
|
447 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
448 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
449 |
+
|
450 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
451 |
+
|
452 |
+
z_hat = state.model(z_mask_latent)
|
453 |
+
|
454 |
+
z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
|
455 |
+
z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
|
456 |
+
z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
|
457 |
+
|
458 |
+
generated = vn.decode(z_pred, state.codec)
|
459 |
+
reconstructed = vn.decode(z, state.codec)
|
460 |
+
masked = vn.decode(z_mask.squeeze(1), state.codec)
|
461 |
+
|
462 |
+
for i in range(generated.batch_size):
|
463 |
+
audio_dict = {
|
464 |
+
"original": signal[i],
|
465 |
+
"masked": masked[i],
|
466 |
+
"generated": generated[i],
|
467 |
+
"reconstructed": reconstructed[i],
|
468 |
+
}
|
469 |
+
for k, v in audio_dict.items():
|
470 |
+
v.cpu().write_audio_to_tb(
|
471 |
+
f"onestep/_{i}.r={r[i]:0.2f}/{k}",
|
472 |
+
writer,
|
473 |
+
step=state.tracker.step,
|
474 |
+
plot_fn=None,
|
475 |
+
)
|
476 |
+
|
477 |
+
save_sampled(state=state, z=z, writer=writer)
|
478 |
+
save_imputation(state=state, z=z, val_idx=val_idx, writer=writer)
|
479 |
+
|
480 |
+
|
481 |
+
|
482 |
+
@argbind.bind(without_prefix=True)
|
483 |
+
def load(
|
484 |
+
args,
|
485 |
+
accel: at.ml.Accelerator,
|
486 |
+
tracker: Tracker,
|
487 |
+
save_path: str,
|
488 |
+
resume: bool = False,
|
489 |
+
tag: str = "latest",
|
490 |
+
fine_tune_checkpoint: Optional[str] = None,
|
491 |
+
grad_clip_val: float = 5.0,
|
492 |
+
) -> State:
|
493 |
+
codec = DAC.load(args["codec_ckpt"], map_location="cpu")
|
494 |
+
codec.eval()
|
495 |
+
|
496 |
+
model, v_extra = None, {}
|
497 |
+
|
498 |
+
if args["fine_tune"]:
|
499 |
+
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
500 |
+
model = torch.compile(
|
501 |
+
VampNet.load(location=Path(fine_tune_checkpoint),
|
502 |
+
map_location="cpu",
|
503 |
+
)
|
504 |
+
)
|
505 |
+
|
506 |
+
if resume:
|
507 |
+
kwargs = {
|
508 |
+
"folder": f"{save_path}/{tag}",
|
509 |
+
"map_location": "cpu",
|
510 |
+
"package": False,
|
511 |
+
}
|
512 |
+
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
513 |
+
if (Path(kwargs["folder"]) / "vampnet").exists():
|
514 |
+
model, v_extra = VampNet.load_from_folder(**kwargs)
|
515 |
+
else:
|
516 |
+
raise ValueError(
|
517 |
+
f"Could not find a VampNet checkpoint in {kwargs['folder']}"
|
518 |
+
)
|
519 |
+
|
520 |
+
|
521 |
+
|
522 |
+
|
523 |
+
model = torch.compile(VampNet()) if model is None else model
|
524 |
+
model = accel.prepare_model(model)
|
525 |
+
|
526 |
+
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
527 |
+
assert (
|
528 |
+
accel.unwrap(model).vocab_size == codec.quantizer.quantizers[0].codebook_size
|
529 |
+
)
|
530 |
+
|
531 |
+
|
532 |
+
if accel.world_size > 1:
|
533 |
+
from torch.distributed.optim import ZeroRedundancyOptimizer
|
534 |
+
optimizer = ZeroRedundancyOptimizer(model.parameters(), AdamW)
|
535 |
+
print(f"OPTIMIZER LR is {optimizer.param_groups[0]['lr']}")
|
536 |
+
else:
|
537 |
+
optimizer = AdamW(model.parameters())
|
538 |
+
|
539 |
+
scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
|
540 |
+
scheduler.step()
|
541 |
+
|
542 |
+
if "optimizer.pth" in v_extra:
|
543 |
+
optimizer.load_state_dict(v_extra["optimizer.pth"])
|
544 |
+
scheduler.load_state_dict(v_extra["scheduler.pth"])
|
545 |
+
if "tracker.pth" in v_extra:
|
546 |
+
tracker.load_state_dict(v_extra["tracker.pth"])
|
547 |
+
|
548 |
+
criterion = CrossEntropyLoss()
|
549 |
+
|
550 |
+
sample_rate = codec.sample_rate
|
551 |
+
|
552 |
+
# a better rng for sampling from our schedule
|
553 |
+
rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"])
|
554 |
+
|
555 |
+
# log a model summary w/ num params
|
556 |
+
if accel.local_rank == 0:
|
557 |
+
add_num_params_repr_hook(accel.unwrap(model))
|
558 |
+
with open(f"{save_path}/model.txt", "w") as f:
|
559 |
+
f.write(repr(accel.unwrap(model)))
|
560 |
+
|
561 |
+
# load the datasets
|
562 |
+
train_data, val_data = build_datasets(args, sample_rate)
|
563 |
+
|
564 |
+
return State(
|
565 |
+
tracker=tracker,
|
566 |
+
model=model,
|
567 |
+
codec=codec,
|
568 |
+
optimizer=optimizer,
|
569 |
+
scheduler=scheduler,
|
570 |
+
criterion=criterion,
|
571 |
+
rng=rng,
|
572 |
+
train_data=train_data,
|
573 |
+
val_data=val_data,
|
574 |
+
grad_clip_val=grad_clip_val,
|
575 |
+
)
|
576 |
+
|
577 |
+
|
578 |
+
@argbind.bind(without_prefix=True)
|
579 |
+
def train(
|
580 |
+
args,
|
581 |
+
accel: at.ml.Accelerator,
|
582 |
+
seed: int = 0,
|
583 |
+
codec_ckpt: str = None,
|
584 |
+
save_path: str = "ckpt",
|
585 |
+
num_iters: int = int(1000e6),
|
586 |
+
save_iters: list = [10000, 50000, 100000, 300000, 500000,],
|
587 |
+
sample_freq: int = 10000,
|
588 |
+
val_freq: int = 1000,
|
589 |
+
batch_size: int = 12,
|
590 |
+
val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
591 |
+
num_workers: int = 10,
|
592 |
+
fine_tune: bool = False,
|
593 |
+
):
|
594 |
+
assert codec_ckpt is not None, "codec_ckpt is required"
|
595 |
+
|
596 |
+
seed = seed + accel.local_rank
|
597 |
+
at.util.seed(seed)
|
598 |
+
writer = None
|
599 |
+
|
600 |
+
if accel.local_rank == 0:
|
601 |
+
writer = SummaryWriter(log_dir=f"{save_path}/logs/")
|
602 |
+
argbind.dump_args(args, f"{save_path}/args.yml")
|
603 |
+
|
604 |
+
tracker = Tracker(
|
605 |
+
writer=writer, log_file=f"{save_path}/log.txt", rank=accel.local_rank
|
606 |
+
)
|
607 |
+
|
608 |
+
# load the codec model
|
609 |
+
state: State = load(
|
610 |
+
args=args,
|
611 |
+
accel=accel,
|
612 |
+
tracker=tracker,
|
613 |
+
save_path=save_path)
|
614 |
+
print("initialized state.")
|
615 |
+
|
616 |
+
train_dataloader = accel.prepare_dataloader(
|
617 |
+
state.train_data,
|
618 |
+
start_idx=state.tracker.step * batch_size,
|
619 |
+
num_workers=num_workers,
|
620 |
+
batch_size=batch_size,
|
621 |
+
collate_fn=state.train_data.collate,
|
622 |
+
)
|
623 |
+
val_dataloader = accel.prepare_dataloader(
|
624 |
+
state.val_data,
|
625 |
+
start_idx=0,
|
626 |
+
num_workers=num_workers,
|
627 |
+
batch_size=batch_size,
|
628 |
+
collate_fn=state.val_data.collate,
|
629 |
+
persistent_workers=num_workers > 0,
|
630 |
+
)
|
631 |
+
print("initialized dataloader.")
|
632 |
+
|
633 |
+
|
634 |
+
|
635 |
+
if fine_tune:
|
636 |
+
lora.mark_only_lora_as_trainable(state.model)
|
637 |
+
print("marked only lora as trainable.")
|
638 |
+
|
639 |
+
# Wrap the functions so that they neatly track in TensorBoard + progress bars
|
640 |
+
# and only run when specific conditions are met.
|
641 |
+
global train_loop, val_loop, validate, save_samples, checkpoint
|
642 |
+
|
643 |
+
train_loop = tracker.log("train", "value", history=False)(
|
644 |
+
tracker.track("train", num_iters, completed=state.tracker.step)(train_loop)
|
645 |
+
)
|
646 |
+
val_loop = tracker.track("val", len(val_dataloader))(val_loop)
|
647 |
+
validate = tracker.log("val", "mean")(validate)
|
648 |
+
|
649 |
+
save_samples = when(lambda: accel.local_rank == 0)(save_samples)
|
650 |
+
checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
|
651 |
+
|
652 |
+
print("starting training loop.")
|
653 |
+
with tracker.live:
|
654 |
+
for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
|
655 |
+
train_loop(state, batch, accel)
|
656 |
+
|
657 |
+
last_iter = (
|
658 |
+
tracker.step == num_iters - 1 if num_iters is not None else False
|
659 |
+
)
|
660 |
+
|
661 |
+
if tracker.step % sample_freq == 0 or last_iter:
|
662 |
+
save_samples(state, val_idx, writer)
|
663 |
+
|
664 |
+
if tracker.step % val_freq == 0 or last_iter:
|
665 |
+
validate(state, val_dataloader, accel)
|
666 |
+
checkpoint(
|
667 |
+
state=state,
|
668 |
+
save_iters=save_iters,
|
669 |
+
save_path=save_path,
|
670 |
+
fine_tune=fine_tune)
|
671 |
+
|
672 |
+
# Reset validation progress bar, print summary since last validation.
|
673 |
+
tracker.done("val", f"Iteration {tracker.step}")
|
674 |
+
|
675 |
+
if last_iter:
|
676 |
+
break
|
677 |
+
|
678 |
+
|
679 |
+
if __name__ == "__main__":
|
680 |
+
args = argbind.parse_args()
|
681 |
+
args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
|
682 |
+
with argbind.scope(args):
|
683 |
+
with Accelerator() as accel:
|
684 |
+
if accel.local_rank != 0:
|
685 |
+
sys.tracebacklimit = 0
|
686 |
+
train(args, accel)
|
scripts/utils/README.md
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scripts
|
2 |
+
|
3 |
+
## process_zip.py
|
4 |
+
|
5 |
+
Some requirements that may not be installed in the docker image:
|
6 |
+
* argbind
|
7 |
+
* wav2wav (pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git or `pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git@<branchname>`)
|
8 |
+
|
9 |
+
### zip folder structure
|
10 |
+
|
11 |
+
The zip folder should have the following internal structure:
|
12 |
+
|
13 |
+
```
|
14 |
+
base_folder/
|
15 |
+
test_case_1/
|
16 |
+
before.wav
|
17 |
+
test_case_2/
|
18 |
+
before.wav
|
19 |
+
...
|
20 |
+
test_case_n/
|
21 |
+
before.wav
|
22 |
+
```
|
23 |
+
|
24 |
+
Note: There can be issues with the output zip if the input zip folder structure is too deep or too shallow. IF you want/need to use a zip file with a different folder structure, adjust this:
|
25 |
+
https://github.com/descriptinc/lyrebird-wav2wav/blob/136c923ce19df03876a515ca0ed83854710cfa30/scripts/utils/process_zip.py#L28
|
26 |
+
|
27 |
+
### Execution
|
28 |
+
`python process_zip.py <path/to/zip> -tag <string>`
|
scripts/utils/gtzan_embeddings.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
TODO: train a linear probe
|
3 |
+
usage:
|
4 |
+
python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_gtzan /path/to/gtzan/genres_original --output_dir /path/to/output
|
5 |
+
"""
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
import audiotools as at
|
10 |
+
from audiotools import AudioSignal
|
11 |
+
import argbind
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
import zipfile
|
15 |
+
import json
|
16 |
+
|
17 |
+
from vampnet.interface import Interface
|
18 |
+
import tqdm
|
19 |
+
|
20 |
+
# bind the Interface to argbind
|
21 |
+
Interface = argbind.bind(Interface)
|
22 |
+
|
23 |
+
DEBUG = False
|
24 |
+
|
25 |
+
def smart_plotly_export(fig, save_path):
|
26 |
+
img_format = save_path.split('.')[-1]
|
27 |
+
if img_format == 'html':
|
28 |
+
fig.write_html(save_path)
|
29 |
+
elif img_format == 'bytes':
|
30 |
+
return fig.to_image(format='png')
|
31 |
+
#TODO: come back and make this prettier
|
32 |
+
elif img_format == 'numpy':
|
33 |
+
import io
|
34 |
+
from PIL import Image
|
35 |
+
|
36 |
+
def plotly_fig2array(fig):
|
37 |
+
#convert Plotly fig to an array
|
38 |
+
fig_bytes = fig.to_image(format="png", width=1200, height=700)
|
39 |
+
buf = io.BytesIO(fig_bytes)
|
40 |
+
img = Image.open(buf)
|
41 |
+
return np.asarray(img)
|
42 |
+
|
43 |
+
return plotly_fig2array(fig)
|
44 |
+
elif img_format == 'jpeg' or 'png' or 'webp':
|
45 |
+
fig.write_image(save_path)
|
46 |
+
else:
|
47 |
+
raise ValueError("invalid image format")
|
48 |
+
|
49 |
+
def dim_reduce(emb, labels, save_path, n_components=3, method='tsne', title=''):
|
50 |
+
"""
|
51 |
+
dimensionality reduction for visualization!
|
52 |
+
saves an html plotly figure to save_path
|
53 |
+
parameters:
|
54 |
+
emb (np.ndarray): the samples to be reduces with shape (samples, features)
|
55 |
+
labels (list): list of labels for embedding
|
56 |
+
save_path (str): path where u wanna save ur figure
|
57 |
+
method (str): umap, tsne, or pca
|
58 |
+
title (str): title for ur figure
|
59 |
+
returns:
|
60 |
+
proj (np.ndarray): projection vector with shape (samples, dimensions)
|
61 |
+
"""
|
62 |
+
import pandas as pd
|
63 |
+
import plotly.express as px
|
64 |
+
if method == 'umap':
|
65 |
+
from umap import UMAP
|
66 |
+
reducer = umap.UMAP(n_components=n_components)
|
67 |
+
elif method == 'tsne':
|
68 |
+
from sklearn.manifold import TSNE
|
69 |
+
reducer = TSNE(n_components=n_components)
|
70 |
+
elif method == 'pca':
|
71 |
+
from sklearn.decomposition import PCA
|
72 |
+
reducer = PCA(n_components=n_components)
|
73 |
+
else:
|
74 |
+
raise ValueError
|
75 |
+
|
76 |
+
proj = reducer.fit_transform(emb)
|
77 |
+
|
78 |
+
if n_components == 2:
|
79 |
+
df = pd.DataFrame(dict(
|
80 |
+
x=proj[:, 0],
|
81 |
+
y=proj[:, 1],
|
82 |
+
instrument=labels
|
83 |
+
))
|
84 |
+
fig = px.scatter(df, x='x', y='y', color='instrument',
|
85 |
+
title=title+f"_{method}")
|
86 |
+
|
87 |
+
elif n_components == 3:
|
88 |
+
df = pd.DataFrame(dict(
|
89 |
+
x=proj[:, 0],
|
90 |
+
y=proj[:, 1],
|
91 |
+
z=proj[:, 2],
|
92 |
+
instrument=labels
|
93 |
+
))
|
94 |
+
fig = px.scatter_3d(df, x='x', y='y', z='z',
|
95 |
+
color='instrument',
|
96 |
+
title=title)
|
97 |
+
else:
|
98 |
+
raise ValueError("cant plot more than 3 components")
|
99 |
+
|
100 |
+
fig.update_traces(marker=dict(size=6,
|
101 |
+
line=dict(width=1,
|
102 |
+
color='DarkSlateGrey')),
|
103 |
+
selector=dict(mode='markers'))
|
104 |
+
|
105 |
+
return smart_plotly_export(fig, save_path)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
# per JukeMIR, we want the emebddings from the middle layer?
|
110 |
+
def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
|
111 |
+
with torch.inference_mode():
|
112 |
+
# preprocess the signal
|
113 |
+
sig = interface.preprocess(sig)
|
114 |
+
|
115 |
+
# get the coarse vampnet model
|
116 |
+
vampnet = interface.coarse
|
117 |
+
|
118 |
+
# get the tokens
|
119 |
+
z = interface.encode(sig)[:, :vampnet.n_codebooks, :]
|
120 |
+
z_latents = vampnet.embedding.from_codes(z, interface.codec)
|
121 |
+
|
122 |
+
# do a forward pass through the model, get the embeddings
|
123 |
+
_z, embeddings = vampnet(z_latents, return_activations=True)
|
124 |
+
# print(f"got embeddings with shape {embeddings.shape}")
|
125 |
+
# [layer, batch, time, n_dims]
|
126 |
+
# [20, 1, 600ish, 768]
|
127 |
+
|
128 |
+
|
129 |
+
# squeeze batch dim (1 bc layer should be dim 0)
|
130 |
+
assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}"
|
131 |
+
embeddings = embeddings.squeeze(1)
|
132 |
+
|
133 |
+
num_layers = embeddings.shape[0]
|
134 |
+
assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers"
|
135 |
+
|
136 |
+
# do meanpooling over the time dimension
|
137 |
+
embeddings = embeddings.mean(dim=-2)
|
138 |
+
# [20, 768]
|
139 |
+
|
140 |
+
# return the embeddings
|
141 |
+
return embeddings
|
142 |
+
|
143 |
+
from dataclasses import dataclass, fields
|
144 |
+
@dataclass
|
145 |
+
class Embedding:
|
146 |
+
genre: str
|
147 |
+
filename: str
|
148 |
+
embedding: np.ndarray
|
149 |
+
|
150 |
+
def save(self, path):
|
151 |
+
"""Save the Embedding object to a given path as a zip file."""
|
152 |
+
with zipfile.ZipFile(path, 'w') as archive:
|
153 |
+
|
154 |
+
# Save numpy array
|
155 |
+
with archive.open('embedding.npy', 'w') as f:
|
156 |
+
np.save(f, self.embedding)
|
157 |
+
|
158 |
+
# Save non-numpy data as json
|
159 |
+
non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'}
|
160 |
+
with archive.open('data.json', 'w') as f:
|
161 |
+
f.write(json.dumps(non_numpy_data).encode('utf-8'))
|
162 |
+
|
163 |
+
@classmethod
|
164 |
+
def load(cls, path):
|
165 |
+
"""Load the Embedding object from a given zip path."""
|
166 |
+
with zipfile.ZipFile(path, 'r') as archive:
|
167 |
+
|
168 |
+
# Load numpy array
|
169 |
+
with archive.open('embedding.npy') as f:
|
170 |
+
embedding = np.load(f)
|
171 |
+
|
172 |
+
# Load non-numpy data from json
|
173 |
+
with archive.open('data.json') as f:
|
174 |
+
data = json.loads(f.read().decode('utf-8'))
|
175 |
+
|
176 |
+
return cls(embedding=embedding, **data)
|
177 |
+
|
178 |
+
|
179 |
+
@argbind.bind(without_prefix=True)
|
180 |
+
def main(
|
181 |
+
path_to_gtzan: str = None,
|
182 |
+
cache_dir: str = "./.gtzan_emb_cache",
|
183 |
+
output_dir: str = "./gtzan_vampnet_embeddings",
|
184 |
+
layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
|
185 |
+
):
|
186 |
+
path_to_gtzan = Path(path_to_gtzan)
|
187 |
+
assert path_to_gtzan.exists(), f"{path_to_gtzan} does not exist"
|
188 |
+
|
189 |
+
cache_dir = Path(cache_dir)
|
190 |
+
output_dir = Path(output_dir)
|
191 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
192 |
+
|
193 |
+
# load our interface
|
194 |
+
# argbind will automatically load the default config,
|
195 |
+
interface = Interface()
|
196 |
+
|
197 |
+
# gtzan should have a folder for each genre, so let's get the list of genres
|
198 |
+
genres = [Path(x).name for x in path_to_gtzan.iterdir() if x.is_dir()]
|
199 |
+
print(f"Found {len(genres)} genres")
|
200 |
+
print(f"genres: {genres}")
|
201 |
+
|
202 |
+
# collect audio files, genres, and embeddings
|
203 |
+
data = []
|
204 |
+
for genre in genres:
|
205 |
+
audio_files = list(at.util.find_audio(path_to_gtzan / genre))
|
206 |
+
print(f"Found {len(audio_files)} audio files for genre {genre}")
|
207 |
+
|
208 |
+
for audio_file in tqdm.tqdm(audio_files, desc=f"embedding genre {genre}"):
|
209 |
+
# check if we have a cached embedding for this file
|
210 |
+
cached_path = (cache_dir / f"{genre}_{audio_file.stem}.emb")
|
211 |
+
if cached_path.exists():
|
212 |
+
# if so, load it
|
213 |
+
if DEBUG:
|
214 |
+
print(f"loading cached embedding for {cached_path.stem}")
|
215 |
+
embedding = Embedding.load(cached_path)
|
216 |
+
else:
|
217 |
+
try:
|
218 |
+
sig = AudioSignal(audio_file)
|
219 |
+
except Exception as e:
|
220 |
+
print(f"failed to load {audio_file.name} with error {e}")
|
221 |
+
print(f"skipping {audio_file.name}")
|
222 |
+
continue
|
223 |
+
|
224 |
+
# gets the embedding
|
225 |
+
emb = vampnet_embed(sig, interface).cpu().numpy()
|
226 |
+
|
227 |
+
# create an embedding we can save/load
|
228 |
+
embedding = Embedding(
|
229 |
+
genre=genre,
|
230 |
+
filename=audio_file.name,
|
231 |
+
embedding=emb
|
232 |
+
)
|
233 |
+
|
234 |
+
# cache the embeddings
|
235 |
+
cached_path.parent.mkdir(exist_ok=True, parents=True)
|
236 |
+
embedding.save(cached_path)
|
237 |
+
data.append(embedding)
|
238 |
+
|
239 |
+
# now, let's do a dim reduction on the embeddings
|
240 |
+
# and visualize them.
|
241 |
+
|
242 |
+
# collect a list of embeddings and labels
|
243 |
+
embeddings = [d.embedding for d in data]
|
244 |
+
labels = [d.genre for d in data]
|
245 |
+
|
246 |
+
# convert the embeddings to a numpy array
|
247 |
+
embeddings = np.stack(embeddings)
|
248 |
+
|
249 |
+
# do dimensionality reduction for each layer we're given
|
250 |
+
for layer in tqdm.tqdm(layers, desc="dim reduction"):
|
251 |
+
dim_reduce(
|
252 |
+
embeddings[:, layer, :], labels,
|
253 |
+
save_path=str(output_dir / f'vampnet-gtzan-layer={layer}.html'),
|
254 |
+
n_components=2, method='tsne',
|
255 |
+
title=f'vampnet-gtzan-layer={layer}'
|
256 |
+
)
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
|
261 |
+
if __name__ == "__main__":
|
262 |
+
args = argbind.parse_args()
|
263 |
+
with argbind.scope(args):
|
264 |
+
main()
|
scripts/utils/plots.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import seaborn as sns
|
3 |
+
from pandas.api.types import CategoricalDtype
|
4 |
+
|
5 |
+
def plot_metrics(metrics, condition_to_latex, title, color_palette):
|
6 |
+
# Add a new column to your dataframe with the latex representation
|
7 |
+
metrics['condition_latex'] = metrics['condition'].map(condition_to_latex)
|
8 |
+
|
9 |
+
# Order condition_latex as per the condition_to_latex dictionary
|
10 |
+
cat_type = CategoricalDtype(categories=condition_to_latex.values(), ordered=True)
|
11 |
+
metrics['condition_latex'] = metrics['condition_latex'].astype(cat_type)
|
12 |
+
|
13 |
+
# Compute mean and std for each condition for each metric
|
14 |
+
grouped = metrics.groupby('condition_latex')[['mel', 'frechet']].agg(['mean', 'std'])
|
15 |
+
|
16 |
+
fig, axs = plt.subplots(2, 1, figsize=(7, 5.25))
|
17 |
+
|
18 |
+
# Set the main title for the figure
|
19 |
+
fig.suptitle(title, fontsize=16)
|
20 |
+
|
21 |
+
# Get color for each bar in the plot
|
22 |
+
bar_colors = [color_palette[condition] for condition in grouped.index]
|
23 |
+
|
24 |
+
# Plot mel
|
25 |
+
sns.boxplot(x='condition_latex', y='mel', data=metrics, ax=axs[0], palette=color_palette, showfliers=False)
|
26 |
+
axs[0].set_ylabel('Mel Spectrogram Loss \u2190')
|
27 |
+
axs[0].set_xlabel('') # Remove x-axis label
|
28 |
+
axs[0].set_xticklabels(grouped.index, rotation=0, ha='center')
|
29 |
+
|
30 |
+
# Plot frechet
|
31 |
+
axs[1].bar(grouped.index, grouped['frechet']['mean'], yerr=grouped['frechet']['std'], color=bar_colors)
|
32 |
+
axs[1].set_ylabel('FAD \u2190')
|
33 |
+
axs[1].set_xlabel('') # Remove x-axis label
|
34 |
+
axs[1].set_xticklabels(grouped.index, rotation=0, ha='center')
|
35 |
+
|
36 |
+
# Adjust the space between plots
|
37 |
+
plt.subplots_adjust(hspace=0.1)
|
38 |
+
|
39 |
+
# Remove any unnecessary space around the plot
|
40 |
+
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
41 |
+
|
42 |
+
# Reduce the space between suptitle and the plot
|
43 |
+
plt.subplots_adjust(top=0.92)
|
scripts/utils/remove_quiet_files.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# removes files with loudness below 24db
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
import shutil
|
5 |
+
import audiotools as at
|
6 |
+
import argbind
|
7 |
+
|
8 |
+
@argbind.bind(without_prefix=True)
|
9 |
+
def remove_quiet_files(
|
10 |
+
src_dir: Path = None,
|
11 |
+
dest_dir: Path = None,
|
12 |
+
min_loudness: float = -30,
|
13 |
+
):
|
14 |
+
# copy src to dest
|
15 |
+
dest_dir.mkdir(parents=True, exist_ok=True)
|
16 |
+
shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
|
17 |
+
|
18 |
+
audio_files = at.util.find_audio(dest_dir)
|
19 |
+
for audio_file in audio_files:
|
20 |
+
sig = at.AudioSignal(audio_file)
|
21 |
+
if sig.loudness() < min_loudness:
|
22 |
+
audio_file.unlink()
|
23 |
+
print(f"removed {audio_file}")
|
24 |
+
|
25 |
+
if __name__ == "__main__":
|
26 |
+
args = argbind.parse_args()
|
27 |
+
|
28 |
+
with argbind.scope(args):
|
29 |
+
remove_quiet_files()
|
scripts/utils/split.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import random
|
3 |
+
import shutil
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
|
7 |
+
import argbind
|
8 |
+
from tqdm import tqdm
|
9 |
+
from tqdm.contrib.concurrent import thread_map
|
10 |
+
|
11 |
+
from audiotools.core import util
|
12 |
+
|
13 |
+
|
14 |
+
@argbind.bind(without_prefix=True)
|
15 |
+
def train_test_split(
|
16 |
+
audio_folder: str = ".",
|
17 |
+
test_size: float = 0.2,
|
18 |
+
seed: int = 42,
|
19 |
+
):
|
20 |
+
print(f"finding audio")
|
21 |
+
|
22 |
+
audio_folder = Path(audio_folder)
|
23 |
+
audio_files = util.find_audio(audio_folder)
|
24 |
+
print(f"found {len(audio_files)} audio files")
|
25 |
+
|
26 |
+
# split according to test_size
|
27 |
+
n_test = int(len(audio_files) * test_size)
|
28 |
+
n_train = len(audio_files) - n_test
|
29 |
+
|
30 |
+
# shuffle
|
31 |
+
random.seed(seed)
|
32 |
+
random.shuffle(audio_files)
|
33 |
+
|
34 |
+
train_files = audio_files[:n_train]
|
35 |
+
test_files = audio_files[n_train:]
|
36 |
+
|
37 |
+
|
38 |
+
print(f"Train files: {len(train_files)}")
|
39 |
+
print(f"Test files: {len(test_files)}")
|
40 |
+
continue_ = input("Continue [yn]? ") or "n"
|
41 |
+
|
42 |
+
if continue_ != "y":
|
43 |
+
return
|
44 |
+
|
45 |
+
for split, files in (
|
46 |
+
("train", train_files), ("test", test_files)
|
47 |
+
):
|
48 |
+
for file in tqdm(files):
|
49 |
+
out_file = audio_folder.parent / f"{audio_folder.name}-{split}" / Path(file).name
|
50 |
+
out_file.parent.mkdir(exist_ok=True, parents=True)
|
51 |
+
try:
|
52 |
+
os.symlink(file, out_file)
|
53 |
+
except FileExistsError:
|
54 |
+
print(f"File {out_file} already exists, skipping")
|
55 |
+
|
56 |
+
# save split as json
|
57 |
+
with open(Path(audio_folder) / f"{split}.json", "w") as f:
|
58 |
+
json.dump([str(f) for f in files], f)
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
args = argbind.parse_args()
|
64 |
+
|
65 |
+
with argbind.scope(args):
|
66 |
+
train_test_split()
|
scripts/utils/split_long_audio_file.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import argbind
|
3 |
+
|
4 |
+
import audiotools as at
|
5 |
+
import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
@argbind.bind(without_prefix=True)
|
9 |
+
def split_long_audio_file(
|
10 |
+
file: str = None,
|
11 |
+
max_chunk_size_s: int = 60*10
|
12 |
+
):
|
13 |
+
file = Path(file)
|
14 |
+
output_dir = file.parent / file.stem
|
15 |
+
output_dir.mkdir()
|
16 |
+
|
17 |
+
sig = at.AudioSignal(file)
|
18 |
+
|
19 |
+
# split into chunks
|
20 |
+
for i, sig in tqdm.tqdm(enumerate(sig.windows(
|
21 |
+
window_duration=max_chunk_size_s, hop_duration=max_chunk_size_s/2,
|
22 |
+
preprocess=True))
|
23 |
+
):
|
24 |
+
sig.write(output_dir / f"{i}.wav")
|
25 |
+
|
26 |
+
print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
|
27 |
+
|
28 |
+
return output_dir
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
args = argbind.parse_args()
|
32 |
+
|
33 |
+
with argbind.scope(args):
|
34 |
+
split_long_audio_file()
|
scripts/utils/stage.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import argbind
|
6 |
+
import rich
|
7 |
+
from audiotools.ml import Experiment
|
8 |
+
|
9 |
+
|
10 |
+
@argbind.bind(without_prefix=True)
|
11 |
+
def run(
|
12 |
+
run_dir: str = os.getenv("PATH_TO_RUNS", "runs"),
|
13 |
+
name: str = None,
|
14 |
+
recent: bool = False,
|
15 |
+
):
|
16 |
+
if recent:
|
17 |
+
paths = sorted(Path(run_dir).iterdir(), key=os.path.getmtime)
|
18 |
+
paths = [p.name for p in paths if p.is_dir()]
|
19 |
+
if paths:
|
20 |
+
name = paths[-1]
|
21 |
+
|
22 |
+
with Experiment(run_dir, name) as exp:
|
23 |
+
exp.snapshot()
|
24 |
+
rich.print(f"Created a snapshot of {exp.parent_directory} at {exp.exp_dir}")
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == "__main__":
|
28 |
+
args = argbind.parse_args()
|
29 |
+
with argbind.scope(args):
|
30 |
+
run()
|
scripts/utils/visualize_embeddings.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
TODO: train a linear probe
|
3 |
+
usage:
|
4 |
+
python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_audio /path/to/audio/labels --output_dir /path/to/output
|
5 |
+
"""
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
import audiotools as at
|
10 |
+
from audiotools import AudioSignal
|
11 |
+
import argbind
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
import zipfile
|
15 |
+
import json
|
16 |
+
|
17 |
+
from vampnet.interface import Interface
|
18 |
+
import tqdm
|
19 |
+
|
20 |
+
# bind the Interface to argbind
|
21 |
+
Interface = argbind.bind(Interface)
|
22 |
+
|
23 |
+
DEBUG = False
|
24 |
+
|
25 |
+
|
26 |
+
def smart_plotly_export(fig, save_path: Path):
|
27 |
+
img_format = save_path.suffix[1:]
|
28 |
+
if img_format == "html":
|
29 |
+
fig.write_html(save_path)
|
30 |
+
elif img_format == 'bytes':
|
31 |
+
return fig.to_image(format='png')
|
32 |
+
#TODO: come back and make this prettier
|
33 |
+
elif img_format == 'numpy':
|
34 |
+
import io
|
35 |
+
from PIL import Image
|
36 |
+
|
37 |
+
def plotly_fig2array(fig):
|
38 |
+
#convert Plotly fig to an array
|
39 |
+
fig_bytes = fig.to_image(format="png", width=1200, height=700)
|
40 |
+
buf = io.BytesIO(fig_bytes)
|
41 |
+
img = Image.open(buf)
|
42 |
+
return np.asarray(img)
|
43 |
+
|
44 |
+
return plotly_fig2array(fig)
|
45 |
+
elif img_format == 'jpeg' or 'png' or 'webp':
|
46 |
+
fig.write_image(save_path)
|
47 |
+
else:
|
48 |
+
raise ValueError("invalid image format")
|
49 |
+
|
50 |
+
|
51 |
+
def dim_reduce(annotated_embeddings, layer, output_dir, n_components=3, method="tsne"):
|
52 |
+
"""
|
53 |
+
dimensionality reduction for visualization!
|
54 |
+
saves an html plotly figure to save_path
|
55 |
+
parameters:
|
56 |
+
annotated_embeddings (list): the annotated enmbeddings to be reduced; embeddings have shape (samples, features)
|
57 |
+
labels (list): list of labels for embedding
|
58 |
+
save_path (str): path where u wanna save ur figure
|
59 |
+
method (str): umap, tsne, or pca
|
60 |
+
title (str): title for ur figure
|
61 |
+
returns:
|
62 |
+
proj (np.ndarray): projection vector with shape (samples, dimensions)
|
63 |
+
"""
|
64 |
+
import pandas as pd
|
65 |
+
import plotly.express as px
|
66 |
+
|
67 |
+
fig_name = f"vampnet-embeddings-layer={layer}"
|
68 |
+
fig_title = f"{fig_name}_{method}"
|
69 |
+
save_path = (output_dir / fig_name).with_suffix(".html")
|
70 |
+
|
71 |
+
if method == "umap":
|
72 |
+
from umap import UMAP
|
73 |
+
reducer = umap.UMAP(n_components=n_components)
|
74 |
+
elif method == "tsne":
|
75 |
+
from sklearn.manifold import TSNE
|
76 |
+
|
77 |
+
reducer = TSNE(n_components=n_components)
|
78 |
+
elif method == "pca":
|
79 |
+
from sklearn.decomposition import PCA
|
80 |
+
|
81 |
+
reducer = PCA(n_components=n_components)
|
82 |
+
else:
|
83 |
+
raise ValueError(f"invalid method: {method}")
|
84 |
+
|
85 |
+
labels = [emb.label for emb in annotated_embeddings]
|
86 |
+
names = [emb.filename for emb in annotated_embeddings]
|
87 |
+
embs = [emb.embedding for emb in annotated_embeddings]
|
88 |
+
embs_at_layer = np.stack(embs)[:, layer, :]
|
89 |
+
projs = reducer.fit_transform(embs_at_layer)
|
90 |
+
|
91 |
+
df = pd.DataFrame(
|
92 |
+
{
|
93 |
+
"label": labels,
|
94 |
+
"name": names,
|
95 |
+
"x": projs[:, 0],
|
96 |
+
"y": projs[:, 1],
|
97 |
+
}
|
98 |
+
)
|
99 |
+
if n_components == 2:
|
100 |
+
fig = px.scatter(
|
101 |
+
df, x="x", y="y", color="label", hover_name="name", title=fig_title,
|
102 |
+
)
|
103 |
+
|
104 |
+
elif n_components == 3:
|
105 |
+
df['z'] = projs[:, 2]
|
106 |
+
fig = px.scatter_3d(
|
107 |
+
df, x="x", y="y", z="z", color="label", hover_name="name", title=fig_title
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
raise ValueError(f"can't plot {n_components} components")
|
111 |
+
|
112 |
+
fig.update_traces(
|
113 |
+
marker=dict(size=6, line=dict(width=1, color="DarkSlateGrey")),
|
114 |
+
selector=dict(mode="markers"),
|
115 |
+
)
|
116 |
+
|
117 |
+
return smart_plotly_export(fig, save_path)
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
# per JukeMIR, we want the emebddings from the middle layer?
|
122 |
+
def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
|
123 |
+
with torch.inference_mode():
|
124 |
+
# preprocess the signal
|
125 |
+
sig = interface.preprocess(sig)
|
126 |
+
|
127 |
+
# get the coarse vampnet model
|
128 |
+
vampnet = interface.coarse
|
129 |
+
|
130 |
+
# get the tokens
|
131 |
+
z = interface.encode(sig)[:, :vampnet.n_codebooks, :]
|
132 |
+
z_latents = vampnet.embedding.from_codes(z, interface.codec)
|
133 |
+
|
134 |
+
# do a forward pass through the model, get the embeddings
|
135 |
+
_z, embeddings = vampnet(z_latents, return_activations=True)
|
136 |
+
# print(f"got embeddings with shape {embeddings.shape}")
|
137 |
+
# [layer, batch, time, n_dims]
|
138 |
+
# [20, 1, 600ish, 768]
|
139 |
+
|
140 |
+
|
141 |
+
# squeeze batch dim (1 bc layer should be dim 0)
|
142 |
+
assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}"
|
143 |
+
embeddings = embeddings.squeeze(1)
|
144 |
+
|
145 |
+
num_layers = embeddings.shape[0]
|
146 |
+
assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers"
|
147 |
+
|
148 |
+
# do meanpooling over the time dimension
|
149 |
+
embeddings = embeddings.mean(dim=-2)
|
150 |
+
# [20, 768]
|
151 |
+
|
152 |
+
# return the embeddings
|
153 |
+
return embeddings
|
154 |
+
|
155 |
+
from dataclasses import dataclass, fields
|
156 |
+
@dataclass
|
157 |
+
class AnnotatedEmbedding:
|
158 |
+
label: str
|
159 |
+
filename: str
|
160 |
+
embedding: np.ndarray
|
161 |
+
|
162 |
+
def save(self, path):
|
163 |
+
"""Save the Embedding object to a given path as a zip file."""
|
164 |
+
with zipfile.ZipFile(path, 'w') as archive:
|
165 |
+
|
166 |
+
# Save numpy array
|
167 |
+
with archive.open('embedding.npy', 'w') as f:
|
168 |
+
np.save(f, self.embedding)
|
169 |
+
|
170 |
+
# Save non-numpy data as json
|
171 |
+
non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'}
|
172 |
+
with archive.open('data.json', 'w') as f:
|
173 |
+
f.write(json.dumps(non_numpy_data).encode('utf-8'))
|
174 |
+
|
175 |
+
@classmethod
|
176 |
+
def load(cls, path):
|
177 |
+
"""Load the Embedding object from a given zip path."""
|
178 |
+
with zipfile.ZipFile(path, 'r') as archive:
|
179 |
+
|
180 |
+
# Load numpy array
|
181 |
+
with archive.open('embedding.npy') as f:
|
182 |
+
embedding = np.load(f)
|
183 |
+
|
184 |
+
# Load non-numpy data from json
|
185 |
+
with archive.open('data.json') as f:
|
186 |
+
data = json.loads(f.read().decode('utf-8'))
|
187 |
+
|
188 |
+
return cls(embedding=embedding, **data)
|
189 |
+
|
190 |
+
|
191 |
+
@argbind.bind(without_prefix=True)
|
192 |
+
def main(
|
193 |
+
path_to_audio: str = None,
|
194 |
+
cache_dir: str = "./.emb_cache",
|
195 |
+
output_dir: str = "./vampnet_embeddings",
|
196 |
+
layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19],
|
197 |
+
method: str = "tsne",
|
198 |
+
n_components: int = 2,
|
199 |
+
):
|
200 |
+
path_to_audio = Path(path_to_audio)
|
201 |
+
assert path_to_audio.exists(), f"{path_to_audio} does not exist"
|
202 |
+
|
203 |
+
cache_dir = Path(cache_dir)
|
204 |
+
output_dir = Path(output_dir)
|
205 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
206 |
+
|
207 |
+
# load our interface
|
208 |
+
# argbind will automatically load the default config,
|
209 |
+
interface = Interface()
|
210 |
+
|
211 |
+
# we expect path_to_audio to consist of a folder for each label, so let's get the list of labels
|
212 |
+
labels = [Path(x).name for x in path_to_audio.iterdir() if x.is_dir()]
|
213 |
+
print(f"Found {len(labels)} labels")
|
214 |
+
print(f"labels: {labels}")
|
215 |
+
|
216 |
+
# collect audio files, labels, and embeddings
|
217 |
+
annotated_embeddings = []
|
218 |
+
for label in labels:
|
219 |
+
audio_files = list(at.util.find_audio(path_to_audio / label))
|
220 |
+
print(f"Found {len(audio_files)} audio files for label {label}")
|
221 |
+
|
222 |
+
for audio_file in tqdm.tqdm(audio_files, desc=f"embedding label {label}"):
|
223 |
+
# check if we have a cached embedding for this file
|
224 |
+
cached_path = cache_dir / f"{label}_{audio_file.stem}.emb"
|
225 |
+
if cached_path.exists():
|
226 |
+
# if so, load it
|
227 |
+
if DEBUG:
|
228 |
+
print(f"loading cached embedding for {cached_path.stem}")
|
229 |
+
embedding = AnnotatedEmbedding.load(cached_path)
|
230 |
+
else:
|
231 |
+
try:
|
232 |
+
sig = AudioSignal(audio_file)
|
233 |
+
except Exception as e:
|
234 |
+
print(f"failed to load {audio_file.name} with error {e}")
|
235 |
+
print(f"skipping {audio_file.name}")
|
236 |
+
continue
|
237 |
+
|
238 |
+
# gets the embedding
|
239 |
+
emb = vampnet_embed(sig, interface).cpu().numpy()
|
240 |
+
|
241 |
+
# create an embedding we can save/load
|
242 |
+
embedding = AnnotatedEmbedding(
|
243 |
+
label=label, filename=audio_file.name, embedding=emb
|
244 |
+
)
|
245 |
+
|
246 |
+
# cache the embeddings
|
247 |
+
cached_path.parent.mkdir(exist_ok=True, parents=True)
|
248 |
+
embedding.save(cached_path)
|
249 |
+
annotated_embeddings.append(embedding)
|
250 |
+
|
251 |
+
# now, let's do a dim reduction on the embeddings and visualize them.
|
252 |
+
for layer in tqdm.tqdm(layers, desc="dim reduction"):
|
253 |
+
dim_reduce(
|
254 |
+
annotated_embeddings,
|
255 |
+
layer,
|
256 |
+
output_dir=output_dir,
|
257 |
+
n_components=n_components,
|
258 |
+
method=method,
|
259 |
+
)
|
260 |
+
|
261 |
+
|
262 |
+
if __name__ == "__main__":
|
263 |
+
args = argbind.parse_args()
|
264 |
+
with argbind.scope(args):
|
265 |
+
main()
|
scripts/utils/xeno-canto-dl.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from xenopy import Query
|
2 |
+
|
3 |
+
|
4 |
+
SPECIES = [
|
5 |
+
"American Robin",
|
6 |
+
"Northern Cardinal",
|
7 |
+
"Mourning Dove",
|
8 |
+
"American Crow",
|
9 |
+
"Baltimore Oriole",
|
10 |
+
"Blue Jay",
|
11 |
+
"Eastern Bluebird",
|
12 |
+
"House Finch",
|
13 |
+
"American Goldfinch",
|
14 |
+
"House Sparrow",
|
15 |
+
"Song Sparrow",
|
16 |
+
"Tufted Titmouse",
|
17 |
+
"White-breasted Nuthatch",
|
18 |
+
"European Starling",
|
19 |
+
"American Redstart",
|
20 |
+
"Red-winged Blackbird",
|
21 |
+
"Brown-headed Cowbird",
|
22 |
+
"Common Grackle",
|
23 |
+
"Boat-tailed Grackle",
|
24 |
+
"Common Yellowthroat",
|
25 |
+
"Northern Mockingbird",
|
26 |
+
"Carolina Wren",
|
27 |
+
"Eastern Meadowlark",
|
28 |
+
"Chipping Sparrow",
|
29 |
+
"Tree Swallow",
|
30 |
+
"Barn Swallow",
|
31 |
+
"Cliff Swallow",
|
32 |
+
"Pine Siskin",
|
33 |
+
"Indigo Bunting",
|
34 |
+
"Eastern Towhee",
|
35 |
+
"Carolina Chickadee",
|
36 |
+
"Great Crested Flycatcher",
|
37 |
+
"Eastern Wood-Pewee",
|
38 |
+
"Ovenbird",
|
39 |
+
"Northern Flicker",
|
40 |
+
"Red-eyed Vireo",
|
41 |
+
"American Woodcock",
|
42 |
+
"Eastern Phoebe",
|
43 |
+
"Downy Woodpecker",
|
44 |
+
"Scarlet Tanager",
|
45 |
+
"Yellow Warbler",
|
46 |
+
"White-eyed Vireo",
|
47 |
+
"Common Loon",
|
48 |
+
"White-throated Sparrow",
|
49 |
+
"Yellow-throated Vireo",
|
50 |
+
"Great Blue Heron",
|
51 |
+
"Belted Kingfisher",
|
52 |
+
"Pied-billed Grebe",
|
53 |
+
"Wild Turkey",
|
54 |
+
"Wood Thrush",
|
55 |
+
"Rose-breasted Grosbeak",
|
56 |
+
"Field Sparrow",
|
57 |
+
"Hooded Warbler",
|
58 |
+
"Northern Parula",
|
59 |
+
"Chestnut-sided Warbler",
|
60 |
+
"Blue-winged Warbler",
|
61 |
+
"Red-bellied Woodpecker",
|
62 |
+
"Yellow-billed Cuckoo",
|
63 |
+
"Gray Catbird",
|
64 |
+
"Northern Saw-whet Owl",
|
65 |
+
"Osprey",
|
66 |
+
"Common Nighthawk",
|
67 |
+
"Broad-winged Hawk",
|
68 |
+
"Black-throated Green Warbler",
|
69 |
+
"Great Horned Owl",
|
70 |
+
"Common Raven",
|
71 |
+
"Barred Owl",
|
72 |
+
"Canada Warbler",
|
73 |
+
"Magnolia Warbler",
|
74 |
+
"Black-and-white Warbler",
|
75 |
+
"Eastern Kingbird",
|
76 |
+
"Swainson's Thrush",
|
77 |
+
"Worm-eating Warbler",
|
78 |
+
"Prairie Warbler",
|
79 |
+
"Baltimore Oriole",
|
80 |
+
"Black-throated Blue Warbler",
|
81 |
+
"Louisiana Waterthrush",
|
82 |
+
"Blackburnian Warbler",
|
83 |
+
"Black-capped Chickadee",
|
84 |
+
"Cerulean Warbler",
|
85 |
+
"Red-shouldered Hawk",
|
86 |
+
"Cooper's Hawk",
|
87 |
+
"Yellow-throated Warbler",
|
88 |
+
"Blue-headed Vireo",
|
89 |
+
"Blackpoll Warbler",
|
90 |
+
"Ruffed Grouse",
|
91 |
+
"Kentucky Warbler",
|
92 |
+
"Hermit Thrush",
|
93 |
+
"Cedar Waxwing",
|
94 |
+
"Eastern Screech-Owl",
|
95 |
+
"Northern Goshawk",
|
96 |
+
"Green Heron",
|
97 |
+
"Red-tailed Hawk",
|
98 |
+
"Black Vulture",
|
99 |
+
"Hairy Woodpecker",
|
100 |
+
"Golden-crowned Kinglet",
|
101 |
+
"Ruby-crowned Kinglet",
|
102 |
+
"Bicknell's Thrush",
|
103 |
+
"Blue-gray Gnatcatcher",
|
104 |
+
"Veery",
|
105 |
+
"Pileated Woodpecker",
|
106 |
+
"Purple Finch",
|
107 |
+
"White-crowned Sparrow",
|
108 |
+
"Snow Bunting",
|
109 |
+
"Pine Grosbeak",
|
110 |
+
"American Tree Sparrow",
|
111 |
+
"Dark-eyed Junco",
|
112 |
+
"Snowy Owl",
|
113 |
+
"White-winged Crossbill",
|
114 |
+
"Red Crossbill",
|
115 |
+
"Common Redpoll",
|
116 |
+
"Northern Shrike",
|
117 |
+
"Northern Harrier",
|
118 |
+
"Rough-legged Hawk",
|
119 |
+
"Long-eared Owl",
|
120 |
+
"Evening Grosbeak",
|
121 |
+
"Northern Pintail",
|
122 |
+
"American Black Duck",
|
123 |
+
"Mallard",
|
124 |
+
"Canvasback",
|
125 |
+
"Redhead",
|
126 |
+
"Ring-necked Duck",
|
127 |
+
"Greater Scaup",
|
128 |
+
"Lesser Scaup",
|
129 |
+
"Bufflehead",
|
130 |
+
"Common Goldeneye",
|
131 |
+
"Hooded Merganser",
|
132 |
+
"Common Merganser",
|
133 |
+
"Red-breasted Merganser",
|
134 |
+
"Ruddy Duck",
|
135 |
+
"Wood Duck",
|
136 |
+
"Gadwall",
|
137 |
+
"American Wigeon",
|
138 |
+
"Northern Shoveler",
|
139 |
+
"Green-winged Teal",
|
140 |
+
"Blue-winged Teal",
|
141 |
+
"Cinnamon Teal",
|
142 |
+
"Ringed Teal",
|
143 |
+
"Cape Teal",
|
144 |
+
"Northern Fulmar",
|
145 |
+
"Yellow-billed Loon",
|
146 |
+
"Red-throated Loon",
|
147 |
+
"Arctic Loon",
|
148 |
+
"Pacific Loon",
|
149 |
+
"Horned Grebe",
|
150 |
+
"Red-necked Grebe",
|
151 |
+
"Eared Grebe",
|
152 |
+
"Western Grebe",
|
153 |
+
"Clark's Grebe",
|
154 |
+
"Double-crested Cormorant",
|
155 |
+
"Pelagic Cormorant",
|
156 |
+
"Great Cormorant",
|
157 |
+
"American White Pelican",
|
158 |
+
"Brown Pelican",
|
159 |
+
"Brandt's Cormorant",
|
160 |
+
"Least Bittern",
|
161 |
+
"Great Egret",
|
162 |
+
"Snowy Egret",
|
163 |
+
"Little Blue Heron",
|
164 |
+
"Tricolored Heron",
|
165 |
+
"Reddish Egret",
|
166 |
+
"Black-crowned Night-Heron",
|
167 |
+
"Yellow-crowned Night-Heron",
|
168 |
+
"White Ibis",
|
169 |
+
"Glossy Ibis",
|
170 |
+
"Roseate Spoonbill",
|
171 |
+
"Wood Stork",
|
172 |
+
"Black-bellied Whistling-Duck",
|
173 |
+
"Fulvous Whistling-Duck",
|
174 |
+
"Greater White-fronted Goose",
|
175 |
+
"Snow Goose",
|
176 |
+
"Ross's Goose",
|
177 |
+
"Canada Goose",
|
178 |
+
"Brant",
|
179 |
+
"Mute Swan",
|
180 |
+
"Tundra Swan",
|
181 |
+
"Whooper Swan",
|
182 |
+
"Sandhill Crane",
|
183 |
+
"Black-necked Stilt",
|
184 |
+
"American Avocet",
|
185 |
+
"Northern Jacana",
|
186 |
+
"Greater Yellowlegs",
|
187 |
+
"Lesser Yellowlegs",
|
188 |
+
"Willet",
|
189 |
+
"Spotted Sandpiper",
|
190 |
+
"Upland Sandpiper",
|
191 |
+
"Whimbrel",
|
192 |
+
"Long-billed Curlew",
|
193 |
+
"Marbled Godwit",
|
194 |
+
"Ruddy Turnstone",
|
195 |
+
"Red Knot",
|
196 |
+
"Sanderling",
|
197 |
+
"Semipalmated Sandpiper",
|
198 |
+
"Western Sandpiper",
|
199 |
+
"Least Sandpiper",
|
200 |
+
"White-rumped Sandpiper",
|
201 |
+
"Baird's Sandpiper",
|
202 |
+
"Pectoral Sandpiper",
|
203 |
+
"Dunlin",
|
204 |
+
"Buff-breasted Sandpiper",
|
205 |
+
"Short-billed Dowitcher",
|
206 |
+
"Long-billed Dowitcher",
|
207 |
+
"Common Snipe",
|
208 |
+
"American Woodcock",
|
209 |
+
"Wilson's Phalarope",
|
210 |
+
"Red-necked Phalarope",
|
211 |
+
"Red Phalarope"
|
212 |
+
]
|
213 |
+
|
214 |
+
from pathlib import Path
|
215 |
+
|
216 |
+
def remove_spaces(s):
|
217 |
+
return s.replace(" ", "")
|
218 |
+
|
219 |
+
for species in SPECIES:
|
220 |
+
if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
|
221 |
+
continue
|
222 |
+
try:
|
223 |
+
q = Query(
|
224 |
+
name=species, q="A", length="10-30",
|
225 |
+
)
|
226 |
+
|
227 |
+
# retrieve metadata
|
228 |
+
metafiles = q.retrieve_meta(verbose=True)
|
229 |
+
# retrieve recordings
|
230 |
+
q.retrieve_recordings(multiprocess=True, nproc=10, attempts=10, outdir="/media/CHONK/hugo/xeno-canto-full/")
|
231 |
+
|
232 |
+
except:
|
233 |
+
print("Failed to download " + species)
|
234 |
+
continue
|
setup.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import find_packages
|
2 |
+
from setuptools import setup
|
3 |
+
|
4 |
+
with open("README.md") as f:
|
5 |
+
long_description = f.read()
|
6 |
+
|
7 |
+
setup(
|
8 |
+
name="vampnet",
|
9 |
+
version="0.0.1",
|
10 |
+
classifiers=[
|
11 |
+
"Intended Audience :: Developers",
|
12 |
+
"Natural Language :: English",
|
13 |
+
"Programming Language :: Python :: 3.7",
|
14 |
+
"Topic :: Artistic Software",
|
15 |
+
"Topic :: Multimedia",
|
16 |
+
"Topic :: Multimedia :: Sound/Audio",
|
17 |
+
"Topic :: Multimedia :: Sound/Audio :: Editors",
|
18 |
+
"Topic :: Software Development :: Libraries",
|
19 |
+
],
|
20 |
+
description="Generative Music Modeling.",
|
21 |
+
long_description=long_description,
|
22 |
+
long_description_content_type="text/markdown",
|
23 |
+
author="Hugo Flores García, Prem Seetharaman",
|
24 |
+
author_email="hfgacrcia@descript.com",
|
25 |
+
url="https://github.com/hugofloresgarcia/vampnet",
|
26 |
+
license="MIT",
|
27 |
+
packages=find_packages(),
|
28 |
+
install_requires=[
|
29 |
+
"torch",
|
30 |
+
"argbind>=0.3.2",
|
31 |
+
"numpy==1.23",
|
32 |
+
"wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
|
33 |
+
"lac @ git+https://github.com/hugofloresgarcia/lac.git",
|
34 |
+
"descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git",
|
35 |
+
"gradio",
|
36 |
+
"loralib",
|
37 |
+
"torch_pitch_shift",
|
38 |
+
"plotly",
|
39 |
+
],
|
40 |
+
)
|
token_telephone/tt.py
ADDED
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ttutil import hsv_to_rgb, dbg, log, set_debug, pow2db, db2pow
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
from threading import Thread
|
8 |
+
import gc
|
9 |
+
gc.disable()
|
10 |
+
|
11 |
+
import sounddevice as sd
|
12 |
+
|
13 |
+
from blessed import Terminal
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
from einops import rearrange
|
18 |
+
|
19 |
+
PROFILE = False
|
20 |
+
DEBUG = False
|
21 |
+
DEBUG_NO_VAMPNET = False
|
22 |
+
set_debug(DEBUG)
|
23 |
+
# if DEBUG:
|
24 |
+
# import gc
|
25 |
+
# # log when gc start and stops
|
26 |
+
# gc.set_debug(gc.DEBUG_STATS)
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class LoadState:
|
30 |
+
t0: float = None
|
31 |
+
loaded: bool = False
|
32 |
+
|
33 |
+
load_state = LoadState()
|
34 |
+
|
35 |
+
def on_random_color():
|
36 |
+
def random_rgb_bg():
|
37 |
+
return np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255)
|
38 |
+
return term.on_color_rgb(*random_rgb_bg())
|
39 |
+
|
40 |
+
# draw the intro screen before slow imports
|
41 |
+
def color_tokenize_txt(text: str):
|
42 |
+
# apply a random bg color to each letter
|
43 |
+
return "".join(on_random_color()(letter) for letter in text)
|
44 |
+
|
45 |
+
def color_tokenize_words(text: str):
|
46 |
+
return " ".join(on_random_color()(word) for word in text.split(" "))
|
47 |
+
|
48 |
+
def draw_intro_screen():
|
49 |
+
global load_state
|
50 |
+
load_state.t0 = time.time()
|
51 |
+
avg_time = 20 # average loading time
|
52 |
+
|
53 |
+
while not load_state.loaded:
|
54 |
+
print(term.clear)
|
55 |
+
print(term.move_xy(0, 1) + term.center(color_tokenize_words("hugo flores garcía")))
|
56 |
+
print(term.move_xy(0, 3) + term.center(color_tokenize_words("and")))
|
57 |
+
print(term.move_xy(0, 5) + term.center(color_tokenize_words("stephan moore")))
|
58 |
+
print(term.move_xy(0, 7) + term.center(color_tokenize_words("present")))
|
59 |
+
print(term.move_xy(0, 9) + term.center(term.bold(color_tokenize_txt("token telephone"))))
|
60 |
+
|
61 |
+
# print(term.move_xy(0, 10) + term.center(color_tokenize_txt("loading ")), end="")
|
62 |
+
# make a little loading bar
|
63 |
+
elapsed = time.time() - load_state.t0
|
64 |
+
num_dots = int((elapsed / avg_time) * 20)
|
65 |
+
num_spaces = 20 - num_dots
|
66 |
+
print(term.move_xy(0, 12) + term.center(color_tokenize_words("loading")))
|
67 |
+
print(term.move_xy(0, 13) + term.center(color_tokenize_txt(f"[{'.' * num_dots}") + f"{' ' * num_spaces}]"))
|
68 |
+
time.sleep(0.3)
|
69 |
+
|
70 |
+
log(f"loading took {time.time() - load_state.t0} seconds")
|
71 |
+
return
|
72 |
+
|
73 |
+
# the program
|
74 |
+
term = Terminal()
|
75 |
+
|
76 |
+
# draw the intro screen on a background thread
|
77 |
+
Thread(target=draw_intro_screen).start()
|
78 |
+
|
79 |
+
# disable garbage collection
|
80 |
+
from audiotools import AudioSignal
|
81 |
+
from vamp_helper import load_interface, ez_variation
|
82 |
+
|
83 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
84 |
+
# ~~~~~~ configs! ~~~~~~~~
|
85 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
86 |
+
|
87 |
+
MAX_LOUDNESS = -20
|
88 |
+
MIN_LOUDNESS = -40
|
89 |
+
COLS = 40
|
90 |
+
ROWS = 13
|
91 |
+
|
92 |
+
device = 'Scarlett 4i4 4th Gen'
|
93 |
+
sample_rate = 48000
|
94 |
+
num_channels = 4
|
95 |
+
blocksize = 16384
|
96 |
+
|
97 |
+
|
98 |
+
# TODO:
|
99 |
+
# still some quirks to work around recording time:
|
100 |
+
# do we wanna stop recording and wait a full cycle before letting people record again?
|
101 |
+
# how do we wanna balance the volume of a new input vs what's currently gonig on?
|
102 |
+
# should people have to take turns in between new loops?
|
103 |
+
# otherwise, we're doing great i think
|
104 |
+
# we also need to add a crossfade. This means maybe cutting off the last 0.1 seconds of the loop, and the beginning 0.1
|
105 |
+
# and use that to crossfade.
|
106 |
+
|
107 |
+
# TODO: do I wanna train a diff model to swap every 2hrs or something?
|
108 |
+
# how lond does model swapping take? how can I make it faster?
|
109 |
+
|
110 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
111 |
+
# ~~~~~~ looper ~~~~~~~~
|
112 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
113 |
+
|
114 |
+
@dataclass
|
115 |
+
class State:
|
116 |
+
# looper state
|
117 |
+
feedback: float = 0.25
|
118 |
+
duration: float = 5.0
|
119 |
+
record_channel: int = 0
|
120 |
+
|
121 |
+
loopbuf: np.ndarray = None # the main loop buffer. the token telephone audio is here
|
122 |
+
looper_in: np.ndarray = None # a buffer that stores the audio that's being recorded
|
123 |
+
|
124 |
+
buf_in: np.ndarray = None # the input block with audio samples in the audio callbac
|
125 |
+
lookback_buf: np.ndarray = None # stores some lookback audio for when the threshold is passed, to propery capture transients
|
126 |
+
|
127 |
+
recording: bool = False
|
128 |
+
playing: bool = False
|
129 |
+
|
130 |
+
# ramps
|
131 |
+
record_ramp_in: bool = False
|
132 |
+
record_ramp_out: bool = False
|
133 |
+
|
134 |
+
# n_record_layers: int = 2 # number of times we'll record over before clearing
|
135 |
+
# cur_rec_layer: int = 0
|
136 |
+
recording_locked: bool = False
|
137 |
+
|
138 |
+
rec_time: float = 0
|
139 |
+
cur_hold_time: float = None
|
140 |
+
pos: int = 0
|
141 |
+
rms_db: float = float("-inf")
|
142 |
+
|
143 |
+
trig_threshold_db = -25 # a more sane default is -20
|
144 |
+
hold_seconds = 1.0
|
145 |
+
rel_threshold_db = -40 # a more sane default is -30
|
146 |
+
|
147 |
+
status: str = field(default=None)
|
148 |
+
|
149 |
+
# token telephone configs
|
150 |
+
z_buf: torch.Tensor = None
|
151 |
+
input_ready = False
|
152 |
+
input_channel = 0
|
153 |
+
token_telephone_processing: bool = False
|
154 |
+
num_telephone_chans = 4
|
155 |
+
tt_cur_ch = 0
|
156 |
+
|
157 |
+
def __post_init__(self):
|
158 |
+
self.loopbuf = np.zeros((num_channels, int(self.duration * sample_rate)))
|
159 |
+
self.looper_in = np.zeros((1, int(self.duration * sample_rate)))
|
160 |
+
|
161 |
+
# hold 200ms of lookback to account for rising attacks.
|
162 |
+
num_lookback_samples = max(int(sample_rate * 0.2), int(blocksize))
|
163 |
+
log(f"num_lookback_samples {num_lookback_samples} ({num_lookback_samples / sample_rate} seconds)")
|
164 |
+
self.lookback_buf = np.zeros((1, num_lookback_samples))
|
165 |
+
|
166 |
+
self.buf_in = np.zeros((num_channels, blocksize))
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
def check_if_record(st: State, ain: np.ndarray, on_release_callback=None):
|
171 |
+
# get our rms value
|
172 |
+
rms = pow2db(np.sqrt(np.mean(ain**2)))
|
173 |
+
st.rms_db = rms
|
174 |
+
|
175 |
+
# determine if we should ater the looper state
|
176 |
+
# if we werent recording and we cross the trigger threshold
|
177 |
+
# start recording
|
178 |
+
# if not st.recording and rms > st.trig_threshold_db and not st.recording_locked:
|
179 |
+
if not st.recording and rms > st.trig_threshold_db and not st.recording_locked:
|
180 |
+
st.recording = True
|
181 |
+
st.record_ramp_in = True
|
182 |
+
|
183 |
+
# if we were recording and we cross the release threshold
|
184 |
+
# begin the hold period
|
185 |
+
if (st.recording and rms < st.rel_threshold_db) or st.rec_time > (st.duration-st.hold_seconds):
|
186 |
+
# if we dont have a hold time, set it
|
187 |
+
if st.cur_hold_time is None:
|
188 |
+
st.cur_hold_time = time.time()
|
189 |
+
|
190 |
+
# release if we have a hold time and we've held for the required time,
|
191 |
+
if (time.time() - st.cur_hold_time) > st.hold_seconds:
|
192 |
+
st.record_ramp_out = True
|
193 |
+
st.rec_time = 0
|
194 |
+
if on_release_callback is not None:
|
195 |
+
st.input_ready = True
|
196 |
+
on_release_callback(st)
|
197 |
+
st.cur_hold_time = None
|
198 |
+
else:
|
199 |
+
pass
|
200 |
+
else:
|
201 |
+
st.cur_hold_time = None
|
202 |
+
|
203 |
+
|
204 |
+
def launch_token_telephone(st: State):
|
205 |
+
if interface is None:
|
206 |
+
log("no interface loaded, can't do token telephone!")
|
207 |
+
time.sleep(10)
|
208 |
+
return
|
209 |
+
|
210 |
+
# if we're already processing, do nothing
|
211 |
+
if st.token_telephone_processing:
|
212 |
+
return
|
213 |
+
else:
|
214 |
+
log("starting token telephone!")
|
215 |
+
Thread(target=do_token_telephone, args=(st,)).start()
|
216 |
+
|
217 |
+
|
218 |
+
def do_token_telephone(st: State,):
|
219 |
+
st.token_telephone_processing = True
|
220 |
+
while True:
|
221 |
+
lrc = st.record_channel
|
222 |
+
t0 = time.time()
|
223 |
+
cur_ch = st.tt_cur_ch
|
224 |
+
|
225 |
+
# if there was input ready, start back from the top.
|
226 |
+
if st.input_ready:
|
227 |
+
log(f"there was input ready, processing!")
|
228 |
+
# NOTE: hugo, trying something new here. what happens if
|
229 |
+
# we don't reset the channel when input is ready,
|
230 |
+
# and instead let it come in anywhere in the cycle?
|
231 |
+
# st.tt_cur_ch = 0 # uncomment to go back to reality
|
232 |
+
|
233 |
+
# clear the lrc, reset for next record.
|
234 |
+
st.input_ready = False
|
235 |
+
|
236 |
+
# reocrd the channel that we'll be processing in and lock recording
|
237 |
+
st.input_channel = cur_ch
|
238 |
+
st.recording_locked = True
|
239 |
+
|
240 |
+
# first, let's preprocess looper in
|
241 |
+
sig_looper_in = AudioSignal(
|
242 |
+
torch.from_numpy(st.looper_in).unsqueeze(0),
|
243 |
+
sample_rate=sample_rate
|
244 |
+
)
|
245 |
+
sig_loopbuf_curch = AudioSignal(
|
246 |
+
torch.from_numpy(st.loopbuf[cur_ch:cur_ch+1]).unsqueeze(0),
|
247 |
+
sample_rate=sample_rate
|
248 |
+
)
|
249 |
+
# make sure looperin matches the midpoint in loudness
|
250 |
+
ldns_mid = max(sig_loopbuf_curch.loudness(), sig_looper_in.loudness())
|
251 |
+
sig_looper_in = sig_looper_in.normalize(ldns_mid)
|
252 |
+
st.looper_in = sig_looper_in.samples.cpu().numpy().squeeze(0)
|
253 |
+
|
254 |
+
st.loopbuf[cur_ch:cur_ch + 1] = (
|
255 |
+
st.looper_in + st.loopbuf[cur_ch:cur_ch+1] * st.feedback
|
256 |
+
)
|
257 |
+
# also lower the volumes of the other channels
|
258 |
+
for i in range(4):
|
259 |
+
if i != cur_ch:
|
260 |
+
st.loopbuf[i:i+1] = st.loopbuf[i:i+1] * 0.5 # -3dB
|
261 |
+
|
262 |
+
st.looper_in = np.zeros_like(st.looper_in)
|
263 |
+
|
264 |
+
loop_input = st.loopbuf[cur_ch:cur_ch+1]
|
265 |
+
|
266 |
+
# ~~~ VAMPNET STUFF ~~~~
|
267 |
+
sig = AudioSignal(
|
268 |
+
torch.from_numpy(loop_input).unsqueeze(0),
|
269 |
+
sample_rate=sample_rate
|
270 |
+
)
|
271 |
+
input_loudness = sig.loudness()
|
272 |
+
log(f"INPUT loudness {input_loudness}")
|
273 |
+
if input_loudness > MAX_LOUDNESS:
|
274 |
+
log(f"input loudness {input_loudness} is over {MAX_LOUDNESS}!")
|
275 |
+
sig = sig.normalize(MAX_LOUDNESS)
|
276 |
+
elif input_loudness < MIN_LOUDNESS:
|
277 |
+
log(f"input loudness {input_loudness} is under {MIN_LOUDNESS}!")
|
278 |
+
sig = sig.normalize(MIN_LOUDNESS)
|
279 |
+
|
280 |
+
sig = ez_variation(interface, sig)
|
281 |
+
sig = sig.resample(sample_rate)
|
282 |
+
|
283 |
+
# notify if we've gone over the loudness
|
284 |
+
sig = sig.normalize(input_loudness)
|
285 |
+
outloudness = sig.loudness()
|
286 |
+
if outloudness > MAX_LOUDNESS:
|
287 |
+
log(f"out loudness {sig.loudness()} is over {MAX_LOUDNESS}!")
|
288 |
+
sig = sig.normalize(MAX_LOUDNESS)
|
289 |
+
elif outloudness < MIN_LOUDNESS:
|
290 |
+
log(f"out loudness {sig.loudness()} is under {MIN_LOUDNESS}!")
|
291 |
+
sig = sig.normalize(MIN_LOUDNESS)
|
292 |
+
|
293 |
+
# put it back in the loopbuf
|
294 |
+
# write to the next channel
|
295 |
+
# (TODO: instead of trimming to loopbuf.shape[1], maybe we can just have the loopbuf be the right size from init time.)
|
296 |
+
cur_ch = (cur_ch + 1) % st.num_telephone_chans
|
297 |
+
st.tt_cur_ch = cur_ch
|
298 |
+
if False: # HUGO: is there a time where we want feedback?
|
299 |
+
st.loopbuf[cur_ch:cur_ch+1] = (
|
300 |
+
sig.samples.cpu().numpy().squeeze(0)[:, :st.loopbuf.shape[1]]
|
301 |
+
+ st.feedback * st.loopbuf[cur_ch:cur_ch+1]
|
302 |
+
)
|
303 |
+
else:
|
304 |
+
st.loopbuf[cur_ch:cur_ch+1] = (
|
305 |
+
sig.samples.cpu().numpy().squeeze(0)[:, :st.loopbuf.shape[1]]
|
306 |
+
)
|
307 |
+
|
308 |
+
log(f"output loudness {sig.loudness()}")
|
309 |
+
log(f"telephone loop took {time.time() - t0} seconds... next channel {cur_ch}\n\n")
|
310 |
+
|
311 |
+
# if we've made it back to the input channel, we can unlock the recording
|
312 |
+
log(f"cur_ch {cur_ch} input_channel {st.input_channel}")
|
313 |
+
if cur_ch == st.input_channel:
|
314 |
+
st.recording_locked = False
|
315 |
+
log(f"recording unlocked!")
|
316 |
+
|
317 |
+
|
318 |
+
# unlock the recording if we've successfully written to all channels
|
319 |
+
# if st.recording_locked and cur_ch == 0:
|
320 |
+
# st.recording_locked = False
|
321 |
+
# log(f"recording locked {st.recording_locked}")
|
322 |
+
|
323 |
+
st.token_telephone_processing = False
|
324 |
+
return
|
325 |
+
|
326 |
+
# TODO: since we're using this really high threshold
|
327 |
+
# we always need to record about 100ms in advance, to catch the beginning of the attacks.
|
328 |
+
|
329 |
+
def looper_process_block(st, block: np.ndarray):
|
330 |
+
lrc = st.record_channel
|
331 |
+
|
332 |
+
# treat the lookback buffer as a circular buffer
|
333 |
+
st.lookback_buf = np.roll(st.lookback_buf, block.shape[1], axis=1)
|
334 |
+
st.lookback_buf[:, -block.shape[1]:] = block[lrc:lrc+1, :]
|
335 |
+
|
336 |
+
|
337 |
+
# check if we need to record.
|
338 |
+
if st.recording:
|
339 |
+
start_i = (st.pos + block.shape[1]) - st.lookback_buf.shape[1]
|
340 |
+
end_i = st.pos + st.lookback_buf.shape[1]
|
341 |
+
|
342 |
+
indices = np.take(
|
343 |
+
np.arange(st.loopbuf.shape[1]),
|
344 |
+
np.arange(start_i, end_i),
|
345 |
+
mode="wrap"
|
346 |
+
)
|
347 |
+
_audio_in = st.lookback_buf[:, :]
|
348 |
+
# ramp in if we need to
|
349 |
+
if st.record_ramp_in:
|
350 |
+
_audio_in = _audio_in * np.linspace(0, 1, _audio_in.shape[1])
|
351 |
+
st.record_ramp_in=False
|
352 |
+
|
353 |
+
if st.record_ramp_out:
|
354 |
+
_audio_in = _audio_in * np.linspace(1, 0, _audio_in.shape[1])
|
355 |
+
st.record_ramp_out=False
|
356 |
+
st.recording = False
|
357 |
+
|
358 |
+
st.looper_in[:, indices] = (
|
359 |
+
0.9 * st.looper_in[:, indices] + _audio_in
|
360 |
+
)
|
361 |
+
|
362 |
+
# incremement the recording time
|
363 |
+
st.rec_time += st.lookback_buf.shape[1] / sample_rate
|
364 |
+
|
365 |
+
# check if we need to play
|
366 |
+
crossfade_samples = int(0.1 * sample_rate)
|
367 |
+
if st.playing:
|
368 |
+
play_pos = (st.pos + block.shape[1]) % st.loopbuf.shape[1] # read one buffer ahead
|
369 |
+
indices = np.arange(play_pos, play_pos + block.shape[1])
|
370 |
+
block = st.loopbuf.take(indices, axis=1, mode="wrap")[:, :] # this doesn't have any crossfading. # TODO: this is still not working!
|
371 |
+
|
372 |
+
# if we've recorded more than the loop size
|
373 |
+
if st.rec_time > st.duration and st.recording:
|
374 |
+
# play the loop
|
375 |
+
play_pos = st.pos + block.shape[1] # read one buffer ahead
|
376 |
+
indices = np.arange(play_pos, play_pos + block.shape[1])
|
377 |
+
|
378 |
+
block[lrc:lrc] = st.looper_in.take(indices, axis=1, mode="wrap")[:, :]
|
379 |
+
|
380 |
+
# advance looper state
|
381 |
+
st.pos = (st.pos + block.shape[1]) % st.loopbuf.shape[1]
|
382 |
+
|
383 |
+
return block
|
384 |
+
|
385 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
386 |
+
# ~~~~~~ drawing ~~~~~~~~
|
387 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
388 |
+
|
389 |
+
def draw_rms_bar(st, x, y, width, height):
|
390 |
+
rms_min = -50
|
391 |
+
rms_max = -10
|
392 |
+
rms = st.rms_db
|
393 |
+
rms = max(rms, rms_min)
|
394 |
+
threshold = st.trig_threshold_db
|
395 |
+
rel_threshold = st.rel_threshold_db
|
396 |
+
|
397 |
+
rms_block = int((rms - rms_min) / (rms_max - rms_min) * height)
|
398 |
+
threshold_block = (threshold - rms_min) / (rms_max - rms_min) * height
|
399 |
+
rel_threshold_block = (rel_threshold - rms_min) / (rms_max - rms_min) * height
|
400 |
+
|
401 |
+
# draw the rms curve
|
402 |
+
for i in range(rms_block, height+4):
|
403 |
+
with term.location(x+4, y+height-i):
|
404 |
+
print(term.clear_bol)
|
405 |
+
for i in range(rms_block):
|
406 |
+
rms_val = i * (rms_max - rms_min) / height + rms_min
|
407 |
+
with term.location(x, y+height-2-i):
|
408 |
+
if i < threshold_block:
|
409 |
+
print(" " + term.on_green(f"*"))
|
410 |
+
else:
|
411 |
+
print(" " + term.on_red(f"*"))
|
412 |
+
|
413 |
+
# at the very bottom of the bar, draw the rms value
|
414 |
+
with term.location(x, y+height-1):
|
415 |
+
print(f"{rms:.1f}dB")
|
416 |
+
# print(f" rms")
|
417 |
+
|
418 |
+
|
419 |
+
def draw_looper(st):
|
420 |
+
x = 0
|
421 |
+
y = 0
|
422 |
+
width = COLS
|
423 |
+
height = ROWS
|
424 |
+
|
425 |
+
tt_refresh_every = 0.3
|
426 |
+
if not hasattr(draw_looper, "last_draw"):
|
427 |
+
draw_looper.last_draw = 0
|
428 |
+
should_draw = True
|
429 |
+
else:
|
430 |
+
should_draw = (time.time() - draw_looper.last_draw) > tt_refresh_every
|
431 |
+
if should_draw:
|
432 |
+
draw_looper.last_draw = time.time()
|
433 |
+
|
434 |
+
|
435 |
+
draw_rms_bar(st, x, y, width - 10, height)
|
436 |
+
|
437 |
+
if should_draw:
|
438 |
+
with term.location(width // 2-4, 1):
|
439 |
+
for i, letter in enumerate("token telephone"):
|
440 |
+
print(on_random_color()(letter), end="")
|
441 |
+
|
442 |
+
# with term.location(ROWS-2, COLS // 2):
|
443 |
+
# print(f"status {st.status}!!!")
|
444 |
+
|
445 |
+
|
446 |
+
# if we're recording, draw a red unlderlined "rec" sign on the bottom right
|
447 |
+
# with term.location(width-8, height-1):
|
448 |
+
# if st.recording:
|
449 |
+
# print(term.on_red("rec"))
|
450 |
+
# else:
|
451 |
+
# print(term.on_gray50("rec"))
|
452 |
+
|
453 |
+
# # if we're playing draw a green underline "play" sign on the bottom right
|
454 |
+
# with term.location(width-4, height-1):
|
455 |
+
# if st.playing:
|
456 |
+
# print(term.on_green("play"))
|
457 |
+
# else:
|
458 |
+
# print(term.on_gray50("play"))
|
459 |
+
|
460 |
+
|
461 |
+
# draw the timeline at the bottom using ---
|
462 |
+
with term.location(6, height):
|
463 |
+
timeline = ["-"] * (width - 12)
|
464 |
+
playhead = int((st.pos / st.loopbuf.shape[1]) * (width - 12))
|
465 |
+
timeline[playhead] = "v"
|
466 |
+
print("|"+"".join(timeline) + "|")
|
467 |
+
|
468 |
+
|
469 |
+
# draw the main message at the very center:
|
470 |
+
msg_loc = (width // 2, height // 2+1)
|
471 |
+
_x, _y = msg_loc
|
472 |
+
if not st.recording:
|
473 |
+
if not st.recording_locked:
|
474 |
+
print(term.move_xy(0, _y-1) + term.center("make a sound", width=width+5))
|
475 |
+
print(term.move_xy(0, _y+0) + term.center("to", width=width+5))
|
476 |
+
print(term.move_xy(0, _y+1) + term.center("record", width=width+5))
|
477 |
+
else:
|
478 |
+
# how many seconds left until we can record again?
|
479 |
+
# how many more chs do we need to go through before we can record again?
|
480 |
+
if st.tt_cur_ch < st.input_channel:
|
481 |
+
chs_remaining = st.input_channel - st.tt_cur_ch
|
482 |
+
else:
|
483 |
+
chs_remaining = 4-st.tt_cur_ch + st.input_channel
|
484 |
+
locked_time_remaining = chs_remaining * st.duration + st.duration - (st.pos / sample_rate)
|
485 |
+
print(term.move_xy(0, _y-1) + term.center("please wait", width=width+5))
|
486 |
+
print(term.move_xy(0, _y+0) + term.center(term.on_green(f"{locked_time_remaining:.1f}s"), width=width+5))
|
487 |
+
print(term.move_xy(0, _y+1) + term.center("for your turn :)", width=width+5))
|
488 |
+
else:
|
489 |
+
print(term.move_xy(0, _y-1) + term.center(term.on_red("recording"), width=width+5))
|
490 |
+
print(term.move_xy(0, _y+0) + term.center(f"{(st.duration) - st.rec_time:.1f}s left", width=width+5))
|
491 |
+
print(term.move_xy(0, _y+1) + term.center("", width=width+5))
|
492 |
+
|
493 |
+
|
494 |
+
# we'll draw channel 0 (1) on the bottom right corner
|
495 |
+
# channel 1 (2) on the top right corner
|
496 |
+
# channel 2 (3) on the top left corner
|
497 |
+
# channel 3 (4) on the bottom left corner
|
498 |
+
my = 3 # margin
|
499 |
+
mx = 10
|
500 |
+
locations = {
|
501 |
+
1: (width - mx, height - my),
|
502 |
+
2: (width - mx, 1+my),
|
503 |
+
3: (mx, 1+my),
|
504 |
+
4: (mx, height - my),
|
505 |
+
}
|
506 |
+
for i in range(1, 5):
|
507 |
+
if should_draw:
|
508 |
+
if st.tt_cur_ch == i - 1 and st.token_telephone_processing:
|
509 |
+
x, y = locations[i]
|
510 |
+
on_random_colors = lambda n: "".join(on_random_color()(" ") for _ in range(n))
|
511 |
+
print(term.move_xy(x, y-1) + on_random_colors(5))
|
512 |
+
print(term.move_xy(x, y) + on_random_color()(" ") + f" {i} " + on_random_color()(" "))
|
513 |
+
print(term.move_xy(x, y+1) + on_random_colors(5))
|
514 |
+
else:
|
515 |
+
# same thing, but a gray instead of random colors
|
516 |
+
x, y = locations[i]
|
517 |
+
on_gray_colors = lambda n: "".join(term.on_gray50(" ") for _ in range(n))
|
518 |
+
print(term.move_xy(x, y-1) + on_gray_colors(5))
|
519 |
+
print(term.move_xy(x, y) + term.on_gray50(" ") + f" {i} " + term.on_gray50(" "))
|
520 |
+
print(term.move_xy(x, y+1) + on_gray_colors(5))
|
521 |
+
|
522 |
+
|
523 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
524 |
+
# ~~~~~~ live audio ~~~~~~~~
|
525 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
526 |
+
def audio_init():
|
527 |
+
sd.default.samplerate = sample_rate
|
528 |
+
sd.default.device = device
|
529 |
+
|
530 |
+
# ~~~~~~ the main audio callback ~~~~~~~~~
|
531 |
+
def callback(st, indata, outdata, frames, _time, status):
|
532 |
+
t0 = time.time()
|
533 |
+
lrc = st.record_channel
|
534 |
+
|
535 |
+
if status:
|
536 |
+
log(f"status is {status}")
|
537 |
+
st.status = status
|
538 |
+
|
539 |
+
# log dtype, status, frames, time, max min
|
540 |
+
# log(f"indata {indata.dtype} max {indata.max()} min {indata.min()} {status} {frames} {_time}")
|
541 |
+
|
542 |
+
|
543 |
+
ain = rearrange(indata, 't n -> n t', n=num_channels)
|
544 |
+
|
545 |
+
# convert audio to from int32 to float32
|
546 |
+
ain = ain.astype(np.float32) / np.iinfo(np.int16).max
|
547 |
+
buf_in = ain
|
548 |
+
|
549 |
+
# if it's all zeros, we're not recording
|
550 |
+
# so we can just pass it through
|
551 |
+
if np.all(buf_in == 0):
|
552 |
+
st.status = st.status + "no input"
|
553 |
+
return
|
554 |
+
|
555 |
+
st.buf_in = buf_in
|
556 |
+
check_if_record(
|
557 |
+
st, buf_in,
|
558 |
+
on_release_callback=launch_token_telephone
|
559 |
+
)
|
560 |
+
buf_in = looper_process_block(st, buf_in)
|
561 |
+
|
562 |
+
# pass our st.loopbuf to the output
|
563 |
+
ain = buf_in
|
564 |
+
|
565 |
+
# convert back to int32
|
566 |
+
ain = (ain * np.iinfo(np.int16).max).astype(np.int16)
|
567 |
+
|
568 |
+
outdata[:] = rearrange(ain, 'n t -> t n')
|
569 |
+
|
570 |
+
# log(f"outdata {outdata.dtype} max {outdata.max()} min {outdata.min()} --- took {time.time() - t0} seconds")
|
571 |
+
|
572 |
+
|
573 |
+
|
574 |
+
if DEBUG_NO_VAMPNET:
|
575 |
+
interface=None
|
576 |
+
else:
|
577 |
+
interface = load_interface(model_choice="opera")
|
578 |
+
|
579 |
+
load_state.loaded = True
|
580 |
+
|
581 |
+
def main():
|
582 |
+
if PROFILE:
|
583 |
+
import yappi
|
584 |
+
yappi.start()
|
585 |
+
|
586 |
+
try:
|
587 |
+
audio_init()
|
588 |
+
st = State()
|
589 |
+
st.playing = True
|
590 |
+
|
591 |
+
from functools import partial
|
592 |
+
cb = partial(callback, st)
|
593 |
+
|
594 |
+
with term.fullscreen(), term.cbreak():
|
595 |
+
with sd.Stream(channels=num_channels, callback=cb, blocksize=blocksize, prime_output_buffers_using_stream_callback=True, dtype=np.int16):
|
596 |
+
while True:
|
597 |
+
with term.hidden_cursor():
|
598 |
+
if DEBUG:
|
599 |
+
time.sleep(100)
|
600 |
+
else:
|
601 |
+
draw_looper(st)
|
602 |
+
|
603 |
+
except KeyboardInterrupt:
|
604 |
+
print(term.clear)
|
605 |
+
if PROFILE:
|
606 |
+
yappi.stop()
|
607 |
+
|
608 |
+
# retrieve thread stats by their thread id (given by yappi)
|
609 |
+
threads = yappi.get_thread_stats()
|
610 |
+
for thread in threads:
|
611 |
+
print(
|
612 |
+
"Function stats for (%s) (%d)" % (thread.name, thread.id)
|
613 |
+
) # it is the Thread.__class__.__name__
|
614 |
+
yappi.get_func_stats(ctx_id=thread.id).print_all()
|
615 |
+
|
616 |
+
main()
|
token_telephone/ttutil.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
ROOT = Path(__file__).parent
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from queue import Queue
|
7 |
+
|
8 |
+
# make a log file!!
|
9 |
+
logfile= ROOT / "log.txt"
|
10 |
+
if logfile.exists():
|
11 |
+
logfile.unlink()
|
12 |
+
logging.basicConfig(filename=logfile, level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S", format="%(asctime)s | %(levelname)s | %(message)s")
|
13 |
+
|
14 |
+
|
15 |
+
def hsv_to_rgb(h, s, v):
|
16 |
+
# from https://en.wikipedia.org/wiki/HSL_and_HSV#From_HSV
|
17 |
+
c = v * s
|
18 |
+
h_ = h / 60
|
19 |
+
x = c * (1 - abs(h_ % 2 - 1))
|
20 |
+
m = v - c
|
21 |
+
|
22 |
+
if h_ < 1:
|
23 |
+
r, g, b = c, x, 0
|
24 |
+
elif h_ < 2:
|
25 |
+
r, g, b = x, c, 0
|
26 |
+
elif h_ < 3:
|
27 |
+
r, g, b = 0, c, x
|
28 |
+
elif h_ < 4:
|
29 |
+
r, g, b = 0, x, c
|
30 |
+
elif h_ < 5:
|
31 |
+
r, g, b = x, 0, c
|
32 |
+
else:
|
33 |
+
r, g, b = c, 0, x
|
34 |
+
|
35 |
+
return r + m, g + m, b + m
|
36 |
+
|
37 |
+
|
38 |
+
def dbg(*args):
|
39 |
+
print(" ".join(map(str, args)))
|
40 |
+
|
41 |
+
|
42 |
+
# we'll want to log on a separate thread
|
43 |
+
# so that we can log without blocking the main thread
|
44 |
+
|
45 |
+
# make a queue for logging
|
46 |
+
log_queue = Queue()
|
47 |
+
|
48 |
+
# log to a file instead of the console
|
49 |
+
def log(msg):
|
50 |
+
# log_queue.put(msg)
|
51 |
+
logging.info(msg)
|
52 |
+
pass
|
53 |
+
|
54 |
+
def set_debug(debug):
|
55 |
+
if debug:
|
56 |
+
# print log to console
|
57 |
+
logging.getLogger().addHandler(logging.StreamHandler())
|
58 |
+
|
59 |
+
|
60 |
+
def pow2db(x):
|
61 |
+
return 10 * np.log10(x + 1e-6)
|
62 |
+
|
63 |
+
|
64 |
+
def db2pow(x):
|
65 |
+
return 10 ** (x / 10)
|
token_telephone/vamp_helper.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import time
|
3 |
+
import os
|
4 |
+
from contextlib import contextmanager
|
5 |
+
import random
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import audiotools as at
|
9 |
+
from audiotools import AudioSignal
|
10 |
+
import argbind
|
11 |
+
import shutil
|
12 |
+
import torch
|
13 |
+
import yaml
|
14 |
+
|
15 |
+
|
16 |
+
from vampnet.interface import Interface, signal_concat
|
17 |
+
from vampnet import mask as pmask
|
18 |
+
|
19 |
+
from ttutil import log
|
20 |
+
|
21 |
+
# TODO: incorporate discord bot (if mem allows)
|
22 |
+
# in a separate thread, send audio samples for listening
|
23 |
+
# and send back the results
|
24 |
+
# as well as the params for sampling
|
25 |
+
# also a command that lets you clear the current signal
|
26 |
+
# if you want to start over
|
27 |
+
|
28 |
+
|
29 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
30 |
+
|
31 |
+
VAMPNET_DIR = Path(".").resolve()
|
32 |
+
|
33 |
+
@contextmanager
|
34 |
+
def chdir(path):
|
35 |
+
old_dir = os.getcwd()
|
36 |
+
os.chdir(path)
|
37 |
+
try:
|
38 |
+
yield
|
39 |
+
finally:
|
40 |
+
os.chdir(old_dir)
|
41 |
+
|
42 |
+
def load_interface(model_choice="default") -> Interface:
|
43 |
+
with chdir(VAMPNET_DIR):
|
44 |
+
|
45 |
+
|
46 |
+
# populate the model choices with any interface.yml files in the generated confs
|
47 |
+
MODEL_CHOICES = {
|
48 |
+
"default": {
|
49 |
+
"Interface.coarse_ckpt": "models/vampnet/coarse.pth",
|
50 |
+
"Interface.coarse2fine_ckpt": "models/vampnet/c2f.pth",
|
51 |
+
"Interface.codec_ckpt": "models/vampnet/codec.pth",
|
52 |
+
}
|
53 |
+
}
|
54 |
+
generated_confs = Path("conf/generated")
|
55 |
+
for conf_file in generated_confs.glob("*/interface.yml"):
|
56 |
+
with open(conf_file) as f:
|
57 |
+
_conf = yaml.safe_load(f)
|
58 |
+
|
59 |
+
# check if the coarse, c2f, and codec ckpts exist
|
60 |
+
# otherwise, dont' add this model choice
|
61 |
+
if not (
|
62 |
+
Path(_conf["Interface.coarse_ckpt"]).exists() and
|
63 |
+
Path(_conf["Interface.coarse2fine_ckpt"]).exists() and
|
64 |
+
Path(_conf["Interface.codec_ckpt"]).exists()
|
65 |
+
):
|
66 |
+
continue
|
67 |
+
|
68 |
+
MODEL_CHOICES[conf_file.parent.name] = _conf
|
69 |
+
|
70 |
+
interface = Interface(
|
71 |
+
device=device,
|
72 |
+
coarse_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse_ckpt"],
|
73 |
+
coarse2fine_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse2fine_ckpt"],
|
74 |
+
codec_ckpt=MODEL_CHOICES[model_choice]["Interface.codec_ckpt"],
|
75 |
+
)
|
76 |
+
|
77 |
+
interface.model_choices = MODEL_CHOICES
|
78 |
+
interface.to("cuda" if torch.cuda.is_available() else "cpu")
|
79 |
+
return interface
|
80 |
+
|
81 |
+
def load_model(interface: Interface, model_choice: str):
|
82 |
+
interface.reload(
|
83 |
+
interface.model_choices[model_choice]["Interface.coarse_ckpt"],
|
84 |
+
interface.model_choices[model_choice]["Interface.coarse2fine_ckpt"],
|
85 |
+
)
|
86 |
+
|
87 |
+
def ez_variation(
|
88 |
+
interface,
|
89 |
+
sig: AudioSignal,
|
90 |
+
seed: int = None,
|
91 |
+
model_choice: str = None,
|
92 |
+
):
|
93 |
+
t0 = time.time()
|
94 |
+
|
95 |
+
if seed is None:
|
96 |
+
seed = int(torch.randint(0, 2**32, (1,)).item())
|
97 |
+
at.util.seed(seed)
|
98 |
+
|
99 |
+
# reload the model if necessary
|
100 |
+
if model_choice is not None:
|
101 |
+
load_model(interface, model_choice)
|
102 |
+
|
103 |
+
# SAMPLING MASK PARAMS, hard code for now, we'll prob want a more preset-ey thing for the actual thin
|
104 |
+
# we probably honestly just want to oscillate between the same 4 presets
|
105 |
+
# in a predictable order such that they have a predictable outcome
|
106 |
+
periodic_p = random.choice([3])
|
107 |
+
n_mask_codebooks = 3
|
108 |
+
sampletemp = random.choice([1.0,])
|
109 |
+
dropout = random.choice([0.0, 0.0])
|
110 |
+
|
111 |
+
top_p = None # NOTE: top p may be the culprit behind the collapse into single pitches.
|
112 |
+
|
113 |
+
# parameters for the build_mask function
|
114 |
+
build_mask_kwargs = dict(
|
115 |
+
rand_mask_intensity=1.0,
|
116 |
+
prefix_s=0.0,
|
117 |
+
suffix_s=0.0,
|
118 |
+
periodic_prompt=int(periodic_p),
|
119 |
+
periodic_prompt2=int(periodic_p),
|
120 |
+
periodic_prompt_width=1,
|
121 |
+
_dropout=dropout,
|
122 |
+
upper_codebook_mask=int(n_mask_codebooks),
|
123 |
+
upper_codebook_mask_2=int(n_mask_codebooks),
|
124 |
+
)
|
125 |
+
|
126 |
+
# parameters for the vamp function
|
127 |
+
vamp_kwargs = dict(
|
128 |
+
temperature=sampletemp,
|
129 |
+
typical_filtering=True,
|
130 |
+
typical_mass=0.15,
|
131 |
+
typical_min_tokens=64,
|
132 |
+
top_p=top_p,
|
133 |
+
seed=seed,
|
134 |
+
sample_cutoff=1.0,
|
135 |
+
)
|
136 |
+
|
137 |
+
# save the mask as a txt file
|
138 |
+
interface.set_chunk_size(10.0)
|
139 |
+
sig, mask, codes = interface.ez_vamp(
|
140 |
+
sig,
|
141 |
+
batch_size=1,
|
142 |
+
feedback_steps=1,
|
143 |
+
time_stretch_factor=1,
|
144 |
+
build_mask_kwargs=build_mask_kwargs,
|
145 |
+
vamp_kwargs=vamp_kwargs,
|
146 |
+
return_mask=True,
|
147 |
+
)
|
148 |
+
|
149 |
+
log(f"vamp took {time.time() - t0} seconds")
|
150 |
+
return sig
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
def main():
|
155 |
+
import tqdm
|
156 |
+
|
157 |
+
interface = load_interface()
|
158 |
+
sig = AudioSignal.excerpt("assets/example.wav", duration=7.0)
|
159 |
+
sig = interface.preprocess(sig)
|
160 |
+
sig.write('ttout/in.wav')
|
161 |
+
insig = sig.clone()
|
162 |
+
|
163 |
+
fdbk_every = 4
|
164 |
+
fdbk = 0.5
|
165 |
+
|
166 |
+
for i in tqdm.tqdm(range(1000)):
|
167 |
+
sig = ez_variation(interface, sig, model_choice="orchestral")
|
168 |
+
sig.write(f'ttout/out{i}.wav')
|
169 |
+
|
170 |
+
|
171 |
+
if __name__ == "__main__":
|
172 |
+
main()
|
vampnet/__init__.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from . import modules
|
3 |
+
from pathlib import Path
|
4 |
+
from . import scheduler
|
5 |
+
from .interface import Interface
|
6 |
+
from .modules.transformer import VampNet
|
7 |
+
|
8 |
+
|
9 |
+
__version__ = "0.0.1"
|
10 |
+
|
11 |
+
ROOT = Path(__file__).parent.parent
|
12 |
+
MODELS_DIR = ROOT / "models" / "vampnet"
|
13 |
+
|
14 |
+
from huggingface_hub import hf_hub_download, HfFileSystem
|
15 |
+
DEFAULT_HF_MODEL_REPO = "hugggof/vampnet"
|
16 |
+
FS = HfFileSystem()
|
17 |
+
|
18 |
+
def download_codec():
|
19 |
+
# from dac.model.dac import DAC
|
20 |
+
from lac.model.lac import LAC as DAC
|
21 |
+
repo_id = DEFAULT_HF_MODEL_REPO
|
22 |
+
filename = "codec.pth"
|
23 |
+
codec_path = hf_hub_download(
|
24 |
+
repo_id=repo_id,
|
25 |
+
filename=filename,
|
26 |
+
subfolder=None,
|
27 |
+
local_dir=MODELS_DIR
|
28 |
+
)
|
29 |
+
return codec_path
|
30 |
+
|
31 |
+
|
32 |
+
def download_default():
|
33 |
+
filenames = ["coarse.pth", "c2f.pth"]
|
34 |
+
repo_id = DEFAULT_HF_MODEL_REPO
|
35 |
+
paths = []
|
36 |
+
for filename in filenames:
|
37 |
+
path = f"{MODELS_DIR}/{filename}"
|
38 |
+
if not Path(path).exists():
|
39 |
+
path = hf_hub_download(
|
40 |
+
repo_id=repo_id,
|
41 |
+
filename=filename,
|
42 |
+
subfolder=None,
|
43 |
+
local_dir=MODELS_DIR,
|
44 |
+
local_dir_use_symlinks=False,
|
45 |
+
local_files_only=False
|
46 |
+
)
|
47 |
+
paths.append(path)
|
48 |
+
|
49 |
+
# load the models
|
50 |
+
return paths[0], paths[1]
|
51 |
+
|
52 |
+
|
53 |
+
def download_finetuned(name):
|
54 |
+
repo_id = f"{DEFAULT_HF_MODEL_REPO}"
|
55 |
+
filenames = ["coarse.pth", "c2f.pth"]
|
56 |
+
paths = []
|
57 |
+
for filename in filenames:
|
58 |
+
path = f"{MODELS_DIR}/{name}/loras/{filename}"
|
59 |
+
if not Path(path).exists():
|
60 |
+
path = hf_hub_download(
|
61 |
+
repo_id=repo_id,
|
62 |
+
filename=filename,
|
63 |
+
subfolder=f"loras/{name}",
|
64 |
+
local_dir=MODELS_DIR,
|
65 |
+
local_dir_use_symlinks=False,
|
66 |
+
local_files_only=False
|
67 |
+
)
|
68 |
+
paths.append(path)
|
69 |
+
|
70 |
+
# load the models
|
71 |
+
return paths[0], paths[1]
|
72 |
+
|
73 |
+
def list_finetuned():
|
74 |
+
diritems = FS.listdir(f"{DEFAULT_HF_MODEL_REPO}/loras")
|
75 |
+
# iterate through all the names
|
76 |
+
valid_diritems = []
|
77 |
+
for item in diritems:
|
78 |
+
model_file_items = FS.listdir(item["name"])
|
79 |
+
item_names = [item["name"].split("/")[-1] for item in model_file_items]
|
80 |
+
# check that theres a "c2f.pth" and "coarse.pth" in the items
|
81 |
+
c2f_exists = "c2f.pth" in item_names
|
82 |
+
coarse_exists = "coarse.pth" in item_names
|
83 |
+
if c2f_exists and coarse_exists:
|
84 |
+
valid_diritems.append(item)
|
85 |
+
|
86 |
+
# get the names of the valid items
|
87 |
+
names = [item["name"].split("/")[-1] for item in valid_diritems]
|
88 |
+
return names
|
89 |
+
|
90 |
+
|
vampnet/beats.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import warnings
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any
|
7 |
+
from typing import List
|
8 |
+
from typing import Tuple
|
9 |
+
from typing import Union
|
10 |
+
|
11 |
+
import librosa
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
from audiotools import AudioSignal
|
15 |
+
|
16 |
+
|
17 |
+
logging.basicConfig(level=logging.INFO)
|
18 |
+
|
19 |
+
###################
|
20 |
+
# beat sync utils #
|
21 |
+
###################
|
22 |
+
|
23 |
+
AGGREGATOR_REGISTRY = {
|
24 |
+
"mean": np.mean,
|
25 |
+
"median": np.median,
|
26 |
+
"max": np.max,
|
27 |
+
"min": np.min,
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def list_aggregators() -> list:
|
32 |
+
return list(AGGREGATOR_REGISTRY.keys())
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class TimeSegment:
|
37 |
+
start: float
|
38 |
+
end: float
|
39 |
+
|
40 |
+
@property
|
41 |
+
def duration(self):
|
42 |
+
return self.end - self.start
|
43 |
+
|
44 |
+
def __str__(self) -> str:
|
45 |
+
return f"{self.start} - {self.end}"
|
46 |
+
|
47 |
+
def find_overlapping_segment(
|
48 |
+
self, segments: List["TimeSegment"]
|
49 |
+
) -> Union["TimeSegment", None]:
|
50 |
+
"""Find the first segment that overlaps with this segment, or None if no segment overlaps"""
|
51 |
+
for s in segments:
|
52 |
+
if s.start <= self.start and s.end >= self.end:
|
53 |
+
return s
|
54 |
+
return None
|
55 |
+
|
56 |
+
|
57 |
+
def mkdir(path: Union[Path, str]) -> Path:
|
58 |
+
p = Path(path)
|
59 |
+
p.mkdir(parents=True, exist_ok=True)
|
60 |
+
return p
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
###################
|
65 |
+
# beat data #
|
66 |
+
###################
|
67 |
+
@dataclass
|
68 |
+
class BeatSegment(TimeSegment):
|
69 |
+
downbeat: bool = False # if there's a downbeat on the start_time
|
70 |
+
|
71 |
+
|
72 |
+
class Beats:
|
73 |
+
def __init__(self, beat_times, downbeat_times):
|
74 |
+
if isinstance(beat_times, np.ndarray):
|
75 |
+
beat_times = beat_times.tolist()
|
76 |
+
if isinstance(downbeat_times, np.ndarray):
|
77 |
+
downbeat_times = downbeat_times.tolist()
|
78 |
+
self._beat_times = beat_times
|
79 |
+
self._downbeat_times = downbeat_times
|
80 |
+
self._use_downbeats = False
|
81 |
+
|
82 |
+
def use_downbeats(self, use_downbeats: bool = True):
|
83 |
+
"""use downbeats instead of beats when calling beat_times"""
|
84 |
+
self._use_downbeats = use_downbeats
|
85 |
+
|
86 |
+
def beat_segments(self, signal: AudioSignal) -> List[BeatSegment]:
|
87 |
+
"""
|
88 |
+
segments a song into time segments corresponding to beats.
|
89 |
+
the first segment starts at 0 and ends at the first beat time.
|
90 |
+
the last segment starts at the last beat time and ends at the end of the song.
|
91 |
+
"""
|
92 |
+
beat_times = self._beat_times.copy()
|
93 |
+
downbeat_times = self._downbeat_times
|
94 |
+
beat_times.insert(0, 0)
|
95 |
+
beat_times.append(signal.signal_duration)
|
96 |
+
|
97 |
+
downbeat_ids = np.intersect1d(beat_times, downbeat_times, return_indices=True)[
|
98 |
+
1
|
99 |
+
]
|
100 |
+
is_downbeat = [
|
101 |
+
True if i in downbeat_ids else False for i in range(len(beat_times))
|
102 |
+
]
|
103 |
+
segments = [
|
104 |
+
BeatSegment(start_time, end_time, downbeat)
|
105 |
+
for start_time, end_time, downbeat in zip(
|
106 |
+
beat_times[:-1], beat_times[1:], is_downbeat
|
107 |
+
)
|
108 |
+
]
|
109 |
+
return segments
|
110 |
+
|
111 |
+
def get_beats(self) -> np.ndarray:
|
112 |
+
"""returns an array of beat times, in seconds
|
113 |
+
if downbeats is True, returns an array of downbeat times, in seconds
|
114 |
+
"""
|
115 |
+
return np.array(
|
116 |
+
self._downbeat_times if self._use_downbeats else self._beat_times
|
117 |
+
)
|
118 |
+
|
119 |
+
@property
|
120 |
+
def beat_times(self) -> np.ndarray:
|
121 |
+
"""return beat times"""
|
122 |
+
return np.array(self._beat_times)
|
123 |
+
|
124 |
+
@property
|
125 |
+
def downbeat_times(self) -> np.ndarray:
|
126 |
+
"""return downbeat times"""
|
127 |
+
return np.array(self._downbeat_times)
|
128 |
+
|
129 |
+
def beat_times_to_feature_frames(
|
130 |
+
self, signal: AudioSignal, features: np.ndarray
|
131 |
+
) -> np.ndarray:
|
132 |
+
"""convert beat times to frames, given an array of time-varying features"""
|
133 |
+
beat_times = self.get_beats()
|
134 |
+
beat_frames = (
|
135 |
+
beat_times * signal.sample_rate / signal.signal_length * features.shape[-1]
|
136 |
+
).astype(np.int64)
|
137 |
+
return beat_frames
|
138 |
+
|
139 |
+
def sync_features(
|
140 |
+
self, feature_frames: np.ndarray, features: np.ndarray, aggregate="median"
|
141 |
+
) -> np.ndarray:
|
142 |
+
"""sync features to beats"""
|
143 |
+
if aggregate not in AGGREGATOR_REGISTRY:
|
144 |
+
raise ValueError(f"unknown aggregation method {aggregate}")
|
145 |
+
|
146 |
+
return librosa.util.sync(
|
147 |
+
features, feature_frames, aggregate=AGGREGATOR_REGISTRY[aggregate]
|
148 |
+
)
|
149 |
+
|
150 |
+
def to_json(self) -> dict:
|
151 |
+
"""return beats and downbeats as json"""
|
152 |
+
return {
|
153 |
+
"beats": self._beat_times,
|
154 |
+
"downbeats": self._downbeat_times,
|
155 |
+
"use_downbeats": self._use_downbeats,
|
156 |
+
}
|
157 |
+
|
158 |
+
@classmethod
|
159 |
+
def from_dict(cls, data: dict):
|
160 |
+
"""load beats and downbeats from json"""
|
161 |
+
inst = cls(data["beats"], data["downbeats"])
|
162 |
+
inst.use_downbeats(data["use_downbeats"])
|
163 |
+
return inst
|
164 |
+
|
165 |
+
def save(self, output_dir: Path):
|
166 |
+
"""save beats and downbeats to json"""
|
167 |
+
mkdir(output_dir)
|
168 |
+
with open(output_dir / "beats.json", "w") as f:
|
169 |
+
json.dump(self.to_json(), f)
|
170 |
+
|
171 |
+
@classmethod
|
172 |
+
def load(cls, input_dir: Path):
|
173 |
+
"""load beats and downbeats from json"""
|
174 |
+
beats_file = Path(input_dir) / "beats.json"
|
175 |
+
with open(beats_file, "r") as f:
|
176 |
+
data = json.load(f)
|
177 |
+
return cls.from_dict(data)
|
178 |
+
|
179 |
+
|
180 |
+
###################
|
181 |
+
# beat tracking #
|
182 |
+
###################
|
183 |
+
|
184 |
+
|
185 |
+
class BeatTracker:
|
186 |
+
def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
|
187 |
+
"""extract beats from an audio signal"""
|
188 |
+
raise NotImplementedError
|
189 |
+
|
190 |
+
def __call__(self, signal: AudioSignal) -> Beats:
|
191 |
+
"""extract beats from an audio signal
|
192 |
+
NOTE: if the first beat (and/or downbeat) is detected within the first 100ms of the audio,
|
193 |
+
it is discarded. This is to avoid empty bins with no beat synced features in the first beat.
|
194 |
+
Args:
|
195 |
+
signal (AudioSignal): signal to beat track
|
196 |
+
Returns:
|
197 |
+
Tuple[np.ndarray, np.ndarray]: beats and downbeats
|
198 |
+
"""
|
199 |
+
beats, downbeats = self.extract_beats(signal)
|
200 |
+
return Beats(beats, downbeats)
|
201 |
+
|
202 |
+
|
203 |
+
class WaveBeat(BeatTracker):
|
204 |
+
def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
|
205 |
+
from wavebeat.dstcn import dsTCNModel
|
206 |
+
|
207 |
+
model = dsTCNModel.load_from_checkpoint(ckpt_path, map_location=torch.device(device))
|
208 |
+
model.eval()
|
209 |
+
|
210 |
+
self.device = device
|
211 |
+
self.model = model
|
212 |
+
|
213 |
+
def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
|
214 |
+
"""returns beat and downbeat times, in seconds"""
|
215 |
+
# extract beats
|
216 |
+
beats, downbeats = self.model.predict_beats_from_array(
|
217 |
+
audio=signal.audio_data.squeeze(0),
|
218 |
+
sr=signal.sample_rate,
|
219 |
+
use_gpu=self.device != "cpu",
|
220 |
+
)
|
221 |
+
|
222 |
+
return beats, downbeats
|
223 |
+
|
224 |
+
|
225 |
+
class MadmomBeats(BeatTracker):
|
226 |
+
def __init__(self):
|
227 |
+
raise NotImplementedError
|
228 |
+
|
229 |
+
def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
|
230 |
+
"""returns beat and downbeat times, in seconds"""
|
231 |
+
pass
|
232 |
+
|
233 |
+
|
234 |
+
BEAT_TRACKER_REGISTRY = {
|
235 |
+
"wavebeat": WaveBeat,
|
236 |
+
"madmom": MadmomBeats,
|
237 |
+
}
|
238 |
+
|
239 |
+
|
240 |
+
def list_beat_trackers() -> list:
|
241 |
+
return list(BEAT_TRACKER_REGISTRY.keys())
|
242 |
+
|
243 |
+
|
244 |
+
def load_beat_tracker(beat_tracker: str, **kwargs) -> BeatTracker:
|
245 |
+
if beat_tracker not in BEAT_TRACKER_REGISTRY:
|
246 |
+
raise ValueError(
|
247 |
+
f"Unknown beat tracker {beat_tracker}. Available: {list_beat_trackers()}"
|
248 |
+
)
|
249 |
+
|
250 |
+
return BEAT_TRACKER_REGISTRY[beat_tracker](**kwargs)
|
vampnet/interface.py
ADDED
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
import math
|
4 |
+
import logging
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from audiotools import AudioSignal
|
9 |
+
import tqdm
|
10 |
+
|
11 |
+
from .modules.transformer import VampNet
|
12 |
+
from .beats import WaveBeat
|
13 |
+
from .mask import *
|
14 |
+
|
15 |
+
# from dac.model.dac import DAC
|
16 |
+
from lac.model.lac import LAC as DAC
|
17 |
+
|
18 |
+
|
19 |
+
def signal_concat(
|
20 |
+
audio_signals: list,
|
21 |
+
):
|
22 |
+
audio_data = torch.cat([x.audio_data for x in audio_signals], dim=-1)
|
23 |
+
|
24 |
+
return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
|
25 |
+
|
26 |
+
|
27 |
+
def _load_model(
|
28 |
+
ckpt: str,
|
29 |
+
lora_ckpt: str = None,
|
30 |
+
device: str = "cpu",
|
31 |
+
chunk_size_s: int = 10,
|
32 |
+
):
|
33 |
+
# we need to set strict to False if the model has lora weights to add later
|
34 |
+
model = VampNet.load(location=Path(ckpt), map_location="cpu", strict=False)
|
35 |
+
|
36 |
+
# load lora weights if needed
|
37 |
+
if lora_ckpt is not None:
|
38 |
+
if not Path(lora_ckpt).exists():
|
39 |
+
should_cont = input(
|
40 |
+
f"lora checkpoint {lora_ckpt} does not exist. continue? (y/n) "
|
41 |
+
)
|
42 |
+
if should_cont != "y":
|
43 |
+
raise Exception("aborting")
|
44 |
+
else:
|
45 |
+
model.load_state_dict(torch.load(lora_ckpt, map_location="cpu"), strict=False)
|
46 |
+
|
47 |
+
model.to(device)
|
48 |
+
model.eval()
|
49 |
+
model.chunk_size_s = chunk_size_s
|
50 |
+
return model
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
class Interface(torch.nn.Module):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
coarse_ckpt: str = None,
|
58 |
+
coarse_lora_ckpt: str = None,
|
59 |
+
coarse2fine_ckpt: str = None,
|
60 |
+
coarse2fine_lora_ckpt: str = None,
|
61 |
+
codec_ckpt: str = None,
|
62 |
+
wavebeat_ckpt: str = None,
|
63 |
+
device: str = "cpu",
|
64 |
+
coarse_chunk_size_s: int = 10,
|
65 |
+
coarse2fine_chunk_size_s: int = 3,
|
66 |
+
compile=True,
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
assert codec_ckpt is not None, "must provide a codec checkpoint"
|
70 |
+
self.codec = DAC.load(Path(codec_ckpt))
|
71 |
+
self.codec.eval()
|
72 |
+
self.codec.to(device)
|
73 |
+
self.codec_path = Path(codec_ckpt)
|
74 |
+
|
75 |
+
assert coarse_ckpt is not None, "must provide a coarse checkpoint"
|
76 |
+
self.coarse = _load_model(
|
77 |
+
ckpt=coarse_ckpt,
|
78 |
+
lora_ckpt=coarse_lora_ckpt,
|
79 |
+
device=device,
|
80 |
+
chunk_size_s=coarse_chunk_size_s,
|
81 |
+
)
|
82 |
+
self.coarse_path = Path(coarse_ckpt)
|
83 |
+
|
84 |
+
# check if we have a coarse2fine ckpt
|
85 |
+
if coarse2fine_ckpt is not None:
|
86 |
+
self.c2f_path = Path(coarse2fine_ckpt)
|
87 |
+
self.c2f = _load_model(
|
88 |
+
ckpt=coarse2fine_ckpt,
|
89 |
+
lora_ckpt=coarse2fine_lora_ckpt,
|
90 |
+
device=device,
|
91 |
+
chunk_size_s=coarse2fine_chunk_size_s,
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
self.c2f_path = None
|
95 |
+
self.c2f = None
|
96 |
+
|
97 |
+
if wavebeat_ckpt is not None:
|
98 |
+
logging.debug(f"loading wavebeat from {wavebeat_ckpt}")
|
99 |
+
self.beat_tracker = WaveBeat(wavebeat_ckpt)
|
100 |
+
self.beat_tracker.model.to(device)
|
101 |
+
else:
|
102 |
+
self.beat_tracker = None
|
103 |
+
|
104 |
+
self.device = device
|
105 |
+
self.loudness = -24.0
|
106 |
+
|
107 |
+
if compile:
|
108 |
+
logging.debug(f"compiling models")
|
109 |
+
self.coarse = torch.compile(self.coarse)
|
110 |
+
if self.c2f is not None:
|
111 |
+
self.c2f = torch.compile(self.c2f)
|
112 |
+
self.codec = torch.compile(self.codec)
|
113 |
+
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def default(cls):
|
117 |
+
from . import download_codec, download_default
|
118 |
+
print(f"loading default vampnet")
|
119 |
+
codec_path = download_codec()
|
120 |
+
coarse_path, c2f_path = download_default()
|
121 |
+
|
122 |
+
return Interface(
|
123 |
+
coarse_ckpt=coarse_path,
|
124 |
+
coarse2fine_ckpt=c2f_path,
|
125 |
+
codec_ckpt=codec_path,
|
126 |
+
)
|
127 |
+
|
128 |
+
@classmethod
|
129 |
+
def available_models(cls):
|
130 |
+
from . import list_finetuned
|
131 |
+
return list_finetuned()
|
132 |
+
|
133 |
+
|
134 |
+
def load_finetuned(self, name: str):
|
135 |
+
assert name in self.available_models(), f"{name} is not a valid model name"
|
136 |
+
from . import download_finetuned
|
137 |
+
coarse_path, c2f_path = download_finetuned(name)
|
138 |
+
self.reload(
|
139 |
+
coarse_ckpt=coarse_path,
|
140 |
+
c2f_ckpt=c2f_path,
|
141 |
+
)
|
142 |
+
|
143 |
+
def reload(
|
144 |
+
self,
|
145 |
+
coarse_ckpt: str = None,
|
146 |
+
c2f_ckpt: str = None,
|
147 |
+
):
|
148 |
+
if coarse_ckpt is not None:
|
149 |
+
# check if we already loaded, if so, don't reload
|
150 |
+
if self.coarse_path == Path(coarse_ckpt):
|
151 |
+
logging.debug(f"already loaded {coarse_ckpt}")
|
152 |
+
else:
|
153 |
+
self.coarse = _load_model(
|
154 |
+
ckpt=coarse_ckpt,
|
155 |
+
device=self.device,
|
156 |
+
chunk_size_s=self.coarse.chunk_size_s,
|
157 |
+
)
|
158 |
+
self.coarse_path = Path(coarse_ckpt)
|
159 |
+
logging.debug(f"loaded {coarse_ckpt}")
|
160 |
+
|
161 |
+
if c2f_ckpt is not None:
|
162 |
+
if self.c2f_path == Path(c2f_ckpt):
|
163 |
+
logging.debug(f"already loaded {c2f_ckpt}")
|
164 |
+
else:
|
165 |
+
self.c2f = _load_model(
|
166 |
+
ckpt=c2f_ckpt,
|
167 |
+
device=self.device,
|
168 |
+
chunk_size_s=self.c2f.chunk_size_s,
|
169 |
+
)
|
170 |
+
self.c2f_path = Path(c2f_ckpt)
|
171 |
+
logging.debug(f"loaded {c2f_ckpt}")
|
172 |
+
|
173 |
+
def s2t(self, seconds: float):
|
174 |
+
"""seconds to tokens"""
|
175 |
+
if isinstance(seconds, np.ndarray):
|
176 |
+
return np.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
|
177 |
+
else:
|
178 |
+
return math.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
|
179 |
+
|
180 |
+
def s2t2s(self, seconds: float):
|
181 |
+
"""seconds to tokens to seconds"""
|
182 |
+
return self.t2s(self.s2t(seconds))
|
183 |
+
|
184 |
+
def t2s(self, tokens: int):
|
185 |
+
"""tokens to seconds"""
|
186 |
+
return tokens * self.codec.hop_length / self.codec.sample_rate
|
187 |
+
|
188 |
+
def to(self, device):
|
189 |
+
self.device = device
|
190 |
+
self.coarse.to(device)
|
191 |
+
self.codec.to(device)
|
192 |
+
|
193 |
+
if self.c2f is not None:
|
194 |
+
self.c2f.to(device)
|
195 |
+
|
196 |
+
if self.beat_tracker is not None:
|
197 |
+
self.beat_tracker.model.to(device)
|
198 |
+
return self
|
199 |
+
|
200 |
+
def decode(self, z: torch.Tensor):
|
201 |
+
return self.coarse.decode(z, self.codec)
|
202 |
+
|
203 |
+
def _preprocess(self, signal: AudioSignal):
|
204 |
+
signal = (
|
205 |
+
signal.clone()
|
206 |
+
.resample(self.codec.sample_rate)
|
207 |
+
.to_mono()
|
208 |
+
.normalize(self.loudness)
|
209 |
+
.ensure_max_of_audio(1.0)
|
210 |
+
)
|
211 |
+
logging.debug(f"length before codec preproc: {signal.samples.shape}")
|
212 |
+
signal.samples, length = self.codec.preprocess(signal.samples, signal.sample_rate)
|
213 |
+
logging.debug(f"length after codec preproc: {signal.samples.shape}")
|
214 |
+
return signal
|
215 |
+
|
216 |
+
@torch.inference_mode()
|
217 |
+
def encode(self, signal: AudioSignal):
|
218 |
+
signal = signal.to(self.device)
|
219 |
+
signal = self._preprocess(signal)
|
220 |
+
z = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
221 |
+
return z
|
222 |
+
|
223 |
+
def snap_to_beats(
|
224 |
+
self,
|
225 |
+
signal: AudioSignal
|
226 |
+
):
|
227 |
+
assert hasattr(self, "beat_tracker"), "No beat tracker loaded"
|
228 |
+
beats, downbeats = self.beat_tracker.extract_beats(signal)
|
229 |
+
|
230 |
+
# trim the signa around the first beat time
|
231 |
+
samples_begin = int(beats[0] * signal.sample_rate )
|
232 |
+
samples_end = int(beats[-1] * signal.sample_rate)
|
233 |
+
logging.debug(beats[0])
|
234 |
+
signal = signal.clone().trim(samples_begin, signal.length - samples_end)
|
235 |
+
|
236 |
+
return signal
|
237 |
+
|
238 |
+
def make_beat_mask(self,
|
239 |
+
signal: AudioSignal,
|
240 |
+
before_beat_s: float = 0.0,
|
241 |
+
after_beat_s: float = 0.02,
|
242 |
+
mask_downbeats: bool = True,
|
243 |
+
mask_upbeats: bool = True,
|
244 |
+
downbeat_downsample_factor: int = None,
|
245 |
+
beat_downsample_factor: int = None,
|
246 |
+
dropout: float = 0.0,
|
247 |
+
invert: bool = True,
|
248 |
+
):
|
249 |
+
"""make a beat synced mask. that is, make a mask that
|
250 |
+
places 1s at and around the beat, and 0s everywhere else.
|
251 |
+
"""
|
252 |
+
assert self.beat_tracker is not None, "No beat tracker loaded"
|
253 |
+
|
254 |
+
# get the beat times
|
255 |
+
beats, downbeats = self.beat_tracker.extract_beats(signal)
|
256 |
+
|
257 |
+
# get the beat indices in z
|
258 |
+
beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
|
259 |
+
|
260 |
+
# remove downbeats from beats
|
261 |
+
beats_z = torch.tensor(beats_z)[~torch.isin(torch.tensor(beats_z), torch.tensor(downbeats_z))]
|
262 |
+
beats_z = beats_z.tolist()
|
263 |
+
downbeats_z = downbeats_z.tolist()
|
264 |
+
|
265 |
+
# make the mask
|
266 |
+
seq_len = self.s2t(signal.duration)
|
267 |
+
mask = torch.zeros(seq_len, device=self.device)
|
268 |
+
|
269 |
+
mask_b4 = self.s2t(before_beat_s)
|
270 |
+
mask_after = self.s2t(after_beat_s)
|
271 |
+
|
272 |
+
if beat_downsample_factor is not None:
|
273 |
+
if beat_downsample_factor < 1:
|
274 |
+
raise ValueError("mask_beat_downsample_factor must be >= 1 or None")
|
275 |
+
else:
|
276 |
+
beat_downsample_factor = 1
|
277 |
+
|
278 |
+
if downbeat_downsample_factor is not None:
|
279 |
+
if downbeat_downsample_factor < 1:
|
280 |
+
raise ValueError("mask_beat_downsample_factor must be >= 1 or None")
|
281 |
+
else:
|
282 |
+
downbeat_downsample_factor = 1
|
283 |
+
|
284 |
+
beats_z = beats_z[::beat_downsample_factor]
|
285 |
+
downbeats_z = downbeats_z[::downbeat_downsample_factor]
|
286 |
+
logging.debug(f"beats_z: {len(beats_z)}")
|
287 |
+
logging.debug(f"downbeats_z: {len(downbeats_z)}")
|
288 |
+
|
289 |
+
if mask_upbeats:
|
290 |
+
for beat_idx in beats_z:
|
291 |
+
_slice = int(beat_idx - mask_b4), int(beat_idx + mask_after)
|
292 |
+
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
293 |
+
_m = torch.ones(num_steps, device=self.device)
|
294 |
+
_m_mask = torch.bernoulli(_m * (1 - dropout))
|
295 |
+
_m = _m * _m_mask.long()
|
296 |
+
|
297 |
+
mask[_slice[0]:_slice[1]] = _m
|
298 |
+
|
299 |
+
if mask_downbeats:
|
300 |
+
for downbeat_idx in downbeats_z:
|
301 |
+
_slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after)
|
302 |
+
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
303 |
+
_m = torch.ones(num_steps, device=self.device)
|
304 |
+
_m_mask = torch.bernoulli(_m * (1 - dropout))
|
305 |
+
_m = _m * _m_mask.long()
|
306 |
+
|
307 |
+
mask[_slice[0]:_slice[1]] = _m
|
308 |
+
|
309 |
+
mask = mask.clamp(0, 1)
|
310 |
+
if invert:
|
311 |
+
mask = 1 - mask
|
312 |
+
|
313 |
+
mask = mask[None, None, :].bool().long()
|
314 |
+
if self.c2f is not None:
|
315 |
+
mask = mask.repeat(1, self.c2f.n_codebooks, 1)
|
316 |
+
else:
|
317 |
+
mask = mask.repeat(1, self.coarse.n_codebooks, 1)
|
318 |
+
return mask
|
319 |
+
|
320 |
+
def set_chunk_size(self, chunk_size_s: float):
|
321 |
+
self.coarse.chunk_size_s = chunk_size_s
|
322 |
+
|
323 |
+
@torch.inference_mode()
|
324 |
+
def coarse_to_fine(
|
325 |
+
self,
|
326 |
+
z: torch.Tensor,
|
327 |
+
mask: torch.Tensor = None,
|
328 |
+
return_mask: bool = False,
|
329 |
+
**kwargs
|
330 |
+
):
|
331 |
+
assert self.c2f is not None, "No coarse2fine model loaded"
|
332 |
+
length = z.shape[-1]
|
333 |
+
chunk_len = self.s2t(self.c2f.chunk_size_s)
|
334 |
+
n_chunks = math.ceil(z.shape[-1] / chunk_len)
|
335 |
+
|
336 |
+
# zero pad to chunk_len
|
337 |
+
if length % chunk_len != 0:
|
338 |
+
pad_len = chunk_len - (length % chunk_len)
|
339 |
+
z = torch.nn.functional.pad(z, (0, pad_len))
|
340 |
+
mask = torch.nn.functional.pad(mask, (0, pad_len), value=1) if mask is not None else None
|
341 |
+
|
342 |
+
n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
|
343 |
+
if n_codebooks_to_append > 0:
|
344 |
+
z = torch.cat([
|
345 |
+
z,
|
346 |
+
torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device)
|
347 |
+
], dim=1)
|
348 |
+
logging.debug(f"appended {n_codebooks_to_append} codebooks to z")
|
349 |
+
|
350 |
+
# set the mask to 0 for all conditioning codebooks
|
351 |
+
if mask is not None:
|
352 |
+
mask = mask.clone()
|
353 |
+
mask[:, :self.c2f.n_conditioning_codebooks, :] = 0
|
354 |
+
|
355 |
+
fine_z = []
|
356 |
+
for i in range(n_chunks):
|
357 |
+
chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len]
|
358 |
+
mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None
|
359 |
+
|
360 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
361 |
+
chunk = self.c2f.generate(
|
362 |
+
codec=self.codec,
|
363 |
+
time_steps=chunk_len,
|
364 |
+
start_tokens=chunk,
|
365 |
+
return_signal=False,
|
366 |
+
mask=mask_chunk,
|
367 |
+
cfg_guidance=None,
|
368 |
+
**kwargs
|
369 |
+
)
|
370 |
+
fine_z.append(chunk)
|
371 |
+
|
372 |
+
fine_z = torch.cat(fine_z, dim=-1)
|
373 |
+
if return_mask:
|
374 |
+
return fine_z[:, :, :length].clone(), apply_mask(fine_z, mask, self.c2f.mask_token)[0][:, :, :length].clone()
|
375 |
+
|
376 |
+
return fine_z[:, :, :length].clone()
|
377 |
+
|
378 |
+
@torch.inference_mode()
|
379 |
+
def coarse_vamp(
|
380 |
+
self,
|
381 |
+
z,
|
382 |
+
mask,
|
383 |
+
return_mask=False,
|
384 |
+
gen_fn=None,
|
385 |
+
**kwargs
|
386 |
+
):
|
387 |
+
# coarse z
|
388 |
+
cz = z[:, : self.coarse.n_codebooks, :].clone()
|
389 |
+
mask = mask[:, : self.coarse.n_codebooks, :]
|
390 |
+
# assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}"
|
391 |
+
|
392 |
+
# cut into chunks, keep the last chunk separate if it's too small
|
393 |
+
chunk_len = self.s2t(self.coarse.chunk_size_s)
|
394 |
+
n_chunks = math.ceil(cz.shape[-1] / chunk_len)
|
395 |
+
last_chunk_len = cz.shape[-1] % chunk_len
|
396 |
+
|
397 |
+
cz_chunks = []
|
398 |
+
mask_chunks = []
|
399 |
+
for i in range(n_chunks):
|
400 |
+
chunk = cz[:, :, i * chunk_len : (i + 1) * chunk_len]
|
401 |
+
mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len]
|
402 |
+
|
403 |
+
# make sure that the very first and last timestep of each chunk is 0 so that we don't get a weird
|
404 |
+
# discontinuity when we stitch the chunks back together
|
405 |
+
# only if there's already a 0 somewhere in the chunk
|
406 |
+
if torch.any(mask_chunk == 0):
|
407 |
+
mask_chunk[:, :, 0] = 0
|
408 |
+
mask_chunk[:, :, -1] = 0
|
409 |
+
|
410 |
+
cz_chunks.append(chunk)
|
411 |
+
mask_chunks.append(mask_chunk)
|
412 |
+
|
413 |
+
# now vamp each chunk
|
414 |
+
cz_masked_chunks = []
|
415 |
+
cz_vamped_chunks = []
|
416 |
+
for chunk, mask_chunk in zip(cz_chunks, mask_chunks):
|
417 |
+
cz_masked_chunk, mask_chunk = apply_mask(chunk, mask_chunk, self.coarse.mask_token)
|
418 |
+
cz_masked_chunk = cz_masked_chunk[:, : self.coarse.n_codebooks, :]
|
419 |
+
cz_masked_chunks.append(cz_masked_chunk)
|
420 |
+
|
421 |
+
|
422 |
+
gen_fn = gen_fn or self.coarse.generate
|
423 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
424 |
+
c_vamp_chunk = gen_fn(
|
425 |
+
codec=self.codec,
|
426 |
+
time_steps=chunk_len,
|
427 |
+
start_tokens=cz_masked_chunk,
|
428 |
+
return_signal=False,
|
429 |
+
mask=mask_chunk,
|
430 |
+
**kwargs
|
431 |
+
)
|
432 |
+
cz_vamped_chunks.append(c_vamp_chunk)
|
433 |
+
|
434 |
+
# stitch the chunks back together
|
435 |
+
cz_masked = torch.cat(cz_masked_chunks, dim=-1)
|
436 |
+
c_vamp = torch.cat(cz_vamped_chunks, dim=-1)
|
437 |
+
|
438 |
+
# add the fine codes back in
|
439 |
+
c_vamp = torch.cat(
|
440 |
+
[c_vamp, z[:, self.coarse.n_codebooks :, :]],
|
441 |
+
dim=1
|
442 |
+
)
|
443 |
+
|
444 |
+
if return_mask:
|
445 |
+
return c_vamp, cz_masked
|
446 |
+
|
447 |
+
return c_vamp
|
448 |
+
|
449 |
+
def build_mask(self,
|
450 |
+
z: torch.Tensor,
|
451 |
+
sig: AudioSignal = None,
|
452 |
+
rand_mask_intensity: float = 1.0,
|
453 |
+
prefix_s: float = 0.0,
|
454 |
+
suffix_s: float = 0.0,
|
455 |
+
periodic_prompt: int = 7,
|
456 |
+
periodic_prompt_width: int = 1,
|
457 |
+
onset_mask_width: int = 0,
|
458 |
+
_dropout: float = 0.0,
|
459 |
+
upper_codebook_mask: int = 3,
|
460 |
+
ncc: int = 0,
|
461 |
+
):
|
462 |
+
mask = linear_random(z, rand_mask_intensity)
|
463 |
+
mask = mask_and(
|
464 |
+
mask,
|
465 |
+
inpaint(z, self.s2t(prefix_s), self.s2t(suffix_s)),
|
466 |
+
)
|
467 |
+
|
468 |
+
pmask = periodic_mask(z, periodic_prompt, periodic_prompt_width, random_roll=True)
|
469 |
+
mask = mask_and(mask, pmask)
|
470 |
+
|
471 |
+
if onset_mask_width > 0:
|
472 |
+
assert sig is not None, f"must provide a signal to use onset mask"
|
473 |
+
mask = mask_and(
|
474 |
+
mask, onset_mask(
|
475 |
+
sig, z, self,
|
476 |
+
width=onset_mask_width
|
477 |
+
)
|
478 |
+
)
|
479 |
+
|
480 |
+
mask = dropout(mask, _dropout)
|
481 |
+
mask = codebook_unmask(mask, ncc)
|
482 |
+
|
483 |
+
mask = codebook_mask(mask, int(upper_codebook_mask), None)
|
484 |
+
return mask
|
485 |
+
|
486 |
+
def vamp(
|
487 |
+
self,
|
488 |
+
codes: torch.Tensor,
|
489 |
+
mask: torch.Tensor,
|
490 |
+
batch_size: int = 1,
|
491 |
+
feedback_steps: int = 1,
|
492 |
+
time_stretch_factor: int = 1,
|
493 |
+
return_mask: bool = False,
|
494 |
+
**kwargs,
|
495 |
+
):
|
496 |
+
z = codes
|
497 |
+
|
498 |
+
# expand z to batch size
|
499 |
+
z = z.expand(batch_size, -1, -1)
|
500 |
+
mask = mask.expand(batch_size, -1, -1)
|
501 |
+
|
502 |
+
# stretch mask and z to match the time stretch factor
|
503 |
+
# we'll add (stretch_factor - 1) mask tokens in between each timestep of z
|
504 |
+
# and we'll make the mask 1 in all the new slots we added
|
505 |
+
if time_stretch_factor > 1:
|
506 |
+
z = z.repeat_interleave(time_stretch_factor, dim=-1)
|
507 |
+
mask = mask.repeat_interleave(time_stretch_factor, dim=-1)
|
508 |
+
added_mask = torch.ones_like(mask)
|
509 |
+
added_mask[:, :, ::time_stretch_factor] = 0
|
510 |
+
mask = mask.bool() | added_mask.bool()
|
511 |
+
mask = mask.long()
|
512 |
+
|
513 |
+
# the forward pass
|
514 |
+
logging.debug(z.shape)
|
515 |
+
logging.debug("coarse!")
|
516 |
+
zv, mask_z = self.coarse_vamp(
|
517 |
+
z,
|
518 |
+
mask=mask,
|
519 |
+
return_mask=True,
|
520 |
+
**kwargs
|
521 |
+
)
|
522 |
+
|
523 |
+
# add the top codebooks back in
|
524 |
+
if zv.shape[1] < z.shape[1]:
|
525 |
+
logging.debug(f"adding {z.shape[1] - zv.shape[1]} codebooks back in")
|
526 |
+
zv = torch.cat(
|
527 |
+
[zv, z[:, self.coarse.n_codebooks :, :]],
|
528 |
+
dim=1
|
529 |
+
)
|
530 |
+
|
531 |
+
# now, coarse2fine
|
532 |
+
logging.debug(f"coarse2fine!")
|
533 |
+
zv, fine_zv_mask = self.coarse_to_fine(
|
534 |
+
zv,
|
535 |
+
mask=mask,
|
536 |
+
typical_filtering=True,
|
537 |
+
_sampling_steps=[2],
|
538 |
+
return_mask=True
|
539 |
+
)
|
540 |
+
mask_z = torch.cat(
|
541 |
+
[mask_z[:, :self.coarse.n_codebooks, :], fine_zv_mask[:, self.coarse.n_codebooks:, :]],
|
542 |
+
dim=1
|
543 |
+
)
|
544 |
+
|
545 |
+
z = zv
|
546 |
+
|
547 |
+
if return_mask:
|
548 |
+
return z, mask_z.cpu(),
|
549 |
+
else:
|
550 |
+
return z
|
551 |
+
|
552 |
+
def visualize_codes(self, z: torch.Tensor):
|
553 |
+
import matplotlib.pyplot as plt
|
554 |
+
# make sure the figsize is square when imshow is called
|
555 |
+
fig = plt.figure(figsize=(10, 7))
|
556 |
+
# in subplots, plot z[0] and the mask
|
557 |
+
# set title to "codes" and "mask"
|
558 |
+
fig.add_subplot(2, 1, 1)
|
559 |
+
plt.imshow(z[0].cpu().numpy(), aspect='auto', origin='lower', cmap="tab20")
|
560 |
+
plt.title("codes")
|
561 |
+
plt.ylabel("codebook index")
|
562 |
+
# set the xticks to seconds
|
563 |
+
|
564 |
+
|
565 |
+
if __name__ == "__main__":
|
566 |
+
import audiotools as at
|
567 |
+
import logging
|
568 |
+
logger = logging.getLogger()
|
569 |
+
logger.setLevel(logging.INFO)
|
570 |
+
torch.set_logging.debugoptions(threshold=10000)
|
571 |
+
at.util.seed(42)
|
572 |
+
|
573 |
+
interface = Interface(
|
574 |
+
coarse_ckpt="./models/vampnet/coarse.pth",
|
575 |
+
coarse2fine_ckpt="./models/vampnet/c2f.pth",
|
576 |
+
codec_ckpt="./models/vampnet/codec.pth",
|
577 |
+
device="cuda",
|
578 |
+
wavebeat_ckpt="./models/wavebeat.pth"
|
579 |
+
)
|
580 |
+
|
581 |
+
|
582 |
+
sig = at.AudioSignal('assets/example.wav')
|
583 |
+
|
584 |
+
z = interface.encode(sig)
|
585 |
+
|
586 |
+
|
587 |
+
mask = interface.build_mask(
|
588 |
+
z=z,
|
589 |
+
sig=sig,
|
590 |
+
rand_mask_intensity=1.0,
|
591 |
+
prefix_s=0.0,
|
592 |
+
suffix_s=0.0,
|
593 |
+
periodic_prompt=7,
|
594 |
+
periodic_prompt2=7,
|
595 |
+
periodic_prompt_width=1,
|
596 |
+
onset_mask_width=5,
|
597 |
+
_dropout=0.0,
|
598 |
+
upper_codebook_mask=3,
|
599 |
+
upper_codebook_mask_2=None,
|
600 |
+
ncc=0,
|
601 |
+
)
|
602 |
+
|
603 |
+
zv, mask_z = interface.coarse_vamp(
|
604 |
+
z,
|
605 |
+
mask=mask,
|
606 |
+
return_mask=True,
|
607 |
+
gen_fn=interface.coarse.generate
|
608 |
+
)
|
609 |
+
|
610 |
+
|
611 |
+
use_coarse2fine = True
|
612 |
+
if use_coarse2fine:
|
613 |
+
zv = interface.coarse_to_fine(zv, mask=mask)
|
614 |
+
breakpoint()
|
615 |
+
|
616 |
+
mask = interface.decode(mask_z).cpu()
|
617 |
+
|
618 |
+
sig = interface.decode(zv).cpu()
|
619 |
+
|
620 |
+
|
621 |
+
logging.debug("done")
|
622 |
+
|
623 |
+
|
vampnet/mask.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from audiotools import AudioSignal
|
5 |
+
|
6 |
+
from .util import scalar_to_batch_tensor
|
7 |
+
|
8 |
+
def _gamma(r):
|
9 |
+
return (r * torch.pi / 2).cos().clamp(1e-10, 1.0)
|
10 |
+
|
11 |
+
def _invgamma(y):
|
12 |
+
if not torch.is_tensor(y):
|
13 |
+
y = torch.tensor(y)[None]
|
14 |
+
return 2 * y.acos() / torch.pi
|
15 |
+
|
16 |
+
def full_mask(x: torch.Tensor):
|
17 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
18 |
+
return torch.ones_like(x).long()
|
19 |
+
|
20 |
+
def empty_mask(x: torch.Tensor):
|
21 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
22 |
+
return torch.zeros_like(x).long()
|
23 |
+
|
24 |
+
def apply_mask(
|
25 |
+
x: torch.Tensor,
|
26 |
+
mask: torch.Tensor,
|
27 |
+
mask_token: int
|
28 |
+
):
|
29 |
+
assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}"
|
30 |
+
assert mask.shape == x.shape, f"mask must be same shape as x, but got {mask.shape} and {x.shape}"
|
31 |
+
assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}"
|
32 |
+
assert ~torch.any(mask > 1), "mask must be binary"
|
33 |
+
assert ~torch.any(mask < 0), "mask must be binary"
|
34 |
+
|
35 |
+
fill_x = torch.full_like(x, mask_token)
|
36 |
+
x = x * (1 - mask) + fill_x * mask
|
37 |
+
|
38 |
+
return x, mask
|
39 |
+
|
40 |
+
def random(
|
41 |
+
x: torch.Tensor,
|
42 |
+
r: torch.Tensor
|
43 |
+
):
|
44 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
45 |
+
if not isinstance(r, torch.Tensor):
|
46 |
+
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
|
47 |
+
|
48 |
+
r = _gamma(r)[:, None, None]
|
49 |
+
probs = torch.ones_like(x) * r
|
50 |
+
|
51 |
+
mask = torch.bernoulli(probs)
|
52 |
+
mask = mask.round().long()
|
53 |
+
|
54 |
+
return mask
|
55 |
+
|
56 |
+
def linear_random(
|
57 |
+
x: torch.Tensor,
|
58 |
+
r: torch.Tensor,
|
59 |
+
):
|
60 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
61 |
+
if not isinstance(r, torch.Tensor):
|
62 |
+
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
|
63 |
+
r = r[:, None, None]
|
64 |
+
|
65 |
+
probs = torch.ones_like(x).to(x.device).float()
|
66 |
+
# expand to batch and codebook dims
|
67 |
+
probs = probs.expand(x.shape[0], x.shape[1], -1)
|
68 |
+
probs = probs * r
|
69 |
+
|
70 |
+
mask = torch.bernoulli(probs)
|
71 |
+
mask = mask.round().long()
|
72 |
+
|
73 |
+
return mask
|
74 |
+
|
75 |
+
def inpaint(x: torch.Tensor,
|
76 |
+
n_prefix,
|
77 |
+
n_suffix,
|
78 |
+
):
|
79 |
+
assert n_prefix is not None
|
80 |
+
assert n_suffix is not None
|
81 |
+
|
82 |
+
mask = full_mask(x)
|
83 |
+
|
84 |
+
# if we have a prefix or suffix, set their mask prob to 0
|
85 |
+
if n_prefix > 0:
|
86 |
+
if not isinstance(n_prefix, torch.Tensor):
|
87 |
+
n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device)
|
88 |
+
for i, n in enumerate(n_prefix):
|
89 |
+
if n > 0:
|
90 |
+
mask[i, :, :n] = 0.0
|
91 |
+
if n_suffix > 0:
|
92 |
+
if not isinstance(n_suffix, torch.Tensor):
|
93 |
+
n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device)
|
94 |
+
for i, n in enumerate(n_suffix):
|
95 |
+
if n > 0:
|
96 |
+
mask[i, :, -n:] = 0.0
|
97 |
+
|
98 |
+
|
99 |
+
return mask
|
100 |
+
|
101 |
+
def periodic_mask(x: torch.Tensor,
|
102 |
+
period: int,width: int = 1,
|
103 |
+
random_roll=False,
|
104 |
+
):
|
105 |
+
mask = full_mask(x)
|
106 |
+
if period == 0:
|
107 |
+
return mask
|
108 |
+
|
109 |
+
if not isinstance(period, torch.Tensor):
|
110 |
+
period = scalar_to_batch_tensor(period, x.shape[0])
|
111 |
+
for i, factor in enumerate(period):
|
112 |
+
if factor == 0:
|
113 |
+
continue
|
114 |
+
for j in range(mask.shape[-1]):
|
115 |
+
if j % factor == 0:
|
116 |
+
# figure out how wide the mask should be
|
117 |
+
j_start = max(0, j - width // 2 )
|
118 |
+
j_end = min(mask.shape[-1] - 1, j + width // 2 ) + 1
|
119 |
+
# flip a coin for each position in the mask
|
120 |
+
j_mask = torch.bernoulli(torch.ones(j_end - j_start))
|
121 |
+
assert torch.all(j_mask == 1)
|
122 |
+
j_fill = torch.ones_like(j_mask) * (1 - j_mask)
|
123 |
+
assert torch.all(j_fill == 0)
|
124 |
+
# fill
|
125 |
+
mask[i, :, j_start:j_end] = j_fill
|
126 |
+
if random_roll:
|
127 |
+
# add a random offset to the mask
|
128 |
+
offset = torch.randint(0, period[0], (1,))
|
129 |
+
mask = torch.roll(mask, offset.item(), dims=-1)
|
130 |
+
|
131 |
+
return mask
|
132 |
+
|
133 |
+
def codebook_unmask(
|
134 |
+
mask: torch.Tensor,
|
135 |
+
n_conditioning_codebooks: int
|
136 |
+
):
|
137 |
+
if n_conditioning_codebooks == None:
|
138 |
+
return mask
|
139 |
+
# if we have any conditioning codebooks, set their mask to 0
|
140 |
+
mask = mask.clone()
|
141 |
+
mask[:, :n_conditioning_codebooks, :] = 0
|
142 |
+
return mask
|
143 |
+
|
144 |
+
def codebook_mask(mask: torch.Tensor, val1: int, val2: int = None):
|
145 |
+
mask = mask.clone()
|
146 |
+
mask[:, val1:, :] = 1
|
147 |
+
# val2 = val2 or val1
|
148 |
+
# vs = torch.linspace(val1, val2, mask.shape[1])
|
149 |
+
# for t, v in enumerate(vs):
|
150 |
+
# v = int(v)
|
151 |
+
# mask[:, v:, t] = 1
|
152 |
+
|
153 |
+
return mask
|
154 |
+
|
155 |
+
def mask_and(
|
156 |
+
mask1: torch.Tensor,
|
157 |
+
mask2: torch.Tensor
|
158 |
+
):
|
159 |
+
assert mask1.shape == mask2.shape, "masks must be same shape"
|
160 |
+
return torch.min(mask1, mask2)
|
161 |
+
|
162 |
+
def dropout(
|
163 |
+
mask: torch.Tensor,
|
164 |
+
p: float,
|
165 |
+
):
|
166 |
+
assert 0 <= p <= 1, "p must be between 0 and 1"
|
167 |
+
assert mask.max() <= 1, "mask must be binary"
|
168 |
+
assert mask.min() >= 0, "mask must be binary"
|
169 |
+
mask = (~mask.bool()).float()
|
170 |
+
mask = torch.bernoulli(mask * (1 - p))
|
171 |
+
mask = ~mask.round().bool()
|
172 |
+
return mask.long()
|
173 |
+
|
174 |
+
def mask_or(
|
175 |
+
mask1: torch.Tensor,
|
176 |
+
mask2: torch.Tensor
|
177 |
+
):
|
178 |
+
assert mask1.shape == mask2.shape, f"masks must be same shape, but got {mask1.shape} and {mask2.shape}"
|
179 |
+
assert mask1.max() <= 1, "mask1 must be binary"
|
180 |
+
assert mask2.max() <= 1, "mask2 must be binary"
|
181 |
+
assert mask1.min() >= 0, "mask1 must be binary"
|
182 |
+
assert mask2.min() >= 0, "mask2 must be binary"
|
183 |
+
return (mask1 + mask2).clamp(0, 1)
|
184 |
+
|
185 |
+
def time_stretch_mask(
|
186 |
+
x: torch.Tensor,
|
187 |
+
stretch_factor: int,
|
188 |
+
):
|
189 |
+
assert stretch_factor >= 1, "stretch factor must be >= 1"
|
190 |
+
c_seq_len = x.shape[-1]
|
191 |
+
x = x.repeat_interleave(stretch_factor, dim=-1)
|
192 |
+
|
193 |
+
# trim cz to the original length
|
194 |
+
x = x[:, :, :c_seq_len]
|
195 |
+
|
196 |
+
mask = periodic_mask(x, stretch_factor, width=1)
|
197 |
+
return mask
|
198 |
+
|
199 |
+
def onset_mask(
|
200 |
+
sig: AudioSignal,
|
201 |
+
z: torch.Tensor,
|
202 |
+
interface,
|
203 |
+
width: int = 1,
|
204 |
+
):
|
205 |
+
import librosa
|
206 |
+
|
207 |
+
onset_frame_idxs = librosa.onset.onset_detect(
|
208 |
+
y=sig.samples[0][0].detach().cpu().numpy(), sr=sig.sample_rate,
|
209 |
+
hop_length=interface.codec.hop_length,
|
210 |
+
backtrack=True,
|
211 |
+
)
|
212 |
+
if len(onset_frame_idxs) == 0:
|
213 |
+
print("no onsets detected")
|
214 |
+
print("onset_frame_idxs", onset_frame_idxs)
|
215 |
+
print("mask shape", z.shape)
|
216 |
+
|
217 |
+
mask = torch.ones_like(z)
|
218 |
+
for idx in onset_frame_idxs:
|
219 |
+
mask[:, :, idx-width:idx+width] = 0
|
220 |
+
|
221 |
+
return mask
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
if __name__ == "__main__":
|
226 |
+
sig = AudioSignal("assets/example.wav")
|
vampnet/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import audiotools
|
2 |
+
|
3 |
+
audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
|
4 |
+
audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"]
|
5 |
+
|
6 |
+
from .transformer import VampNet
|
vampnet/modules/activations.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
|
9 |
+
class NewGELU(nn.Module):
|
10 |
+
"""
|
11 |
+
Implementation of the GELU activation function currently in Google BERT repo
|
12 |
+
(identical to OpenAI GPT). Also see the Gaussian Error Linear Units
|
13 |
+
paper: https://arxiv.org/abs/1606.08415
|
14 |
+
"""
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
return (
|
18 |
+
0.5
|
19 |
+
* x
|
20 |
+
* (
|
21 |
+
1.0
|
22 |
+
+ torch.tanh(
|
23 |
+
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
|
24 |
+
)
|
25 |
+
)
|
26 |
+
)
|
27 |
+
|
28 |
+
class GatedGELU(nn.Module):
|
29 |
+
def __init__(self):
|
30 |
+
super().__init__()
|
31 |
+
self.gelu = NewGELU()
|
32 |
+
|
33 |
+
def forward(self, x, dim: int = -1):
|
34 |
+
p1, p2 = x.chunk(2, dim=dim)
|
35 |
+
return p1 * self.gelu(p2)
|
36 |
+
|
37 |
+
class Snake1d(nn.Module):
|
38 |
+
def __init__(self, channels):
|
39 |
+
super().__init__()
|
40 |
+
self.alpha = nn.Parameter(torch.ones(channels))
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
return x + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * x).pow(2)
|
44 |
+
|
45 |
+
def get_activation(name: str = "relu"):
|
46 |
+
if name == "relu":
|
47 |
+
return nn.ReLU
|
48 |
+
elif name == "gelu":
|
49 |
+
return NewGELU
|
50 |
+
elif name == "geglu":
|
51 |
+
return GatedGELU
|
52 |
+
elif name == "snake":
|
53 |
+
return Snake1d
|
54 |
+
else:
|
55 |
+
raise ValueError(f"Unrecognized activation {name}")
|
vampnet/modules/layers.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from typing import Optional
|
3 |
+
from typing import Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
from torch.nn.utils import weight_norm
|
10 |
+
|
11 |
+
# Scripting this brings model speed up 1.4x
|
12 |
+
@torch.jit.script
|
13 |
+
def snake(x, alpha):
|
14 |
+
shape = x.shape
|
15 |
+
x = x.reshape(shape[0], shape[1], -1)
|
16 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
17 |
+
x = x.reshape(shape)
|
18 |
+
return x
|
19 |
+
|
20 |
+
|
21 |
+
class Snake1d(nn.Module):
|
22 |
+
def __init__(self, channels):
|
23 |
+
super().__init__()
|
24 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
return snake(x, self.alpha)
|
28 |
+
|
29 |
+
|
30 |
+
def num_params(model):
|
31 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
32 |
+
|
33 |
+
|
34 |
+
def recurse_children(module, fn):
|
35 |
+
for child in module.children():
|
36 |
+
if isinstance(child, nn.ModuleList):
|
37 |
+
for c in child:
|
38 |
+
yield recurse_children(c, fn)
|
39 |
+
if isinstance(child, nn.ModuleDict):
|
40 |
+
for c in child.values():
|
41 |
+
yield recurse_children(c, fn)
|
42 |
+
|
43 |
+
yield recurse_children(child, fn)
|
44 |
+
yield fn(child)
|
45 |
+
|
46 |
+
|
47 |
+
def WNConv1d(*args, **kwargs):
|
48 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
49 |
+
|
50 |
+
|
51 |
+
def WNConvTranspose1d(*args, **kwargs):
|
52 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
53 |
+
|
54 |
+
|
55 |
+
class SequentialWithFiLM(nn.Module):
|
56 |
+
"""
|
57 |
+
handy wrapper for nn.Sequential that allows FiLM layers to be
|
58 |
+
inserted in between other layers.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self, *layers):
|
62 |
+
super().__init__()
|
63 |
+
self.layers = nn.ModuleList(layers)
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def has_film(module):
|
67 |
+
mod_has_film = any(
|
68 |
+
[res for res in recurse_children(module, lambda c: isinstance(c, FiLM))]
|
69 |
+
)
|
70 |
+
return mod_has_film
|
71 |
+
|
72 |
+
def forward(self, x, cond):
|
73 |
+
for layer in self.layers:
|
74 |
+
if self.has_film(layer):
|
75 |
+
x = layer(x, cond)
|
76 |
+
else:
|
77 |
+
x = layer(x)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class FiLM(nn.Module):
|
82 |
+
def __init__(self, input_dim: int, output_dim: int):
|
83 |
+
super().__init__()
|
84 |
+
|
85 |
+
self.input_dim = input_dim
|
86 |
+
self.output_dim = output_dim
|
87 |
+
|
88 |
+
if input_dim > 0:
|
89 |
+
self.beta = nn.Linear(input_dim, output_dim)
|
90 |
+
self.gamma = nn.Linear(input_dim, output_dim)
|
91 |
+
|
92 |
+
def forward(self, x, r):
|
93 |
+
if self.input_dim == 0:
|
94 |
+
return x
|
95 |
+
else:
|
96 |
+
beta, gamma = self.beta(r), self.gamma(r)
|
97 |
+
beta, gamma = (
|
98 |
+
beta.view(x.size(0), self.output_dim, 1),
|
99 |
+
gamma.view(x.size(0), self.output_dim, 1),
|
100 |
+
)
|
101 |
+
x = x * (gamma + 1) + beta
|
102 |
+
return x
|
103 |
+
|
104 |
+
|
105 |
+
class CodebookEmbedding(nn.Module):
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
vocab_size: int,
|
109 |
+
latent_dim: int,
|
110 |
+
n_codebooks: int,
|
111 |
+
emb_dim: int,
|
112 |
+
special_tokens: Optional[Tuple[str]] = None,
|
113 |
+
):
|
114 |
+
super().__init__()
|
115 |
+
self.n_codebooks = n_codebooks
|
116 |
+
self.emb_dim = emb_dim
|
117 |
+
self.latent_dim = latent_dim
|
118 |
+
self.vocab_size = vocab_size
|
119 |
+
|
120 |
+
if special_tokens is not None:
|
121 |
+
for tkn in special_tokens:
|
122 |
+
self.special = nn.ParameterDict(
|
123 |
+
{
|
124 |
+
tkn: nn.Parameter(torch.randn(n_codebooks, self.latent_dim))
|
125 |
+
for tkn in special_tokens
|
126 |
+
}
|
127 |
+
)
|
128 |
+
self.special_idxs = {
|
129 |
+
tkn: i + vocab_size for i, tkn in enumerate(special_tokens)
|
130 |
+
}
|
131 |
+
|
132 |
+
self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
|
133 |
+
|
134 |
+
def from_codes(self, codes: torch.Tensor, codec):
|
135 |
+
"""
|
136 |
+
get a sequence of continuous embeddings from a sequence of discrete codes.
|
137 |
+
unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens
|
138 |
+
necessary for the language model, like <MASK>.
|
139 |
+
"""
|
140 |
+
n_codebooks = codes.shape[1]
|
141 |
+
latent = []
|
142 |
+
for i in range(n_codebooks):
|
143 |
+
c = codes[:, i, :]
|
144 |
+
|
145 |
+
lookup_table = codec.quantizer.quantizers[i].codebook.weight
|
146 |
+
if hasattr(self, "special"):
|
147 |
+
special_lookup = torch.cat(
|
148 |
+
[self.special[tkn][i : i + 1] for tkn in self.special], dim=0
|
149 |
+
)
|
150 |
+
lookup_table = torch.cat([lookup_table, special_lookup], dim=0)
|
151 |
+
|
152 |
+
l = F.embedding(c, lookup_table).transpose(1, 2)
|
153 |
+
latent.append(l)
|
154 |
+
|
155 |
+
latent = torch.cat(latent, dim=1)
|
156 |
+
return latent
|
157 |
+
|
158 |
+
def forward(self, latents: torch.Tensor):
|
159 |
+
"""
|
160 |
+
project a sequence of latents to a sequence of embeddings
|
161 |
+
"""
|
162 |
+
x = self.out_proj(latents)
|
163 |
+
return x
|
164 |
+
|
vampnet/modules/transformer.py
ADDED
@@ -0,0 +1,965 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import logging
|
3 |
+
from typing import Optional, Tuple, Union, List
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange
|
10 |
+
import loralib as lora
|
11 |
+
import audiotools as at
|
12 |
+
|
13 |
+
from .activations import get_activation
|
14 |
+
from .layers import CodebookEmbedding
|
15 |
+
from .layers import FiLM
|
16 |
+
from .layers import SequentialWithFiLM
|
17 |
+
from .layers import WNConv1d
|
18 |
+
from ..util import scalar_to_batch_tensor, codebook_flatten, codebook_unflatten
|
19 |
+
from ..mask import _gamma
|
20 |
+
|
21 |
+
LORA_R = 8
|
22 |
+
|
23 |
+
# def log(t, eps=1e-20):
|
24 |
+
# return torch.log(t + eps)
|
25 |
+
|
26 |
+
|
27 |
+
def gumbel_noise_like(t):
|
28 |
+
noise = torch.zeros_like(t).uniform_(1e-20, 1)
|
29 |
+
return -torch.log(-torch.log(noise))
|
30 |
+
|
31 |
+
|
32 |
+
def gumbel_sample(t, temperature=1.0, dim=-1):
|
33 |
+
return ((t / max(temperature, 1e-10)) + gumbel_noise_like(t)).argmax(dim=dim)
|
34 |
+
|
35 |
+
|
36 |
+
class RMSNorm(nn.Module):
|
37 |
+
def __init__(self, hidden_size: int, eps=1e-6):
|
38 |
+
super().__init__()
|
39 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
40 |
+
self.var_eps = eps
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
"""Returns root mean square normalized version of input `x`
|
44 |
+
# T5 uses a layer_norm which only scales and doesn't shift, which is also known
|
45 |
+
# as Root Mean Square Layer Normalization https://arxiv.org/abs/1910.07467
|
46 |
+
# thus varience is calculated w/o mean and there is no bias
|
47 |
+
Parameters
|
48 |
+
----------
|
49 |
+
x : Tensor[B x T x D]
|
50 |
+
Returns
|
51 |
+
-------
|
52 |
+
Tensor[B x T x D]
|
53 |
+
"""
|
54 |
+
var = x.pow(2).mean(-1, keepdim=True)
|
55 |
+
x = x * torch.rsqrt(var + self.var_eps)
|
56 |
+
|
57 |
+
return self.weight * x
|
58 |
+
|
59 |
+
|
60 |
+
class FeedForward(nn.Module):
|
61 |
+
def __init__(
|
62 |
+
self, d_model: int = 512, dropout: float = 0.1, activation: str = "geglu"
|
63 |
+
):
|
64 |
+
super().__init__()
|
65 |
+
factor = 2 if activation == "geglu" else 1
|
66 |
+
self.w_1 = lora.Linear(d_model, d_model * 4, bias=False, r=LORA_R)
|
67 |
+
self.w_2 = lora.Linear(d_model * 4 // factor, d_model, bias=False, r=LORA_R)
|
68 |
+
self.drop = nn.Dropout(dropout)
|
69 |
+
self.act = get_activation(activation)()
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
"""Computes position-wise feed-forward layer
|
73 |
+
Parameters
|
74 |
+
----------
|
75 |
+
x : Tensor[B x T x D]
|
76 |
+
Returns
|
77 |
+
-------
|
78 |
+
Tensor[B x T x D]
|
79 |
+
"""
|
80 |
+
x = self.w_1(x)
|
81 |
+
x = self.act(x)
|
82 |
+
x = self.drop(x)
|
83 |
+
x = self.w_2(x)
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
class MultiHeadRelativeAttention(nn.Module):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
n_head: int = 8,
|
91 |
+
d_model: int = 512,
|
92 |
+
dropout: float = 0.1,
|
93 |
+
bidirectional: bool = True,
|
94 |
+
has_relative_attention_bias: bool = True,
|
95 |
+
attention_num_buckets: int = 32,
|
96 |
+
attention_max_distance: int = 128,
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
d_head = d_model // n_head
|
100 |
+
self.n_head = n_head
|
101 |
+
self.d_head = d_head
|
102 |
+
self.bidirectional = bidirectional
|
103 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
104 |
+
self.attention_num_buckets = attention_num_buckets
|
105 |
+
self.attention_max_distance = attention_max_distance
|
106 |
+
|
107 |
+
# Create linear query, key, value projections
|
108 |
+
self.w_qs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
|
109 |
+
self.w_ks = nn.Linear(d_model, d_model, bias=False)
|
110 |
+
self.w_vs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
|
111 |
+
|
112 |
+
# Create linear final output projection
|
113 |
+
self.fc = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
|
114 |
+
|
115 |
+
# Dropout for attention output weights
|
116 |
+
self.dropout = nn.Dropout(dropout)
|
117 |
+
|
118 |
+
# Create relative positional embeddings (if turned on)
|
119 |
+
if has_relative_attention_bias:
|
120 |
+
self.relative_attention_bias = nn.Embedding(attention_num_buckets, n_head)
|
121 |
+
|
122 |
+
def _relative_position_bucket(self, relative_position):
|
123 |
+
"""Converts unbounded relative position into bounded set of buckets
|
124 |
+
with half "exact" buckets (1 position = 1 bucket) and half "log-spaced"
|
125 |
+
buckets
|
126 |
+
Parameters
|
127 |
+
----------
|
128 |
+
relative_position : Tensor[T_q x T_kv]
|
129 |
+
Relative positions between queries and key_value items
|
130 |
+
Returns
|
131 |
+
-------
|
132 |
+
Tensor[T_q x T_kv]
|
133 |
+
Input relative positions converted into buckets
|
134 |
+
"""
|
135 |
+
relative_buckets = 0
|
136 |
+
num_buckets = self.attention_num_buckets
|
137 |
+
max_distance = self.attention_max_distance
|
138 |
+
|
139 |
+
# Convert relative position for (-inf, inf) to [0, inf]
|
140 |
+
# Negative relative positions correspond to past
|
141 |
+
# Positive relative positions correspond to future
|
142 |
+
if self.bidirectional:
|
143 |
+
# use half buckets for each side (past / future)
|
144 |
+
num_buckets //= 2
|
145 |
+
|
146 |
+
# Shift the position positions by `num_buckets` to wrap around
|
147 |
+
# negative positions
|
148 |
+
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
149 |
+
relative_position = torch.abs(relative_position)
|
150 |
+
else:
|
151 |
+
# If not bidirectional, ignore positive positions and wrap
|
152 |
+
# negative positions to positive
|
153 |
+
relative_position = -torch.min(
|
154 |
+
relative_position, torch.zeros_like(relative_position)
|
155 |
+
)
|
156 |
+
|
157 |
+
# Allocate half of the buckets are for exact increments in positions
|
158 |
+
max_exact = num_buckets // 2
|
159 |
+
is_small = relative_position < max_exact
|
160 |
+
|
161 |
+
# The other half of the buckets are for logarithmically bigger bins in
|
162 |
+
# positions up to `max_distance`
|
163 |
+
relative_postion_if_large = max_exact + (
|
164 |
+
torch.log(relative_position.float() / max_exact)
|
165 |
+
/ math.log(max_distance / max_exact)
|
166 |
+
* (num_buckets - max_exact)
|
167 |
+
).to(torch.long)
|
168 |
+
|
169 |
+
# Clip the max relative position to `num_buckets - 1`
|
170 |
+
relative_postion_if_large = torch.min(
|
171 |
+
relative_postion_if_large,
|
172 |
+
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
173 |
+
)
|
174 |
+
|
175 |
+
# Choose relative buckets based on small or large positions
|
176 |
+
relative_buckets += torch.where(
|
177 |
+
is_small, relative_position, relative_postion_if_large
|
178 |
+
)
|
179 |
+
|
180 |
+
return relative_buckets
|
181 |
+
|
182 |
+
def compute_bias(self, query_length, key_length):
|
183 |
+
"""Computes a position bias scalar for each index in query_length x key_length
|
184 |
+
Parameters
|
185 |
+
----------
|
186 |
+
query_length : int
|
187 |
+
key_length : int
|
188 |
+
Returns
|
189 |
+
-------
|
190 |
+
Tensor[heads x 1 x T_q x T_kv]
|
191 |
+
Position bias to be applied on attention logits
|
192 |
+
"""
|
193 |
+
|
194 |
+
query_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
195 |
+
key_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
196 |
+
relative_position = key_position - query_position
|
197 |
+
|
198 |
+
# Convert relative position to buckets
|
199 |
+
relative_position_bucket = self._relative_position_bucket(relative_position)
|
200 |
+
relative_position_bucket = relative_position_bucket.to(
|
201 |
+
self.relative_attention_bias.weight.device
|
202 |
+
)
|
203 |
+
|
204 |
+
# Index attention bias values
|
205 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
206 |
+
values = rearrange(values, "q k h -> h 1 q k")
|
207 |
+
|
208 |
+
return values
|
209 |
+
|
210 |
+
def forward(self, q, k, v, mask=None, position_bias=None):
|
211 |
+
"""Computes attention over (keys, values) for every timestep in query
|
212 |
+
Parameters
|
213 |
+
----------
|
214 |
+
q : Tensor[B x T_q x d_model]
|
215 |
+
Query vectors
|
216 |
+
k : Tensor[B x T_kv x d_model]
|
217 |
+
Key vectors to compute attention over
|
218 |
+
v : Tensor[B x T_kv x d_model]
|
219 |
+
Value vectors corresponding to the keys
|
220 |
+
mask : Tensor[B x T_q x T_kv], optional
|
221 |
+
position_bias: Tensor[head x 1 x T_q x T_kv]
|
222 |
+
Returns
|
223 |
+
-------
|
224 |
+
Tensor[B x T_q x d_model]
|
225 |
+
Outputs after attending (key, value) using queries
|
226 |
+
"""
|
227 |
+
# Compute query, key, value projections
|
228 |
+
q = rearrange(self.w_qs(q), "b l (head k) -> head b l k", head=self.n_head)
|
229 |
+
k = rearrange(self.w_ks(k), "b t (head k) -> head b t k", head=self.n_head)
|
230 |
+
v = rearrange(self.w_vs(v), "b t (head k) -> head b t k", head=self.n_head)
|
231 |
+
|
232 |
+
# Compute attention matrix
|
233 |
+
attn = torch.einsum("hblk,hbtk->hblt", [q, k]) / np.sqrt(q.shape[-1])
|
234 |
+
|
235 |
+
# Add relative position bias to attention scores
|
236 |
+
if position_bias is None:
|
237 |
+
if self.has_relative_attention_bias:
|
238 |
+
position_bias = self.compute_bias(q.size(-2), k.size(-2))
|
239 |
+
else:
|
240 |
+
position_bias = torch.zeros_like(attn)
|
241 |
+
attn += position_bias
|
242 |
+
|
243 |
+
# Apply mask to attention scores to prevent looking up invalid locations
|
244 |
+
if mask is not None:
|
245 |
+
attn = attn.masked_fill(mask[None] == 0, -1e9)
|
246 |
+
|
247 |
+
# Normalize attention scores and add dropout
|
248 |
+
attn = torch.softmax(attn, dim=3)
|
249 |
+
attn = self.dropout(attn)
|
250 |
+
|
251 |
+
# Compute attended outputs (product of attention matrix and values)
|
252 |
+
output = torch.einsum("hblt,hbtv->hblv", [attn, v])
|
253 |
+
output = rearrange(output, "head b l v -> b l (head v)")
|
254 |
+
output = self.fc(output)
|
255 |
+
|
256 |
+
return output, position_bias
|
257 |
+
|
258 |
+
|
259 |
+
class TransformerLayer(nn.Module):
|
260 |
+
def __init__(
|
261 |
+
self,
|
262 |
+
d_model: int = 512,
|
263 |
+
d_cond: int = 64,
|
264 |
+
n_heads: int = 8,
|
265 |
+
bidirectional: bool = True,
|
266 |
+
is_decoder: bool = False,
|
267 |
+
has_relative_attention_bias: bool = False,
|
268 |
+
flash_attn: bool = False,
|
269 |
+
dropout: float = 0.1,
|
270 |
+
):
|
271 |
+
super().__init__()
|
272 |
+
# Store args
|
273 |
+
self.is_decoder = is_decoder
|
274 |
+
|
275 |
+
# Create self-attention layer
|
276 |
+
self.norm_1 = RMSNorm(d_model)
|
277 |
+
self.film_1 = FiLM(d_cond, d_model)
|
278 |
+
self.flash_attn = flash_attn
|
279 |
+
|
280 |
+
if flash_attn:
|
281 |
+
from flash_attn.flash_attention import FlashMHA
|
282 |
+
self.self_attn = FlashMHA(
|
283 |
+
embed_dim=d_model,
|
284 |
+
num_heads=n_heads,
|
285 |
+
attention_dropout=dropout,
|
286 |
+
causal=False,
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
self.self_attn = MultiHeadRelativeAttention(
|
290 |
+
n_heads, d_model, dropout, bidirectional, has_relative_attention_bias
|
291 |
+
)
|
292 |
+
|
293 |
+
# (Optional) Create cross-attention layer
|
294 |
+
if is_decoder:
|
295 |
+
self.norm_2 = RMSNorm(d_model)
|
296 |
+
self.film_2 = FiLM(d_cond, d_model)
|
297 |
+
self.cross_attn = MultiHeadRelativeAttention(
|
298 |
+
n_heads,
|
299 |
+
d_model,
|
300 |
+
dropout,
|
301 |
+
bidirectional=True,
|
302 |
+
has_relative_attention_bias=False,
|
303 |
+
)
|
304 |
+
|
305 |
+
# Create last feed-forward layer
|
306 |
+
self.norm_3 = RMSNorm(d_model)
|
307 |
+
self.film_3 = FiLM(d_cond, d_model)
|
308 |
+
self.feed_forward = FeedForward(d_model=d_model, dropout=dropout)
|
309 |
+
|
310 |
+
# Create dropout
|
311 |
+
self.dropout = nn.Dropout(dropout)
|
312 |
+
|
313 |
+
def forward(
|
314 |
+
self,
|
315 |
+
x,
|
316 |
+
x_mask,
|
317 |
+
cond,
|
318 |
+
src=None,
|
319 |
+
src_mask=None,
|
320 |
+
position_bias=None,
|
321 |
+
encoder_decoder_position_bias=None,
|
322 |
+
):
|
323 |
+
"""Computes one transformer layer consisting of self attention, (op) cross attention
|
324 |
+
and feedforward layer
|
325 |
+
Parameters
|
326 |
+
----------
|
327 |
+
x : Tensor[B x T_q x D]
|
328 |
+
x_mask : Tensor[B x T_q]
|
329 |
+
src : Tensor[B x T_kv x D], optional
|
330 |
+
src_mask : Tensor[B x T_kv x D], optional
|
331 |
+
position_bias : Tensor[heads x B x T_q x T_q], optional
|
332 |
+
Relative position bias for self attention layer
|
333 |
+
encoder_decoder_position_bias : Tensor[heads x B x T_q x T_kv], optional
|
334 |
+
Relative position bias for cross attention layer
|
335 |
+
Returns
|
336 |
+
-------
|
337 |
+
Tensor[B x T_q x D]
|
338 |
+
"""
|
339 |
+
y = self.norm_1(x)
|
340 |
+
y = self.film_1(y.permute(0, 2, 1), cond).permute(0, 2, 1)
|
341 |
+
if self.flash_attn:
|
342 |
+
with torch.autocast(y.device.type, dtype=torch.bfloat16):
|
343 |
+
y = self.self_attn(y)[0]
|
344 |
+
else:
|
345 |
+
y, position_bias = self.self_attn(y, y, y, x_mask, position_bias)
|
346 |
+
x = x + self.dropout(y)
|
347 |
+
|
348 |
+
if self.is_decoder:
|
349 |
+
y = self.norm_2(x)
|
350 |
+
y = self.film_2(y.permute(0, 2, 1), cond).permute(0, 2, 1)
|
351 |
+
y, encoder_decoder_position_bias = self.cross_attn(
|
352 |
+
y, src, src, src_mask, encoder_decoder_position_bias
|
353 |
+
)
|
354 |
+
x = x + self.dropout(y)
|
355 |
+
|
356 |
+
y = self.norm_3(x)
|
357 |
+
y = self.film_3(
|
358 |
+
y.permute(
|
359 |
+
0,
|
360 |
+
2,
|
361 |
+
1,
|
362 |
+
),
|
363 |
+
cond,
|
364 |
+
).permute(0, 2, 1)
|
365 |
+
y = self.feed_forward(y)
|
366 |
+
x = x + self.dropout(y)
|
367 |
+
|
368 |
+
return x, position_bias, encoder_decoder_position_bias
|
369 |
+
|
370 |
+
|
371 |
+
class TransformerStack(nn.Module):
|
372 |
+
def __init__(
|
373 |
+
self,
|
374 |
+
d_model: int = 512,
|
375 |
+
d_cond: int = 64,
|
376 |
+
n_heads: int = 8,
|
377 |
+
n_layers: int = 8,
|
378 |
+
last_layer: bool = True,
|
379 |
+
bidirectional: bool = True,
|
380 |
+
flash_attn: bool = False,
|
381 |
+
is_decoder: bool = False,
|
382 |
+
dropout: float = 0.1,
|
383 |
+
):
|
384 |
+
super().__init__()
|
385 |
+
# Store args
|
386 |
+
self.bidirectional = bidirectional
|
387 |
+
self.is_decoder = is_decoder
|
388 |
+
|
389 |
+
# Create transformer layers
|
390 |
+
# In T5, relative attention bias is shared by all layers in the stack
|
391 |
+
self.layers = nn.ModuleList(
|
392 |
+
[
|
393 |
+
TransformerLayer(
|
394 |
+
d_model,
|
395 |
+
d_cond,
|
396 |
+
n_heads,
|
397 |
+
bidirectional,
|
398 |
+
is_decoder,
|
399 |
+
has_relative_attention_bias=True if (i == 0) else False,
|
400 |
+
flash_attn=flash_attn,
|
401 |
+
dropout=dropout,
|
402 |
+
)
|
403 |
+
for i in range(n_layers)
|
404 |
+
]
|
405 |
+
)
|
406 |
+
|
407 |
+
# Perform last normalization
|
408 |
+
self.norm = RMSNorm(d_model) if last_layer else None
|
409 |
+
|
410 |
+
def subsequent_mask(self, size):
|
411 |
+
return torch.ones(1, size, size).tril().bool()
|
412 |
+
|
413 |
+
def forward(self, x, x_mask, cond=None, src=None, src_mask=None,
|
414 |
+
return_activations: bool = False
|
415 |
+
):
|
416 |
+
"""Computes a full transformer stack
|
417 |
+
Parameters
|
418 |
+
----------
|
419 |
+
x : Tensor[B x T_q x D]
|
420 |
+
x_mask : Tensor[B x T_q]
|
421 |
+
src : Tensor[B x T_kv x D], optional
|
422 |
+
src_mask : Tensor[B x T_kv], optional
|
423 |
+
Returns
|
424 |
+
-------
|
425 |
+
Tensor[B x T_q x D]
|
426 |
+
"""
|
427 |
+
|
428 |
+
# Convert `src_mask` to (B x T_q x T_kv) shape for cross attention masking
|
429 |
+
if self.is_decoder:
|
430 |
+
src_mask = x_mask.unsqueeze(-1) * src_mask.unsqueeze(-2)
|
431 |
+
|
432 |
+
# Convert `x_mask` to (B x T_q x T_q) shape for self attention masking
|
433 |
+
x_mask = x_mask.unsqueeze(-2)
|
434 |
+
if not self.bidirectional:
|
435 |
+
x_mask = x_mask * self.subsequent_mask(x.size(1)).to(x_mask.device)
|
436 |
+
|
437 |
+
# Initialize position biases
|
438 |
+
position_bias = None
|
439 |
+
encoder_decoder_position_bias = None
|
440 |
+
|
441 |
+
# Compute transformer layers
|
442 |
+
if return_activations:
|
443 |
+
activations = []
|
444 |
+
for layer in self.layers:
|
445 |
+
x, position_bias, encoder_decoder_position_bias = layer(
|
446 |
+
x=x,
|
447 |
+
x_mask=x_mask,
|
448 |
+
cond=cond,
|
449 |
+
src=src,
|
450 |
+
src_mask=src_mask,
|
451 |
+
position_bias=position_bias,
|
452 |
+
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
453 |
+
)
|
454 |
+
if return_activations:
|
455 |
+
activations.append(x.detach())
|
456 |
+
|
457 |
+
|
458 |
+
out = self.norm(x) if self.norm is not None else x
|
459 |
+
if return_activations:
|
460 |
+
return out, torch.stack(activations)
|
461 |
+
else:
|
462 |
+
return out
|
463 |
+
|
464 |
+
|
465 |
+
class VampNet(at.ml.BaseModel):
|
466 |
+
def __init__(
|
467 |
+
self,
|
468 |
+
n_heads: int = 20,
|
469 |
+
n_layers: int = 16,
|
470 |
+
r_cond_dim: int = 0,
|
471 |
+
n_codebooks: int = 9,
|
472 |
+
n_conditioning_codebooks: int = 0,
|
473 |
+
latent_dim: int = 8,
|
474 |
+
embedding_dim: int = 1280,
|
475 |
+
vocab_size: int = 1024,
|
476 |
+
flash_attn: bool = True,
|
477 |
+
noise_mode: str = "mask",
|
478 |
+
dropout: float = 0.1
|
479 |
+
):
|
480 |
+
super().__init__()
|
481 |
+
assert r_cond_dim == 0, f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
|
482 |
+
self.n_heads = n_heads
|
483 |
+
self.n_layers = n_layers
|
484 |
+
self.r_cond_dim = r_cond_dim
|
485 |
+
self.n_codebooks = n_codebooks
|
486 |
+
self.n_conditioning_codebooks = n_conditioning_codebooks
|
487 |
+
self.embedding_dim = embedding_dim
|
488 |
+
self.vocab_size = vocab_size
|
489 |
+
self.latent_dim = latent_dim
|
490 |
+
self.flash_attn = flash_attn
|
491 |
+
self.noise_mode = noise_mode
|
492 |
+
|
493 |
+
assert self.noise_mode == "mask", "deprecated"
|
494 |
+
|
495 |
+
self.embedding = CodebookEmbedding(
|
496 |
+
latent_dim=latent_dim,
|
497 |
+
n_codebooks=n_codebooks,
|
498 |
+
vocab_size=vocab_size,
|
499 |
+
emb_dim=embedding_dim,
|
500 |
+
special_tokens=["MASK"],
|
501 |
+
)
|
502 |
+
self.mask_token = self.embedding.special_idxs["MASK"]
|
503 |
+
|
504 |
+
self.transformer = TransformerStack(
|
505 |
+
d_model=embedding_dim,
|
506 |
+
d_cond=r_cond_dim,
|
507 |
+
n_heads=n_heads,
|
508 |
+
n_layers=n_layers,
|
509 |
+
last_layer=True,
|
510 |
+
bidirectional=True,
|
511 |
+
flash_attn=flash_attn,
|
512 |
+
is_decoder=False,
|
513 |
+
dropout=dropout,
|
514 |
+
)
|
515 |
+
|
516 |
+
# Add final conv layer
|
517 |
+
self.n_predict_codebooks = n_codebooks - n_conditioning_codebooks
|
518 |
+
self.classifier = SequentialWithFiLM(
|
519 |
+
WNConv1d(
|
520 |
+
embedding_dim,
|
521 |
+
vocab_size * self.n_predict_codebooks,
|
522 |
+
kernel_size=1,
|
523 |
+
padding="same",
|
524 |
+
# groups=self.n_predict_codebooks,
|
525 |
+
),
|
526 |
+
)
|
527 |
+
|
528 |
+
def forward(self, x, return_activations: bool = False):
|
529 |
+
x = self.embedding(x)
|
530 |
+
x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
|
531 |
+
|
532 |
+
x = rearrange(x, "b d n -> b n d")
|
533 |
+
out = self.transformer(x=x, x_mask=x_mask, return_activations=return_activations)
|
534 |
+
if return_activations:
|
535 |
+
out, activations = out
|
536 |
+
|
537 |
+
out = rearrange(out, "b n d -> b d n")
|
538 |
+
|
539 |
+
out = self.classifier(out, None) # no cond here!
|
540 |
+
|
541 |
+
out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
|
542 |
+
|
543 |
+
if return_activations:
|
544 |
+
return out, activations
|
545 |
+
else:
|
546 |
+
return out
|
547 |
+
|
548 |
+
def r_embed(self, r, max_positions=10000):
|
549 |
+
if self.r_cond_dim > 0:
|
550 |
+
dtype = r.dtype
|
551 |
+
|
552 |
+
r = _gamma(r) * max_positions
|
553 |
+
half_dim = self.r_cond_dim // 2
|
554 |
+
|
555 |
+
emb = math.log(max_positions) / (half_dim - 1)
|
556 |
+
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
557 |
+
|
558 |
+
emb = r[:, None] * emb[None, :]
|
559 |
+
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
560 |
+
|
561 |
+
if self.r_cond_dim % 2 == 1: # zero pad
|
562 |
+
emb = nn.functional.pad(emb, (0, 1), mode="constant")
|
563 |
+
|
564 |
+
return emb.to(dtype)
|
565 |
+
else:
|
566 |
+
return r
|
567 |
+
|
568 |
+
@torch.no_grad()
|
569 |
+
def decode(self, z, codec):
|
570 |
+
"""
|
571 |
+
convert a sequence of latents to a signal.
|
572 |
+
"""
|
573 |
+
assert z.ndim == 3
|
574 |
+
|
575 |
+
# remove mask token
|
576 |
+
z = z.masked_fill(z == self.mask_token, 0)
|
577 |
+
signal = at.AudioSignal(
|
578 |
+
codec.decode(
|
579 |
+
codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
|
580 |
+
)["audio"],
|
581 |
+
codec.sample_rate,
|
582 |
+
)
|
583 |
+
|
584 |
+
# find where the mask token is and replace it with silence in the audio
|
585 |
+
for tstep in range(z.shape[-1]):
|
586 |
+
if torch.all(z[:, :, tstep] == self.mask_token):
|
587 |
+
sample_idx_0 = tstep * codec.hop_length
|
588 |
+
sample_idx_1 = sample_idx_0 + codec.hop_length
|
589 |
+
signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
|
590 |
+
|
591 |
+
return signal
|
592 |
+
|
593 |
+
@torch.inference_mode()
|
594 |
+
def generate(
|
595 |
+
self,
|
596 |
+
codec,
|
597 |
+
time_steps: int = 300,
|
598 |
+
_sampling_steps: List[int] = [12],
|
599 |
+
start_tokens: Optional[torch.Tensor] = None,
|
600 |
+
temperature: float = 1.0,
|
601 |
+
mask: Optional[torch.Tensor] = None,
|
602 |
+
mask_temperature: float = 10.5,
|
603 |
+
typical_filtering=True,
|
604 |
+
typical_mass=0.2,
|
605 |
+
typical_min_tokens=64,
|
606 |
+
top_p=None,
|
607 |
+
seed: int = None,
|
608 |
+
sample_cutoff: float = 1.0,
|
609 |
+
return_signal=True,
|
610 |
+
debug=False,
|
611 |
+
causal_weight: float = 0.0,
|
612 |
+
cfg_guidance: float = None,
|
613 |
+
):
|
614 |
+
if seed is not None:
|
615 |
+
at.util.seed(seed)
|
616 |
+
sampling_steps = sum(_sampling_steps)
|
617 |
+
logging.debug(f"beginning generation with {sampling_steps} steps")
|
618 |
+
|
619 |
+
#####################
|
620 |
+
# resolve initial z #
|
621 |
+
#####################
|
622 |
+
z = start_tokens
|
623 |
+
nb = z.shape[0]
|
624 |
+
|
625 |
+
if z is None:
|
626 |
+
z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
|
627 |
+
self.device
|
628 |
+
)
|
629 |
+
|
630 |
+
|
631 |
+
|
632 |
+
#################
|
633 |
+
# resolve mask #
|
634 |
+
#################
|
635 |
+
|
636 |
+
if mask is None:
|
637 |
+
mask = torch.ones_like(z).to(self.device).int()
|
638 |
+
mask[:, : self.n_conditioning_codebooks, :] = 0.0
|
639 |
+
if mask.ndim == 2:
|
640 |
+
mask = mask[:, None, :].repeat(1, z.shape[1], 1)
|
641 |
+
# init_mask = mask.clone()
|
642 |
+
|
643 |
+
|
644 |
+
|
645 |
+
###########
|
646 |
+
# set up #
|
647 |
+
##########
|
648 |
+
# apply the mask to z
|
649 |
+
z_masked = z.masked_fill(mask.bool(), self.mask_token)
|
650 |
+
# logging.debug(f"z_masked: {z_masked}")
|
651 |
+
|
652 |
+
# how many mask tokens to begin with?
|
653 |
+
num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
|
654 |
+
|
655 |
+
# how many codebooks are we inferring vs conditioning on?
|
656 |
+
n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
|
657 |
+
|
658 |
+
if cfg_guidance is not None:
|
659 |
+
# we need to repeat our tensors
|
660 |
+
z_uncond = torch.full_like(z, self.mask_token)
|
661 |
+
|
662 |
+
z_masked = torch.cat(
|
663 |
+
(z_masked, z_uncond), dim=0
|
664 |
+
)
|
665 |
+
z = torch.cat(
|
666 |
+
(z, z_uncond), dim=0
|
667 |
+
)
|
668 |
+
mask = torch.cat(
|
669 |
+
(mask, torch.full_like(mask, 1)), dim=0
|
670 |
+
)
|
671 |
+
|
672 |
+
#################
|
673 |
+
# begin sampling #
|
674 |
+
#################
|
675 |
+
from tqdm import tqdm
|
676 |
+
for i in range(sampling_steps):
|
677 |
+
|
678 |
+
# our current schedule step
|
679 |
+
r = scalar_to_batch_tensor(
|
680 |
+
(i + 1) / sampling_steps,
|
681 |
+
z.shape[0]
|
682 |
+
).to(z.device)
|
683 |
+
|
684 |
+
# get latents
|
685 |
+
latents = self.embedding.from_codes(z_masked, codec)
|
686 |
+
|
687 |
+
|
688 |
+
# infer from latents
|
689 |
+
# NOTE: this collapses the codebook dimension into the sequence dimension
|
690 |
+
logits = self.forward(latents) # b, prob, seq
|
691 |
+
|
692 |
+
if cfg_guidance is not None:
|
693 |
+
logits_cond, logits_uncond = logits[:nb], logits[nb:]
|
694 |
+
logits_cond = cfg_guidance * logits_cond + cfg_guidance * (1 - logits_uncond)
|
695 |
+
|
696 |
+
logits = logits.permute(0, 2, 1) # b, seq, prob
|
697 |
+
b = logits.shape[0]
|
698 |
+
|
699 |
+
sampled_z, selected_probs = sample_from_logits(
|
700 |
+
logits, sample=(
|
701 |
+
(i / sampling_steps) <= sample_cutoff
|
702 |
+
),
|
703 |
+
temperature=temperature,
|
704 |
+
typical_filtering=typical_filtering, typical_mass=typical_mass,
|
705 |
+
typical_min_tokens=typical_min_tokens,
|
706 |
+
top_k=None, top_p=top_p, return_probs=True,
|
707 |
+
)
|
708 |
+
|
709 |
+
|
710 |
+
# flatten z_masked and mask, so we can deal with the sampling logic
|
711 |
+
# we'll unflatten them at the end of the loop for the next forward pass
|
712 |
+
# remove conditioning codebooks, we'll add them back at the end
|
713 |
+
z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
|
714 |
+
|
715 |
+
mask = (z_masked == self.mask_token).int()
|
716 |
+
|
717 |
+
# update the mask, remove conditioning codebooks from the mask
|
718 |
+
# add z back into sampled z where the mask was false
|
719 |
+
sampled_z = torch.where(
|
720 |
+
mask.bool(), sampled_z, z_masked
|
721 |
+
)
|
722 |
+
|
723 |
+
# ignore any tokens that weren't masked
|
724 |
+
selected_probs = torch.where(
|
725 |
+
mask.bool(), selected_probs, torch.inf
|
726 |
+
)
|
727 |
+
|
728 |
+
# get the num tokens to mask, according to the schedule
|
729 |
+
num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
|
730 |
+
logging.debug(f"num to mask: {num_to_mask}")
|
731 |
+
|
732 |
+
if i != (sampling_steps - 1):
|
733 |
+
num_to_mask = torch.maximum(
|
734 |
+
torch.tensor(1),
|
735 |
+
torch.minimum(
|
736 |
+
mask.sum(dim=-1, keepdim=True) - 1,
|
737 |
+
num_to_mask
|
738 |
+
)
|
739 |
+
)
|
740 |
+
|
741 |
+
|
742 |
+
# get our new mask
|
743 |
+
mask = mask_by_random_topk(
|
744 |
+
num_to_mask, selected_probs, mask_temperature * (1-r)
|
745 |
+
)
|
746 |
+
|
747 |
+
# update the mask
|
748 |
+
z_masked = torch.where(
|
749 |
+
mask.bool(), self.mask_token, sampled_z
|
750 |
+
)
|
751 |
+
|
752 |
+
z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
|
753 |
+
mask = codebook_unflatten(mask, n_infer_codebooks)
|
754 |
+
|
755 |
+
# add conditioning codebooks back to z_masked
|
756 |
+
z_masked = torch.cat(
|
757 |
+
(z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1
|
758 |
+
)
|
759 |
+
|
760 |
+
# add conditioning codebooks back to sampled_z
|
761 |
+
sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks)
|
762 |
+
sampled_z = torch.cat(
|
763 |
+
(z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1
|
764 |
+
)
|
765 |
+
|
766 |
+
if cfg_guidance is not None:
|
767 |
+
sampled_z = sampled_z[:nb]
|
768 |
+
|
769 |
+
if return_signal:
|
770 |
+
return self.decode(sampled_z, codec)
|
771 |
+
else:
|
772 |
+
return sampled_z
|
773 |
+
|
774 |
+
|
775 |
+
|
776 |
+
|
777 |
+
|
778 |
+
def sample_from_logits(
|
779 |
+
logits,
|
780 |
+
sample: bool = True,
|
781 |
+
temperature: float = 1.0,
|
782 |
+
top_k: int = None,
|
783 |
+
top_p: float = None,
|
784 |
+
typical_filtering: bool = False,
|
785 |
+
typical_mass: float = 0.2,
|
786 |
+
typical_min_tokens: int = 1,
|
787 |
+
return_probs: bool = False
|
788 |
+
):
|
789 |
+
"""Convenience function to sample from a categorial distribution with input as
|
790 |
+
unnormalized logits.
|
791 |
+
|
792 |
+
Parameters
|
793 |
+
----------
|
794 |
+
logits : Tensor[..., vocab_size]
|
795 |
+
config: SamplingConfig
|
796 |
+
The set of hyperparameters to be used for sampling
|
797 |
+
sample : bool, optional
|
798 |
+
Whether to perform multinomial sampling, by default True
|
799 |
+
temperature : float, optional
|
800 |
+
Scaling parameter when multinomial samping, by default 1.0
|
801 |
+
top_k : int, optional
|
802 |
+
Restricts sampling to only `top_k` values acc. to probability,
|
803 |
+
by default None
|
804 |
+
top_p : float, optional
|
805 |
+
Restricts sampling to only those values with cumulative
|
806 |
+
probability = `top_p`, by default None
|
807 |
+
|
808 |
+
Returns
|
809 |
+
-------
|
810 |
+
Tensor[...]
|
811 |
+
Sampled tokens
|
812 |
+
"""
|
813 |
+
shp = logits.shape[:-1]
|
814 |
+
|
815 |
+
if typical_filtering:
|
816 |
+
typical_filter(logits,
|
817 |
+
typical_mass=typical_mass,
|
818 |
+
typical_min_tokens=typical_min_tokens
|
819 |
+
)
|
820 |
+
|
821 |
+
# Apply top_k sampling
|
822 |
+
if top_k is not None:
|
823 |
+
v, _ = logits.topk(top_k)
|
824 |
+
logits[logits < v[..., [-1]]] = -float("inf")
|
825 |
+
|
826 |
+
# Apply top_p (nucleus) sampling
|
827 |
+
if top_p is not None and top_p < 1.0:
|
828 |
+
v, sorted_indices = logits.sort(descending=True)
|
829 |
+
cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
|
830 |
+
|
831 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
832 |
+
# Right shift indices_to_remove to keep 1st token over threshold
|
833 |
+
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[
|
834 |
+
..., :-1
|
835 |
+
]
|
836 |
+
|
837 |
+
# Compute indices_to_remove in unsorted array
|
838 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
839 |
+
-1, sorted_indices, sorted_indices_to_remove
|
840 |
+
)
|
841 |
+
|
842 |
+
logits[indices_to_remove] = -float("inf")
|
843 |
+
|
844 |
+
# Perform multinomial sampling after normalizing logits
|
845 |
+
probs = (
|
846 |
+
F.softmax(logits / temperature, dim=-1)
|
847 |
+
if temperature > 0
|
848 |
+
else logits.softmax(dim=-1)
|
849 |
+
)
|
850 |
+
token = (
|
851 |
+
probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
|
852 |
+
if sample
|
853 |
+
else logits.argmax(-1)
|
854 |
+
)
|
855 |
+
|
856 |
+
if return_probs:
|
857 |
+
token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1)
|
858 |
+
return token, token_probs
|
859 |
+
else:
|
860 |
+
return token
|
861 |
+
|
862 |
+
|
863 |
+
|
864 |
+
def mask_by_random_topk(
|
865 |
+
num_to_mask: int,
|
866 |
+
probs: torch.Tensor,
|
867 |
+
temperature: float = 1.0,
|
868 |
+
):
|
869 |
+
"""
|
870 |
+
Args:
|
871 |
+
num_to_mask (int): number of tokens to mask
|
872 |
+
probs (torch.Tensor): probabilities for each sampled event, shape (batch, seq)
|
873 |
+
temperature (float, optional): temperature. Defaults to 1.0.
|
874 |
+
"""
|
875 |
+
logging.debug(f"masking by random topk")
|
876 |
+
logging.debug(f"num to mask: {num_to_mask}")
|
877 |
+
logging.debug(f"probs shape: {probs.shape}")
|
878 |
+
logging.debug(f"temperature: {temperature}")
|
879 |
+
logging.debug("")
|
880 |
+
|
881 |
+
noise = gumbel_noise_like(probs)
|
882 |
+
temperature = temperature.unsqueeze(-1)
|
883 |
+
confidence = torch.log(probs) + temperature * noise
|
884 |
+
logging.debug(f"confidence shape: {confidence.shape}")
|
885 |
+
|
886 |
+
sorted_confidence, sorted_idx = confidence.sort(dim=-1)
|
887 |
+
logging.debug(f"sorted confidence shape: {sorted_confidence.shape}")
|
888 |
+
logging.debug(f"sorted idx shape: {sorted_idx.shape}")
|
889 |
+
|
890 |
+
# get the cut off threshold, given the mask length
|
891 |
+
cut_off = torch.take_along_dim(
|
892 |
+
sorted_confidence, num_to_mask, axis=-1
|
893 |
+
)
|
894 |
+
logging.debug(f"cut off shape: {cut_off.shape}")
|
895 |
+
|
896 |
+
# mask out the tokens
|
897 |
+
mask = confidence < cut_off
|
898 |
+
logging.debug(f"mask shape: {mask.shape}")
|
899 |
+
|
900 |
+
return mask
|
901 |
+
|
902 |
+
def typical_filter(
|
903 |
+
logits,
|
904 |
+
typical_mass: float = 0.95,
|
905 |
+
typical_min_tokens: int = 1,):
|
906 |
+
nb, nt, _ = logits.shape
|
907 |
+
x_flat = rearrange(logits, "b t l -> (b t ) l")
|
908 |
+
x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
|
909 |
+
x_flat_norm_p = torch.exp(x_flat_norm)
|
910 |
+
entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
|
911 |
+
|
912 |
+
c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
|
913 |
+
c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
|
914 |
+
x_flat_cumsum = (
|
915 |
+
x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
|
916 |
+
)
|
917 |
+
|
918 |
+
last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
|
919 |
+
sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
|
920 |
+
1, last_ind.view(-1, 1)
|
921 |
+
)
|
922 |
+
if typical_min_tokens > 1:
|
923 |
+
sorted_indices_to_remove[..., :typical_min_tokens] = 0
|
924 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
925 |
+
1, x_flat_indices, sorted_indices_to_remove
|
926 |
+
)
|
927 |
+
x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
|
928 |
+
logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
|
929 |
+
return logits
|
930 |
+
|
931 |
+
|
932 |
+
if __name__ == "__main__":
|
933 |
+
# import argbind
|
934 |
+
from .layers import num_params
|
935 |
+
|
936 |
+
VampNet = argbind.bind(VampNet)
|
937 |
+
|
938 |
+
@argbind.bind(without_prefix=True)
|
939 |
+
def try_model(device: str = "cuda", batch_size: int = 2, seq_len_s: float = 10.0):
|
940 |
+
seq_len = int(32000 / 512 * seq_len_s)
|
941 |
+
|
942 |
+
model = VampNet().to(device)
|
943 |
+
|
944 |
+
z = torch.randint(
|
945 |
+
0, model.vocab_size, size=(batch_size, model.n_codebooks, seq_len)
|
946 |
+
).to(device)
|
947 |
+
|
948 |
+
r = torch.zeros(batch_size).to(device)
|
949 |
+
|
950 |
+
z_mask_latent = torch.rand(
|
951 |
+
batch_size, model.latent_dim * model.n_codebooks, seq_len
|
952 |
+
).to(device)
|
953 |
+
z_hat = model(z_mask_latent)
|
954 |
+
|
955 |
+
pred = z_hat.argmax(dim=1)
|
956 |
+
pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
|
957 |
+
|
958 |
+
logging.debug(f"model has {num_params(model)/1e6:<.3f}M parameters")
|
959 |
+
logging.debug(f"prediction has shape {pred.shape}")
|
960 |
+
|
961 |
+
args = argbind.parse_args()
|
962 |
+
with argbind.scope(args):
|
963 |
+
try_model()
|
964 |
+
|
965 |
+
|
vampnet/scheduler.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
class NoamScheduler:
|
7 |
+
"""OG scheduler from transformer paper: https://arxiv.org/pdf/1706.03762.pdf
|
8 |
+
Implementation from Annotated Transformer: https://nlp.seas.harvard.edu/2018/04/03/attention.html
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
optimizer: torch.optim.Optimizer,
|
14 |
+
d_model: int = 512,
|
15 |
+
factor: float = 1.0,
|
16 |
+
warmup: int = 4000,
|
17 |
+
):
|
18 |
+
# Store hparams
|
19 |
+
self.warmup = warmup
|
20 |
+
self.factor = factor
|
21 |
+
self.d_model = d_model
|
22 |
+
|
23 |
+
# Initialize variables `lr` and `steps`
|
24 |
+
self.lr = None
|
25 |
+
self.steps = 0
|
26 |
+
|
27 |
+
# Store the optimizer
|
28 |
+
self.optimizer = optimizer
|
29 |
+
|
30 |
+
def state_dict(self):
|
31 |
+
return {
|
32 |
+
key: value for key, value in self.__dict__.items() if key != "optimizer"
|
33 |
+
}
|
34 |
+
|
35 |
+
def load_state_dict(self, state_dict):
|
36 |
+
self.__dict__.update(state_dict)
|
37 |
+
|
38 |
+
def step(self):
|
39 |
+
self.steps += 1
|
40 |
+
self.lr = self.factor * (
|
41 |
+
self.d_model ** (-0.5)
|
42 |
+
* min(self.steps ** (-0.5), self.steps * self.warmup ** (-1.5))
|
43 |
+
)
|
44 |
+
|
45 |
+
for p in self.optimizer.param_groups:
|
46 |
+
p["lr"] = self.lr
|
47 |
+
|
vampnet/util.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tqdm
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
def scalar_to_batch_tensor(x, batch_size):
|
7 |
+
return torch.tensor(x).repeat(batch_size)
|
8 |
+
|
9 |
+
|
10 |
+
def parallelize(
|
11 |
+
fn,
|
12 |
+
*iterables,
|
13 |
+
parallel: str = "thread_map",
|
14 |
+
**kwargs
|
15 |
+
):
|
16 |
+
if parallel == "thread_map":
|
17 |
+
from tqdm.contrib.concurrent import thread_map
|
18 |
+
return thread_map(
|
19 |
+
fn,
|
20 |
+
*iterables,
|
21 |
+
**kwargs
|
22 |
+
)
|
23 |
+
elif parallel == "process_map":
|
24 |
+
from tqdm.contrib.concurrent import process_map
|
25 |
+
return process_map(
|
26 |
+
fn,
|
27 |
+
*iterables,
|
28 |
+
**kwargs
|
29 |
+
)
|
30 |
+
elif parallel == "single":
|
31 |
+
return [fn(x) for x in tqdm.tqdm(*iterables)]
|
32 |
+
else:
|
33 |
+
raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}")
|
34 |
+
|
35 |
+
def codebook_flatten(tokens: torch.Tensor):
|
36 |
+
"""
|
37 |
+
flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
|
38 |
+
"""
|
39 |
+
return rearrange(tokens, "b c t -> b (t c)")
|
40 |
+
|
41 |
+
def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None):
|
42 |
+
"""
|
43 |
+
unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
|
44 |
+
"""
|
45 |
+
tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c)
|
46 |
+
return tokens
|