File size: 7,522 Bytes
d248698
6c50c01
6c780ba
d248698
 
 
c03e124
d248698
2319c67
 
 
 
d248698
 
 
 
2319c67
d248698
 
 
 
 
2319c67
d248698
 
 
 
 
 
 
 
 
 
 
a22a294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9acf377
a22a294
 
 
 
 
 
 
 
 
 
 
 
 
d248698
0b8602d
5eca949
a22a294
 
3a35271
d248698
 
 
5eca949
 
a22a294
d248698
5eca949
d248698
2319c67
d248698
 
09934b1
 
5eca949
d248698
5eca949
d248698
 
 
 
 
5eca949
6c50c01
 
 
 
 
 
5eca949
d248698
 
 
5eca949
 
d248698
 
 
a22a294
 
 
d248698
5eca949
 
 
 
 
6c50c01
 
 
5eca949
 
 
 
ebaba50
3cd6b34
702dedf
 
 
 
 
 
 
09934b1
deb6407
df8718b
deb6407
a940384
deb6407
 
09934b1
2c4b552
09934b1
 
 
 
 
 
 
 
 
 
 
 
 
 
6c50c01
09934b1
 
 
 
ebaba50
0f91c64
9673b59
37b7d1c
 
ced44b1
52683b7
6aa5a17
f0d857e
 
 
09934b1
2439bcd
 
09934b1
5eca949
d248698
 
5eca949
a22a294
d248698
dbb8510
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import gradio as gr
import spaces
import huggingface_hub
import numpy as np
import pandas as pd
import os
import shutil
import torch
from audiocraft.data.audio import audio_write
import audiocraft.models


# download models
huggingface_hub.hf_hub_download(
    repo_id='Cyan0731/MusiConGen',
    filename='compression_state_dict.bin',
    local_dir='./ckpt/musicongen'
)

huggingface_hub.hf_hub_download(
    repo_id='Cyan0731/MusiConGen',
    filename='state_dict.bin',
    local_dir='./ckpt/musicongen'
)

def print_directory_contents(path):
    for root, dirs, files in os.walk(path):
        level = root.replace(path, '').count(os.sep)
        indent = ' ' * 4 * (level)
        print(f"{indent}{os.path.basename(root)}/")
        subindent = ' ' * 4 * (level + 1)
        for f in files:
            print(f"{subindent}{f}")

def check_outputs_folder(folder_path):
    # Check if the folder exists
    if os.path.exists(folder_path) and os.path.isdir(folder_path):
        # Delete all contents inside the folder
        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)  # Remove file or link
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)  # Remove directory
            except Exception as e:
                print(f'Failed to delete {file_path}. Reason: {e}')
    else:
        print(f'The folder {folder_path} does not exist.')

def check_for_wav_in_outputs():
    # Define the path to the outputs folder
    outputs_folder = './output_samples/example_1'
    
    # Check if the outputs folder exists
    if not os.path.exists(outputs_folder):
        return None
    
    # Check if there is a .mp4 file in the outputs folder
    mp4_files = [f for f in os.listdir(outputs_folder) if f.endswith('.wav')]
    
    # Return the path to the mp4 file if it exists
    if mp4_files:
        return os.path.join(outputs_folder, mp4_files[0])
    else:
        return None

@spaces.GPU(duration=80)
def infer(prompt_in, chords, duration, bpms):

    # check if 'outputs' dir exists and empty it if necessary
    check_outputs_folder('./output_samples/example_1')
    
    # set hparams
    output_dir = 'example_1' ### change this output directory
        
    duration = duration
    num_samples = 1
    bs = 1
        
    # load your model
    musicgen = audiocraft.models.MusicGen.get_pretrained('./ckpt/musicongen') ### change this path
    musicgen.set_generation_params(duration=duration, extend_stride=duration//2, top_k = 250)
    
    chords = [chords]
     
    descriptions = [prompt_in] * num_samples
    
    bpms = [bpms] * num_samples
    
    meters = [4] * num_samples
    
    wav = []
    for i in range(num_samples//bs):
        print(f"starting {i} batch...")
        temp = musicgen.generate_with_chords_and_beats(
            descriptions[i*bs:(i+1)*bs],
            chords[i*bs:(i+1)*bs],
            bpms[i*bs:(i+1)*bs],
            meters[i*bs:(i+1)*bs]
        )
        wav.extend(temp.cpu())
    
    # save and display generated audio
    for idx, one_wav in enumerate(wav):
        sav_path = os.path.join('./output_samples', output_dir, chords[idx] + "|" + descriptions[idx]).replace(" ", "_")
        audio_write(sav_path, one_wav.cpu(), musicgen.sample_rate, strategy='loudness', loudness_compressor=True)

    # Print the outputs directory contents
    print_directory_contents('./output_samples')
    wav_file_path = check_for_wav_in_outputs()
    print(wav_file_path)
    return wav_file_path

css="""
#col-container{
    max-width: 800px;
    margin: 0 auto;
}
#chords-examples button{
    font-size: 20px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# MusiConGen")
        gr.Markdown("## Rhythm and Chord Control for Transformer-Based Text-to-Music Generation")
        gr.HTML("""
        <div style="display:flex;column-gap:4px;">
            <a href='https://musicongen.github.io/musicongen_demo/'>
                <img src='https://img.shields.io/badge/Project-Page-Green'>
            </a>
        </div>
        """)
        with gr.Column():
            with gr.Group():
                prompt_in = gr.Textbox(label="Music description", value="A smooth acid jazz track with a laid-back groove, silky electric piano, and a cool bass, providing a modern take on jazz. Instruments: electric piano, bass, drums.")
                with gr.Row():
                    chords = gr.Textbox(label="Chords progression", value='B:min D F#:min E', scale=1.75)
                    duration = gr.Slider(label="Sample duration", minimum=4, maximum=30, step=1, value=30)
                    bpms = gr.Slider(label="BPMs", minimum=50, maximum=220, step=1, value=120)
            submit_btn = gr.Button("Submit")
            wav_out = gr.Audio(label="Wav Result", value="./MusiConGen_default_sample_space_example.wav")
            with gr.Row():
                gr.Examples(
                    label = "Audio description examples",
                    examples = [
                        ["A laid-back blues shuffle with a relaxed tempo, warm guitar tones, and a comfortable groove, perfect for a slow dance or a night in. Instruments: electric guitar, bass, drums."],
                        ["A smooth acid jazz track with a laid-back groove, silky electric piano, and a cool bass, providing a modern take on jazz. Instruments: electric piano, bass, drums."],
                        ["A classic rock n' roll tune with catchy guitar riffs, driving drums, and a pulsating bass line, reminiscent of the golden era of rock. Instruments: electric guitar, bass, drums."],
                        ["A high-energy funk tune with slap bass, rhythmic guitar riffs, and a tight horn section, guaranteed to get you grooving. Instruments: bass, guitar, trumpet, saxophone, drums."],
                        ["A heavy metal onslaught with double kick drum madness, aggressive guitar riffs, and an unrelenting bass, embodying the spirit of metal. Instruments: electric guitar, bass guitar, drums."]
                    ],
                    inputs = [prompt_in]
                )
                gr.Examples(
                    label = "Chords progression examples",
                    elem_id = "chords-examples",
                    examples = ['C G A:min F',
                                'A:min F C G',
                                'C F G F',
                                'C A:min F G',
                                'D:min G C A:min',
                                'D:min7 G:7 C:maj7 C:maj7',
                                'F G E:min A:min',
                                'B:min D F#:min E',
                                'F G E A:min',
                                'C Bb F C',
                                'A:min C D F',
                                'B:min F#:min E:min B:min',
                                'B:min7 E:9 A:maj7 A:maj7 C#:7 F#:min7',
                                'F:min G:min Ab Bb',
                                'A:min G F D:min'
                               ],
                    inputs = [chords],
                    examples_per_page = 16
                )
        
    submit_btn.click(
        fn = infer,
        inputs = [prompt_in, chords, duration, bpms],
        outputs = [wav_out]
    )
demo.launch(show_api=False, show_error=True)