asigalov61 commited on
Commit
50f85db
1 Parent(s): bf37476

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -31,8 +31,8 @@ def GenerateSong(input_melody_seed_number):
31
 
32
  print('Loading model...')
33
 
34
- SEQ_LEN = 2560
35
- PAD_IDX = 514
36
  DEVICE = 'cuda' # 'cuda'
37
 
38
  # instantiate the model
@@ -40,7 +40,7 @@ def GenerateSong(input_melody_seed_number):
40
  model = TransformerWrapper(
41
  num_tokens = PAD_IDX+1,
42
  max_seq_len = SEQ_LEN,
43
- attn_layers = Decoder(dim = 1024, depth = 24, heads = 16, attn_flash = True)
44
  )
45
 
46
  model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
@@ -51,7 +51,7 @@ def GenerateSong(input_melody_seed_number):
51
  print('Loading model checkpoint...')
52
 
53
  model.load_state_dict(
54
- torch.load('Melody2Song_Seq2Seq_Music_Transformer_Trained_Model_28482_steps_0.719_loss_0.7865_acc.pth',
55
  map_location=DEVICE))
56
  print('=' * 70)
57
 
@@ -201,7 +201,7 @@ if __name__ == "__main__":
201
  soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
202
 
203
  print('Loading seed meldoies data...')
204
- seed_melodies_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data')
205
  print('=' * 70)
206
 
207
  app = gr.Blocks()
 
31
 
32
  print('Loading model...')
33
 
34
+ SEQ_LEN = 1024
35
+ PAD_IDX = 14627
36
  DEVICE = 'cuda' # 'cuda'
37
 
38
  # instantiate the model
 
40
  model = TransformerWrapper(
41
  num_tokens = PAD_IDX+1,
42
  max_seq_len = SEQ_LEN,
43
+ attn_layers = Decoder(dim = 1024, depth = 12, heads = 16, attn_flash = True)
44
  )
45
 
46
  model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
 
51
  print('Loading model checkpoint...')
52
 
53
  model.load_state_dict(
54
+ torch.load('Annotated_MIDI_Dataset_Classifier_Trained_Model_21269_steps_0.4335_loss_0.8716_acc.pth',
55
  map_location=DEVICE))
56
  print('=' * 70)
57
 
 
201
  soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
202
 
203
  print('Loading seed meldoies data...')
204
+ seed_melodies_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('processed_scores')
205
  print('=' * 70)
206
 
207
  app = gr.Blocks()