Spaces:
Sleeping
Sleeping
import spaces | |
from pathlib import Path | |
import yaml | |
import time | |
import uuid | |
import numpy as np | |
import audiotools as at | |
import argbind | |
import shutil | |
import torch | |
from datetime import datetime | |
import gradio as gr | |
from vampnet.interface import Interface, signal_concat | |
from vampnet import mask as pmask | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
interface = Interface.default() | |
# populate the model choices with any interface.yml files in the generated confs | |
MODEL_CHOICES = { | |
"default": { | |
"Interface.coarse_ckpt": str(interface.coarse_path), | |
"Interface.coarse2fine_ckpt": str(interface.c2f_path), | |
"Interface.codec_ckpt": str(interface.codec_path), | |
} | |
} | |
generated_confs = Path("conf/generated") | |
for conf_file in generated_confs.glob("*/interface.yml"): | |
with open(conf_file) as f: | |
_conf = yaml.safe_load(f) | |
# check if the coarse, c2f, and codec ckpts exist | |
# otherwise, dont' add this model choice | |
if not ( | |
Path(_conf["Interface.coarse_ckpt"]).exists() and | |
Path(_conf["Interface.coarse2fine_ckpt"]).exists() and | |
Path(_conf["Interface.codec_ckpt"]).exists() | |
): | |
continue | |
MODEL_CHOICES[conf_file.parent.name] = _conf | |
def to_output(sig): | |
return sig.sample_rate, sig.cpu().detach().numpy()[0][0] | |
MAX_DURATION_S = 5 | |
def load_audio(file): | |
print(file) | |
if isinstance(file, str): | |
filepath = file | |
elif isinstance(file, tuple): | |
# not a file | |
sr, samples = file | |
samples = samples / np.iinfo(samples.dtype).max | |
return sr, samples | |
else: | |
filepath = file.name | |
sig = at.AudioSignal.salient_excerpt( | |
filepath, duration=MAX_DURATION_S | |
) | |
sig = at.AudioSignal(filepath) | |
return to_output(sig) | |
def load_example_audio(): | |
return load_audio("./assets/example.wav") | |
from torch_pitch_shift import pitch_shift, get_fast_shifts | |
def shift_pitch(signal, interval: int): | |
signal.samples = pitch_shift( | |
signal.samples, | |
shift=interval, | |
sample_rate=signal.sample_rate | |
) | |
return signal | |
def _vamp( | |
seed, input_audio, model_choice, | |
pitch_shift_amt, periodic_p, | |
n_mask_codebooks, periodic_w, onset_mask_width, | |
dropout, sampletemp, typical_filtering, | |
typical_mass, typical_min_tokens, top_p, | |
sample_cutoff, stretch_factor, api=False | |
): | |
t0 = time.time() | |
interface.to("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"using device {interface.device}") | |
_seed = seed if seed > 0 else None | |
if _seed is None: | |
_seed = int(torch.randint(0, 2**32, (1,)).item()) | |
at.util.seed(_seed) | |
sr, input_audio = input_audio | |
input_audio = input_audio / np.iinfo(input_audio.dtype).max | |
sig = at.AudioSignal(input_audio, sr) | |
# reload the model if necessary | |
interface.reload( | |
coarse_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse_ckpt"], | |
c2f_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse2fine_ckpt"], | |
) | |
if pitch_shift_amt != 0: | |
sig = shift_pitch(sig, pitch_shift_amt) | |
build_mask_kwargs = dict( | |
rand_mask_intensity=1.0, | |
prefix_s=0.0, | |
suffix_s=0.0, | |
periodic_prompt=int(periodic_p), | |
periodic_prompt_width=periodic_w, | |
onset_mask_width=onset_mask_width, | |
_dropout=dropout, | |
upper_codebook_mask=int(n_mask_codebooks), | |
) | |
vamp_kwargs = dict( | |
temperature=sampletemp, | |
typical_filtering=typical_filtering, | |
typical_mass=typical_mass, | |
typical_min_tokens=typical_min_tokens, | |
top_p=None, | |
seed=_seed, | |
sample_cutoff=1.0, | |
) | |
# save the mask as a txt file | |
interface.set_chunk_size(10.0) | |
sig, mask, codes = interface.ez_vamp( | |
sig, | |
batch_size=1 if api else 1, | |
feedback_steps=1, | |
time_stretch_factor=stretch_factor, | |
build_mask_kwargs=build_mask_kwargs, | |
vamp_kwargs=vamp_kwargs, | |
return_mask=True, | |
) | |
print(f"vamp took {time.time() - t0} seconds") | |
return to_output(sig) | |
def vamp(data): | |
return _vamp( | |
seed=data[seed], | |
input_audio=data[input_audio], | |
model_choice=data[model_choice], | |
pitch_shift_amt=data[pitch_shift_amt], | |
periodic_p=data[periodic_p], | |
n_mask_codebooks=data[n_mask_codebooks], | |
periodic_w=data[periodic_w], | |
onset_mask_width=data[onset_mask_width], | |
dropout=data[dropout], | |
sampletemp=data[sampletemp], | |
typical_filtering=data[typical_filtering], | |
typical_mass=data[typical_mass], | |
typical_min_tokens=data[typical_min_tokens], | |
top_p=data[top_p], | |
sample_cutoff=data[sample_cutoff], | |
stretch_factor=data[stretch_factor], | |
api=False, | |
) | |
def api_vamp(data): | |
return _vamp( | |
seed=data[seed], | |
input_audio=data[input_audio], | |
model_choice=data[model_choice], | |
pitch_shift_amt=data[pitch_shift_amt], | |
periodic_p=data[periodic_p], | |
n_mask_codebooks=data[n_mask_codebooks], | |
periodic_w=data[periodic_w], | |
onset_mask_width=data[onset_mask_width], | |
dropout=data[dropout], | |
sampletemp=data[sampletemp], | |
typical_filtering=data[typical_filtering], | |
typical_mass=data[typical_mass], | |
typical_min_tokens=data[typical_min_tokens], | |
top_p=data[top_p], | |
sample_cutoff=data[sample_cutoff], | |
stretch_factor=data[stretch_factor], | |
api=True, | |
) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
manual_audio_upload = gr.File( | |
label=f"upload some audio (will be randomly trimmed to max of 100s)", | |
file_types=["audio"] | |
) | |
load_example_audio_button = gr.Button("or load example audio") | |
input_audio = gr.Audio( | |
label="input audio", | |
interactive=False, | |
type="numpy", | |
) | |
audio_mask = gr.Audio( | |
label="audio mask (listen to this to hear the mask hints)", | |
interactive=False, | |
type="numpy", | |
) | |
# connect widgets | |
load_example_audio_button.click( | |
fn=load_example_audio, | |
inputs=[], | |
outputs=[ input_audio] | |
) | |
manual_audio_upload.change( | |
fn=load_audio, | |
inputs=[manual_audio_upload], | |
outputs=[ input_audio] | |
) | |
# mask settings | |
with gr.Column(): | |
with gr.Accordion("manual controls", open=True): | |
periodic_p = gr.Slider( | |
label="periodic prompt", | |
minimum=0, | |
maximum=13, | |
step=1, | |
value=3, | |
) | |
onset_mask_width = gr.Slider( | |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ", | |
minimum=0, | |
maximum=100, | |
step=1, | |
value=0, visible=False | |
) | |
n_mask_codebooks = gr.Slider( | |
label="compression prompt ", | |
value=3, | |
minimum=1, | |
maximum=14, | |
step=1, | |
) | |
maskimg = gr.Image( | |
label="mask image", | |
interactive=False, | |
type="filepath" | |
) | |
with gr.Accordion("extras ", open=False): | |
pitch_shift_amt = gr.Slider( | |
label="pitch shift amount (semitones)", | |
minimum=-12, | |
maximum=12, | |
step=1, | |
value=0, | |
) | |
stretch_factor = gr.Slider( | |
label="time stretch factor", | |
minimum=0, | |
maximum=8, | |
step=1, | |
value=1, | |
) | |
periodic_w = gr.Slider( | |
label="periodic prompt width (steps, 1 step ~= 10milliseconds)", | |
minimum=1, | |
maximum=20, | |
step=1, | |
value=1, | |
) | |
with gr.Accordion("sampling settings", open=False): | |
sampletemp = gr.Slider( | |
label="sample temperature", | |
minimum=0.1, | |
maximum=10.0, | |
value=1.0, | |
step=0.001 | |
) | |
top_p = gr.Slider( | |
label="top p (0.0 = off)", | |
minimum=0.0, | |
maximum=1.0, | |
value=0.0 | |
) | |
typical_filtering = gr.Checkbox( | |
label="typical filtering ", | |
value=True | |
) | |
typical_mass = gr.Slider( | |
label="typical mass (should probably stay between 0.1 and 0.5)", | |
minimum=0.01, | |
maximum=0.99, | |
value=0.15 | |
) | |
typical_min_tokens = gr.Slider( | |
label="typical min tokens (should probably stay between 1 and 256)", | |
minimum=1, | |
maximum=256, | |
step=1, | |
value=64 | |
) | |
sample_cutoff = gr.Slider( | |
label="sample cutoff", | |
minimum=0.0, | |
maximum=0.9, | |
value=1.0, | |
step=0.01 | |
) | |
dropout = gr.Slider( | |
label="mask dropout", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.0 | |
) | |
seed = gr.Number( | |
label="seed (0 for random)", | |
value=0, | |
precision=0, | |
) | |
# mask settings | |
with gr.Column(): | |
model_choice = gr.Dropdown( | |
label="model choice", | |
choices=list(MODEL_CHOICES.keys()), | |
value="default", | |
visible=True | |
) | |
vamp_button = gr.Button("generate (vamp)!!!") | |
audio_outs = [] | |
use_as_input_btns = [] | |
for i in range(1): | |
with gr.Column(): | |
audio_outs.append(gr.Audio( | |
label=f"output audio {i+1}", | |
interactive=False, | |
type="numpy" | |
)) | |
use_as_input_btns.append( | |
gr.Button(f"use as input (feedback)") | |
) | |
thank_you = gr.Markdown("") | |
# download all the outputs | |
# download = gr.File(type="filepath", label="download outputs") | |
_inputs = { | |
input_audio, | |
sampletemp, | |
top_p, | |
periodic_p, periodic_w, | |
dropout, | |
stretch_factor, | |
onset_mask_width, | |
typical_filtering, | |
typical_mass, | |
typical_min_tokens, | |
seed, | |
model_choice, | |
n_mask_codebooks, | |
pitch_shift_amt, | |
sample_cutoff, | |
} | |
# connect widgets | |
vamp_button.click( | |
fn=vamp, | |
inputs=_inputs, | |
outputs=[audio_outs[0]], | |
) | |
api_vamp_button = gr.Button("api vamp", visible=True) | |
api_vamp_button.click( | |
fn=api_vamp, | |
inputs=_inputs, | |
outputs=[audio_outs[0]], | |
api_name="vamp" | |
) | |
for i, btn in enumerate(use_as_input_btns): | |
btn.click( | |
fn=load_audio, | |
inputs=[audio_outs[i]], | |
outputs=[input_audio] | |
) | |
try: | |
demo.queue() | |
demo.launch(share=True) | |
except KeyboardInterrupt: | |
shutil.rmtree("gradio-outputs", ignore_errors=True) | |
raise |