asigalov61
commited on
Commit
•
50f85db
1
Parent(s):
bf37476
Update app.py
Browse files
app.py
CHANGED
@@ -31,8 +31,8 @@ def GenerateSong(input_melody_seed_number):
|
|
31 |
|
32 |
print('Loading model...')
|
33 |
|
34 |
-
SEQ_LEN =
|
35 |
-
PAD_IDX =
|
36 |
DEVICE = 'cuda' # 'cuda'
|
37 |
|
38 |
# instantiate the model
|
@@ -40,7 +40,7 @@ def GenerateSong(input_melody_seed_number):
|
|
40 |
model = TransformerWrapper(
|
41 |
num_tokens = PAD_IDX+1,
|
42 |
max_seq_len = SEQ_LEN,
|
43 |
-
attn_layers = Decoder(dim = 1024, depth =
|
44 |
)
|
45 |
|
46 |
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
|
@@ -51,7 +51,7 @@ def GenerateSong(input_melody_seed_number):
|
|
51 |
print('Loading model checkpoint...')
|
52 |
|
53 |
model.load_state_dict(
|
54 |
-
torch.load('
|
55 |
map_location=DEVICE))
|
56 |
print('=' * 70)
|
57 |
|
@@ -201,7 +201,7 @@ if __name__ == "__main__":
|
|
201 |
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
|
202 |
|
203 |
print('Loading seed meldoies data...')
|
204 |
-
seed_melodies_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('
|
205 |
print('=' * 70)
|
206 |
|
207 |
app = gr.Blocks()
|
|
|
31 |
|
32 |
print('Loading model...')
|
33 |
|
34 |
+
SEQ_LEN = 1024
|
35 |
+
PAD_IDX = 14627
|
36 |
DEVICE = 'cuda' # 'cuda'
|
37 |
|
38 |
# instantiate the model
|
|
|
40 |
model = TransformerWrapper(
|
41 |
num_tokens = PAD_IDX+1,
|
42 |
max_seq_len = SEQ_LEN,
|
43 |
+
attn_layers = Decoder(dim = 1024, depth = 12, heads = 16, attn_flash = True)
|
44 |
)
|
45 |
|
46 |
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
|
|
|
51 |
print('Loading model checkpoint...')
|
52 |
|
53 |
model.load_state_dict(
|
54 |
+
torch.load('Annotated_MIDI_Dataset_Classifier_Trained_Model_21269_steps_0.4335_loss_0.8716_acc.pth',
|
55 |
map_location=DEVICE))
|
56 |
print('=' * 70)
|
57 |
|
|
|
201 |
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
|
202 |
|
203 |
print('Loading seed meldoies data...')
|
204 |
+
seed_melodies_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('processed_scores')
|
205 |
print('=' * 70)
|
206 |
|
207 |
app = gr.Blocks()
|