Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import random | |
from typing import List | |
import tempfile | |
import subprocess | |
import argbind | |
from tqdm import tqdm | |
import torch | |
from vampnet.interface import Interface | |
from vampnet import mask as pmask | |
import audiotools as at | |
Interface: Interface = argbind.bind(Interface) | |
def calculate_bitrate( | |
interface, num_codebooks, | |
downsample_factor | |
): | |
bit_width = 10 | |
sr = interface.codec.sample_rate | |
hop = interface.codec.hop_size | |
rate = (sr / hop) * ((bit_width * num_codebooks) / downsample_factor) | |
return rate | |
def baseline(sig, interface): | |
return interface.preprocess(sig) | |
def reconstructed(sig, interface): | |
return interface.decode( | |
interface.encode(sig) | |
) | |
def coarse2fine(sig, interface): | |
z = interface.encode(sig) | |
z = z[:, :interface.c2f.n_conditioning_codebooks, :] | |
z = interface.coarse_to_fine(z) | |
return interface.decode(z) | |
class CoarseCond: | |
def __init__(self, num_conditioning_codebooks, downsample_factor): | |
self.num_conditioning_codebooks = num_conditioning_codebooks | |
self.downsample_factor = downsample_factor | |
def __call__(self, sig, interface): | |
z = interface.encode(sig) | |
mask = pmask.full_mask(z) | |
mask = pmask.codebook_unmask(mask, self.num_conditioning_codebooks) | |
mask = pmask.periodic_mask(mask, self.downsample_factor) | |
zv = interface.coarse_vamp(z, mask) | |
zv = interface.coarse_to_fine(zv) | |
return interface.decode(zv) | |
def opus(sig, interface, bitrate=128): | |
sig = interface.preprocess(sig) | |
with tempfile.NamedTemporaryFile(suffix=".wav") as f: | |
sig.write(f.name) | |
opus_name = Path(f.name).with_suffix(".opus") | |
# convert to opus | |
cmd = [ | |
"ffmpeg", "-y", "-i", f.name, | |
"-c:a", "libopus", | |
"-b:a", f"{bitrate}", | |
opus_name | |
] | |
subprocess.run(cmd, check=True) | |
# convert back to wav | |
output_name = Path(f"{f.name}-opus").with_suffix(".wav") | |
cmd = [ | |
"ffmpeg", "-y", "-i", opus_name, | |
output_name | |
] | |
subprocess.run(cmd, check=True) | |
sig = at.AudioSignal( | |
output_name, | |
sample_rate=sig.sample_rate | |
) | |
return sig | |
def mask_ratio_1_step(ratio=1.0): | |
def wrapper(sig, interface): | |
z = interface.encode(sig) | |
mask = pmask.linear_random(z, ratio) | |
zv = interface.coarse_vamp( | |
z, | |
mask, | |
sampling_steps=1, | |
) | |
return interface.decode(zv) | |
return wrapper | |
def num_sampling_steps(num_steps=1): | |
def wrapper(sig, interface: Interface): | |
z = interface.encode(sig) | |
mask = pmask.periodic_mask(z, 16) | |
zv = interface.coarse_vamp( | |
z, | |
mask, | |
sampling_steps=num_steps, | |
) | |
zv = interface.coarse_to_fine(zv) | |
return interface.decode(zv) | |
return wrapper | |
def beat_mask(ctx_time): | |
def wrapper(sig, interface): | |
beat_mask = interface.make_beat_mask( | |
sig, | |
before_beat_s=ctx_time/2, | |
after_beat_s=ctx_time/2, | |
invert=True | |
) | |
z = interface.encode(sig) | |
zv = interface.coarse_vamp( | |
z, beat_mask | |
) | |
zv = interface.coarse_to_fine(zv) | |
return interface.decode(zv) | |
return wrapper | |
def inpaint(ctx_time): | |
def wrapper(sig, interface: Interface): | |
z = interface.encode(sig) | |
mask = pmask.inpaint(z, interface.s2t(ctx_time), interface.s2t(ctx_time)) | |
zv = interface.coarse_vamp(z, mask) | |
zv = interface.coarse_to_fine(zv) | |
return interface.decode(zv) | |
return wrapper | |
def token_noise(noise_amt): | |
def wrapper(sig, interface: Interface): | |
z = interface.encode(sig) | |
mask = pmask.random(z, noise_amt) | |
z = torch.where( | |
mask, | |
torch.randint_like(z, 0, interface.coarse.vocab_size), | |
z | |
) | |
return interface.decode(z) | |
return wrapper | |
EXP_REGISTRY = {} | |
EXP_REGISTRY["gen-compression"] = { | |
"baseline": baseline, | |
"reconstructed": reconstructed, | |
"coarse2fine": coarse2fine, | |
**{ | |
f"{n}_codebooks_downsampled_{x}x": CoarseCond(num_conditioning_codebooks=n, downsample_factor=x) | |
for (n, x) in ( | |
(1, 1), # 1 codebook, no downsampling | |
(4, 4), # 4 codebooks, downsampled 4x | |
(4, 16), # 4 codebooks, downsampled 16x | |
(4, 32), # 4 codebooks, downsampled 16x | |
) | |
}, | |
**{ | |
f"token_noise_{x}": mask_ratio_1_step(ratio=x) | |
for x in [0.25, 0.5, 0.75] | |
}, | |
} | |
EXP_REGISTRY["sampling-steps"] = { | |
# "codec": reconstructed, | |
**{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 36, 64, 72]}, | |
} | |
EXP_REGISTRY["musical-sampling"] = { | |
**{f"beat_mask_{t}": beat_mask(t) for t in [0.075]}, | |
**{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right) | |
} | |
def main( | |
sources=[ | |
"/media/CHONK/hugo/spotdl/val", | |
], | |
output_dir: str = "./samples", | |
max_excerpts: int = 2000, | |
exp_type: str = "gen-compression", | |
seed: int = 0, | |
ext: str = [".mp3"], | |
): | |
at.util.seed(seed) | |
interface = Interface() | |
output_dir = Path(output_dir) | |
output_dir.mkdir(exist_ok=True, parents=True) | |
from audiotools.data.datasets import AudioLoader, AudioDataset | |
loader = AudioLoader(sources=sources, shuffle_state=seed, ext=ext) | |
dataset = AudioDataset(loader, | |
sample_rate=interface.codec.sample_rate, | |
duration=interface.coarse.chunk_size_s, | |
n_examples=max_excerpts, | |
without_replacement=True, | |
) | |
if exp_type in EXP_REGISTRY: | |
SAMPLE_CONDS = EXP_REGISTRY[exp_type] | |
else: | |
raise ValueError(f"Unknown exp_type {exp_type}") | |
indices = list(range(max_excerpts)) | |
random.shuffle(indices) | |
for i in tqdm(indices): | |
# if all our files are already there, skip | |
done = [] | |
for name in SAMPLE_CONDS: | |
o_dir = Path(output_dir) / name | |
done.append((o_dir / f"{i}.wav").exists()) | |
if all(done): | |
continue | |
sig = dataset[i]["signal"] | |
results = { | |
name: cond(sig, interface).cpu() | |
for name, cond in SAMPLE_CONDS.items() | |
} | |
for name, sig in results.items(): | |
o_dir = Path(output_dir) / name | |
o_dir.mkdir(exist_ok=True, parents=True) | |
sig.write(o_dir / f"{i}.wav") | |
if __name__ == "__main__": | |
args = argbind.parse_args() | |
with argbind.scope(args): | |
main() | |