jukwi-vqvae / app.py
vovahimself's picture
sys
156d0fd
raw
history blame
3.29 kB
# A simple gradio app that converts music tokens to and from audio using JukeboxVQVAE as the model and Gradio as the UI
import sys
import torch as t
from transformers import JukeboxVQVAE
import gradio as gr
model_id = 'openai/jukebox-5b-lyrics' #@param ['openai/jukebox-1b-lyrics', 'openai/jukebox-5b-lyrics']
if 'google.colab' in sys.modules:
cache_path = '/content/drive/My Drive/jukebox-webui/_data/' #@param {type:"string"}
# Connect to your Google Drive
from google.colab import drive
drive.mount('/content/drive')
else:
cache_path = '~/.cache/'
class Convert:
class TokenList:
def to_tokens_file(tokens_list):
# temporary random file name
filename = f"tmp/{t.randint(0, 1000000)}.jt"
t.save(validate_tokens_list(tokens_list), filename)
return filename
def to_audio(tokens_list):
return model.decode(validate_tokens_list(tokens_list)[2:], start_level=2).squeeze(-1)
# TODO: Implement converting other levels besides 2
class TokensFile:
def to_tokens_list(file):
return validate_tokens_list(t.load(file))
def to_audio(file):
return Convert.TokenList.to_audio(Convert.TokensFile.to_tokens_list(file))
class Audio:
def to_tokens_list(audio):
return model.encode(audio.unsqueeze(0), start_level=2)
# (TODO: Generated by copilot, check if it works)
def to_tokens_file(audio):
return Convert.TokenList.to_tokens_file(Convert.Audio.to_tokens_list(audio))
def init():
global model
try:
model
print("Model already initialized")
except NameError:
model = JukeboxVQVAE.from_pretrained(
model_id,
torch_dtype = t.float16,
cache_dir = f"{cache_path}/jukebox/models"
)
def validate_tokens_list(tokens_list):
# Make sure that:
# - tokens_list is a list of exactly 3 torch tensors
assert len(tokens_list) == 3, "Invalid file format: expecting a list of 3 tensors"
# - each has the same number of dimensions
assert len(tokens_list[0].shape) == len(tokens_list[1].shape) == len(tokens_list[2].shape), "Invalid file format: each tensor in the list should have the same number of dimensions"
# - the shape along dimension 0 is the same
assert tokens_list[0].shape[0] == tokens_list[1].shape[0] == tokens_list[2].shape[0], "Invalid file format: the shape along dimension 0 should be the same for all tensors in the list"
# - the shape along dimension 1 increases (or stays the same) as we go from 0 to 2
assert tokens_list[0].shape[1] >= tokens_list[1].shape[1] >= tokens_list[2].shape[1], "Invalid file format: the shape along dimension 1 should decrease (or stay the same) as we go from 0 to 2"
return tokens_list
with gr.Blocks() as ui:
# File input to upload or download the music tokens file
tokens = gr.File(label='music_tokens_file')
# Audio output to play or upload the generated audio
audio = gr.Audio(label='audio')
# Buttons to convert from music tokens to audio (primary) and vice versa (secondary)
gr.Button("Convert tokens to audio", variant='primary').click(Convert.TokensFile.to_audio, tokens, audio)
gr.Button("Convert audio to tokens", variant='secondary').click(Convert.Audio.to_tokens_file, audio, tokens)
if __name__ == '__main__':
init()
ui.launch()