Spaces:
Runtime error
Runtime error
# A simple gradio app that converts music tokens to and from audio using JukeboxVQVAE as the model and Gradio as the UI | |
import sys | |
import torch as t | |
from transformers import JukeboxVQVAE | |
import gradio as gr | |
model_id = 'openai/jukebox-5b-lyrics' #@param ['openai/jukebox-1b-lyrics', 'openai/jukebox-5b-lyrics'] | |
if 'google.colab' in sys.modules: | |
cache_path = '/content/drive/My Drive/jukebox-webui/_data/' #@param {type:"string"} | |
# Connect to your Google Drive | |
from google.colab import drive | |
drive.mount('/content/drive') | |
else: | |
cache_path = '~/.cache/' | |
class Convert: | |
class TokenList: | |
def to_tokens_file(tokens_list): | |
# temporary random file name | |
filename = f"tmp/{t.randint(0, 1000000)}.jt" | |
t.save(validate_tokens_list(tokens_list), filename) | |
return filename | |
def to_audio(tokens_list): | |
return model.decode(validate_tokens_list(tokens_list)[2:], start_level=2).squeeze(-1) | |
# TODO: Implement converting other levels besides 2 | |
class TokensFile: | |
def to_tokens_list(file): | |
return validate_tokens_list(t.load(file)) | |
def to_audio(file): | |
return Convert.TokenList.to_audio(Convert.TokensFile.to_tokens_list(file)) | |
class Audio: | |
def to_tokens_list(audio): | |
return model.encode(audio.unsqueeze(0), start_level=2) | |
# (TODO: Generated by copilot, check if it works) | |
def to_tokens_file(audio): | |
return Convert.TokenList.to_tokens_file(Convert.Audio.to_tokens_list(audio)) | |
def init(): | |
global model | |
try: | |
model | |
print("Model already initialized") | |
except NameError: | |
model = JukeboxVQVAE.from_pretrained( | |
model_id, | |
torch_dtype = t.float16, | |
cache_dir = f"{cache_path}/jukebox/models" | |
) | |
def validate_tokens_list(tokens_list): | |
# Make sure that: | |
# - tokens_list is a list of exactly 3 torch tensors | |
assert len(tokens_list) == 3, "Invalid file format: expecting a list of 3 tensors" | |
# - each has the same number of dimensions | |
assert len(tokens_list[0].shape) == len(tokens_list[1].shape) == len(tokens_list[2].shape), "Invalid file format: each tensor in the list should have the same number of dimensions" | |
# - the shape along dimension 0 is the same | |
assert tokens_list[0].shape[0] == tokens_list[1].shape[0] == tokens_list[2].shape[0], "Invalid file format: the shape along dimension 0 should be the same for all tensors in the list" | |
# - the shape along dimension 1 increases (or stays the same) as we go from 0 to 2 | |
assert tokens_list[0].shape[1] >= tokens_list[1].shape[1] >= tokens_list[2].shape[1], "Invalid file format: the shape along dimension 1 should decrease (or stay the same) as we go from 0 to 2" | |
return tokens_list | |
with gr.Blocks() as ui: | |
# File input to upload or download the music tokens file | |
tokens = gr.File(label='music_tokens_file') | |
# Audio output to play or upload the generated audio | |
audio = gr.Audio(label='audio') | |
# Buttons to convert from music tokens to audio (primary) and vice versa (secondary) | |
gr.Button("Convert tokens to audio", variant='primary').click(Convert.TokensFile.to_audio, tokens, audio) | |
gr.Button("Convert audio to tokens", variant='secondary').click(Convert.Audio.to_tokens_file, audio, tokens) | |
if __name__ == '__main__': | |
init() | |
ui.launch() |