OpenSound commited on
Commit
d9a7330
1 Parent(s): b9d6819

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +117 -117
api.py CHANGED
@@ -1,117 +1,117 @@
1
- import os
2
- import torch
3
- import random
4
- import numpy as np
5
- import gradio as gr
6
- import soundfile as sf
7
- from transformers import T5Tokenizer, T5EncoderModel
8
- from diffusers import DDIMScheduler
9
- from src.models.conditioners import MaskDiT
10
- from src.modules.autoencoder_wrapper import Autoencoder
11
- from src.inference import inference
12
- from src.utils import load_yaml_with_includes
13
-
14
-
15
- # Load model and configs
16
- def load_models(config_name, ckpt_path, vae_path, device):
17
- params = load_yaml_with_includes(config_name)
18
-
19
- # Load codec model
20
- autoencoder = Autoencoder(ckpt_path=vae_path,
21
- model_type=params['autoencoder']['name'],
22
- quantization_first=params['autoencoder']['q_first']).to(device)
23
- autoencoder.eval()
24
-
25
- # Load text encoder
26
- tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model'])
27
- text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model']).to(device)
28
- text_encoder.eval()
29
-
30
- # Load main U-Net model
31
- unet = MaskDiT(**params['model']).to(device)
32
- unet.load_state_dict(torch.load(ckpt_path)['model'])
33
- unet.eval()
34
-
35
- # Load noise scheduler
36
- noise_scheduler = DDIMScheduler(**params['diff'])
37
-
38
- return autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params
39
-
40
- MAX_SEED = np.iinfo(np.int32).max
41
-
42
- # Model and config paths
43
- config_name = 'ckpts/ezaudio-xl.yml'
44
- ckpt_path = 'ckpts/s3/ezaudio_s3_xl.pt'
45
- vae_path = 'ckpts/vae/1m.pt'
46
- save_path = 'output/'
47
- os.makedirs(save_path, exist_ok=True)
48
-
49
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
50
-
51
- autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params = load_models(config_name, ckpt_path, vae_path,
52
- device)
53
-
54
- latents = torch.randn((1, 128, 128), device=device)
55
- noise = torch.randn_like(latents)
56
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=device)
57
- _ = noise_scheduler.add_noise(latents, noise, timesteps)
58
-
59
-
60
- # Inference function
61
- def generate_audio(text, length,
62
- guidance_scale, guidance_rescale, ddim_steps, eta,
63
- random_seed, randomize_seed):
64
- neg_text = None
65
- length = length * params['autoencoder']['latent_sr']
66
-
67
- if randomize_seed:
68
- random_seed = random.randint(0, MAX_SEED)
69
-
70
- pred = inference(autoencoder, unet, None, None,
71
- tokenizer, text_encoder,
72
- params, noise_scheduler,
73
- text, neg_text,
74
- length,
75
- guidance_scale, guidance_rescale,
76
- ddim_steps, eta, random_seed,
77
- device)
78
-
79
- pred = pred.cpu().numpy().squeeze(0).squeeze(0)
80
- # output_file = f"{save_path}/{text}.wav"
81
- # sf.write(output_file, pred, samplerate=params['autoencoder']['sr'])
82
-
83
- return params['autoencoder']['sr'], pred
84
-
85
-
86
- # Gradio Interface
87
- def gradio_interface():
88
- # Input components
89
- text_input = gr.Textbox(label="Text Prompt", value="the sound of dog barking")
90
- length_input = gr.Slider(minimum=1, maximum=10, step=1, value=10, label="Audio Length (in seconds)")
91
-
92
- # Advanced settings
93
- guidance_scale_input = gr.Slider(minimum=1.0, maximum=10, step=0.1, value=5, label="Guidance Scale")
94
- guidance_rescale_input = gr.Slider(minimum=0.0, maximum=1, step=0.05, value=0.75, label="Guidance Rescale")
95
- ddim_steps_input = gr.Slider(minimum=25, maximum=200, step=5, value=100, label="DDIM Steps")
96
- eta_input = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1, label="Eta")
97
- random_seed_input = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=0,)
98
-
99
- randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
100
-
101
- # Output component
102
- output_audio = gr.Audio(label="Converted Audio", type="numpy")
103
-
104
- # Interface
105
- gr.Interface(
106
- fn=generate_audio,
107
- inputs=[text_input, length_input, guidance_scale_input, guidance_rescale_input, ddim_steps_input, eta_input,
108
- random_seed_input, randomize_seed],
109
- outputs=output_audio,
110
- title="EzAudio Text-to-Audio Generator",
111
- description="Generate audio from text using a diffusion model. Adjust advanced settings for more control.",
112
- allow_flagging="never"
113
- ).launch()
114
-
115
-
116
- if __name__ == "__main__":
117
- gradio_interface()
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import spaces
5
+ import numpy as np
6
+ import gradio as gr
7
+ import soundfile as sf
8
+ from transformers import T5Tokenizer, T5EncoderModel
9
+ from diffusers import DDIMScheduler
10
+ from src.models.conditioners import MaskDiT
11
+ from src.modules.autoencoder_wrapper import Autoencoder
12
+ from src.inference import inference
13
+ from src.utils import load_yaml_with_includes
14
+
15
+
16
+ # Load model and configs
17
+ def load_models(config_name, ckpt_path, vae_path, device):
18
+ params = load_yaml_with_includes(config_name)
19
+
20
+ # Load codec model
21
+ autoencoder = Autoencoder(ckpt_path=vae_path,
22
+ model_type=params['autoencoder']['name'],
23
+ quantization_first=params['autoencoder']['q_first']).to(device)
24
+ autoencoder.eval()
25
+
26
+ # Load text encoder
27
+ tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model'])
28
+ text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model']).to(device)
29
+ text_encoder.eval()
30
+
31
+ # Load main U-Net model
32
+ unet = MaskDiT(**params['model']).to(device)
33
+ unet.load_state_dict(torch.load(ckpt_path)['model'])
34
+ unet.eval()
35
+
36
+ # Load noise scheduler
37
+ noise_scheduler = DDIMScheduler(**params['diff'])
38
+
39
+ latents = torch.randn((1, 128, 128), device=device)
40
+ noise = torch.randn_like(latents)
41
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=device)
42
+ _ = noise_scheduler.add_noise(latents, noise, timesteps)
43
+
44
+ return autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params
45
+
46
+ MAX_SEED = np.iinfo(np.int32).max
47
+
48
+ # Model and config paths
49
+ config_name = 'ckpts/ezaudio-xl.yml'
50
+ ckpt_path = 'ckpts/s3/ezaudio_s3_xl.pt'
51
+ vae_path = 'ckpts/vae/1m.pt'
52
+ save_path = 'output/'
53
+ os.makedirs(save_path, exist_ok=True)
54
+
55
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
56
+
57
+ autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params = load_models(config_name, ckpt_path, vae_path,
58
+ device)
59
+
60
+ @spaces.GPU
61
+ def generate_audio(text, length,
62
+ guidance_scale, guidance_rescale, ddim_steps, eta,
63
+ random_seed, randomize_seed):
64
+ neg_text = None
65
+ length = length * params['autoencoder']['latent_sr']
66
+
67
+ if randomize_seed:
68
+ random_seed = random.randint(0, MAX_SEED)
69
+
70
+ pred = inference(autoencoder, unet, None, None,
71
+ tokenizer, text_encoder,
72
+ params, noise_scheduler,
73
+ text, neg_text,
74
+ length,
75
+ guidance_scale, guidance_rescale,
76
+ ddim_steps, eta, random_seed,
77
+ device)
78
+
79
+ pred = pred.cpu().numpy().squeeze(0).squeeze(0)
80
+ # output_file = f"{save_path}/{text}.wav"
81
+ # sf.write(output_file, pred, samplerate=params['autoencoder']['sr'])
82
+
83
+ return params['autoencoder']['sr'], pred
84
+
85
+
86
+ # Gradio Interface
87
+ def gradio_interface():
88
+ # Input components
89
+ text_input = gr.Textbox(label="Text Prompt", value="the sound of dog barking")
90
+ length_input = gr.Slider(minimum=1, maximum=10, step=1, value=10, label="Audio Length (in seconds)")
91
+
92
+ # Advanced settings
93
+ guidance_scale_input = gr.Slider(minimum=1.0, maximum=10, step=0.1, value=5, label="Guidance Scale")
94
+ guidance_rescale_input = gr.Slider(minimum=0.0, maximum=1, step=0.05, value=0.75, label="Guidance Rescale")
95
+ ddim_steps_input = gr.Slider(minimum=25, maximum=200, step=5, value=100, label="DDIM Steps")
96
+ eta_input = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1, label="Eta")
97
+ random_seed_input = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=0,)
98
+
99
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
100
+
101
+ # Output component
102
+ output_audio = gr.Audio(label="Converted Audio", type="numpy")
103
+
104
+ # Interface
105
+ gr.Interface(
106
+ fn=generate_audio,
107
+ inputs=[text_input, length_input, guidance_scale_input, guidance_rescale_input, ddim_steps_input, eta_input,
108
+ random_seed_input, randomize_seed],
109
+ outputs=output_audio,
110
+ title="EzAudio Text-to-Audio Generator",
111
+ description="Generate audio from text using a diffusion model. Adjust advanced settings for more control.",
112
+ allow_flagging="never"
113
+ ).launch()
114
+
115
+
116
+ if __name__ == "__main__":
117
+ gradio_interface()