Spaces:
Runtime error
Runtime error
Epsilon617
commited on
Commit
•
92cd759
1
Parent(s):
c2c7513
add genre prediction head
Browse files
Prediction_Head/MTGGenre_head.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class MLPProberBase(nn.Module):
|
6 |
+
def __init__(self, d=768, num_outputs=87):
|
7 |
+
super().__init__()
|
8 |
+
self.hidden_layer_sizes = [512, ] # eval(self.cfg.hidden_layer_sizes)
|
9 |
+
self.num_layers = len(self.hidden_layer_sizes)
|
10 |
+
for i, ld in enumerate(self.hidden_layer_sizes):
|
11 |
+
setattr(self, f"hidden_{i}", nn.Linear(d, ld))
|
12 |
+
d = ld
|
13 |
+
self.output = nn.Linear(d, num_outputs)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
for i in range(self.num_layers):
|
17 |
+
x = getattr(self, f"hidden_{i}")(x)
|
18 |
+
# x = self.dropout(x)
|
19 |
+
x = F.relu(x)
|
20 |
+
output = self.output(x)
|
21 |
+
return output
|
Prediction_Head/MTGGenre_id2class.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"0": "genre---rock", "1": "genre---pop", "2": "genre---classical", "3": "genre---popfolk", "4": "genre---disco", "5": "genre---funk", "6": "genre---rnb", "7": "genre---ambient", "8": "genre---chillout", "9": "genre---downtempo", "10": "genre---easylistening", "11": "genre---electronic", "12": "genre---lounge", "13": "genre---triphop", "14": "genre---breakbeat", "15": "genre---techno", "16": "genre---newage", "17": "genre---jazz", "18": "genre---metal", "19": "genre---industrial", "20": "genre---instrumentalrock", "21": "genre---minimal", "22": "genre---alternative", "23": "genre---experimental", "24": "genre---drumnbass", "25": "genre---soul", "26": "genre---fusion", "27": "genre---soundtrack", "28": "genre---electropop", "29": "genre---world", "30": "genre---ethno", "31": "genre---trance", "32": "genre---orchestral", "33": "genre---grunge", "34": "genre---chanson", "35": "genre---worldfusion", "36": "genre---hiphop", "37": "genre---groove", "38": "genre---instrumentalpop", "39": "genre---blues", "40": "genre---reggae", "41": "genre---dance", "42": "genre---club", "43": "genre---punkrock", "44": "genre---folk", "45": "genre---synthpop", "46": "genre---poprock", "47": "genre---choir", "48": "genre---symphonic", "49": "genre---indie", "50": "genre---progressive", "51": "genre---acidjazz", "52": "genre---contemporary", "53": "genre---newwave", "54": "genre---dub", "55": "genre---rocknroll", "56": "genre---hard", "57": "genre---hardrock", "58": "genre---house", "59": "genre---atmospheric", "60": "genre---psychedelic", "61": "genre---improvisation", "62": "genre---country", "63": "genre---electronica", "64": "genre---rap", "65": "genre---60s", "66": "genre---70s", "67": "genre---darkambient", "68": "genre---idm", "69": "genre---latin", "70": "genre---postrock", "71": "genre---bossanova", "72": "genre---singersongwriter", "73": "genre---darkwave", "74": "genre---swing", "75": "genre---medieval", "76": "genre---celtic", "77": "genre---eurodance", "78": "genre---classicrock", "79": "genre---dubstep", "80": "genre---bluesrock", "81": "genre---edm", "82": "genre---deephouse", "83": "genre---jazzfusion", "84": "genre---alternativerock", "85": "genre---80s", "86": "genre---90s"}
|
Prediction_Head/__pycache__/MTGGenre_head.cpython-310.pyc
ADDED
Binary file (1.08 kB). View file
|
|
Prediction_Head/best_MTGGenre.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:83b7dcffde10a0dc7ba74341ea56dabec5c5de7cad6a0483708c80f1d893514a
|
3 |
+
size 1759067
|
__pycache__/app.cpython-310.pyc
CHANGED
Binary files a/__pycache__/app.cpython-310.pyc and b/__pycache__/app.cpython-310.pyc differ
|
|
app.py
CHANGED
@@ -8,9 +8,12 @@ import torchaudio
|
|
8 |
import torchaudio.transforms as T
|
9 |
import logging
|
10 |
|
|
|
|
|
11 |
import importlib
|
12 |
modeling_MERT = importlib.import_module("MERT-v0-public.modeling_MERT")
|
13 |
|
|
|
14 |
# input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py
|
15 |
|
16 |
|
@@ -34,7 +37,7 @@ live_inputs = [
|
|
34 |
]
|
35 |
# outputs = [gr.components.Textbox()]
|
36 |
# outputs = [gr.components.Textbox(), transcription_df]
|
37 |
-
title = "
|
38 |
description = "An example of using MERT-95M-public to conduct music tagging."
|
39 |
article = ""
|
40 |
audio_examples = [
|
@@ -48,9 +51,17 @@ audio_examples = [
|
|
48 |
model = modeling_MERT.MERTModel.from_pretrained("./MERT-v0-public")
|
49 |
processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v0-public")
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
53 |
model.to(device)
|
|
|
54 |
|
55 |
def convert_audio(inputs, microphone):
|
56 |
if (microphone is not None):
|
@@ -75,10 +86,17 @@ def convert_audio(inputs, microphone):
|
|
75 |
# take a look at the output shape, there are 13 layers of representation
|
76 |
# each layer performs differently in different downstream tasks, you should choose empirically
|
77 |
all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
# logger.warning(all_layer_hidden_states.shape)
|
80 |
|
81 |
-
return f"device {device}
|
|
|
82 |
|
83 |
def live_convert_audio(microphone):
|
84 |
if (microphone is not None):
|
@@ -103,10 +121,17 @@ def live_convert_audio(microphone):
|
|
103 |
# take a look at the output shape, there are 13 layers of representation
|
104 |
# each layer performs differently in different downstream tasks, you should choose empirically
|
105 |
all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
# logger.warning(all_layer_hidden_states.shape)
|
108 |
|
109 |
-
return f"device {device}, sample reprensentation: {str(all_layer_hidden_states[12, 0, :10])}"
|
|
|
110 |
|
111 |
|
112 |
audio_chunked = gr.Interface(
|
|
|
8 |
import torchaudio.transforms as T
|
9 |
import logging
|
10 |
|
11 |
+
import json
|
12 |
+
|
13 |
import importlib
|
14 |
modeling_MERT = importlib.import_module("MERT-v0-public.modeling_MERT")
|
15 |
|
16 |
+
from Prediction_Head.MTGGenre_head import MLPProberBase
|
17 |
# input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py
|
18 |
|
19 |
|
|
|
37 |
]
|
38 |
# outputs = [gr.components.Textbox()]
|
39 |
# outputs = [gr.components.Textbox(), transcription_df]
|
40 |
+
title = "Predict the top 5 possible genres of Music"
|
41 |
description = "An example of using MERT-95M-public to conduct music tagging."
|
42 |
article = ""
|
43 |
audio_examples = [
|
|
|
51 |
model = modeling_MERT.MERTModel.from_pretrained("./MERT-v0-public")
|
52 |
processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v0-public")
|
53 |
|
54 |
+
MERT_LAYER_IDX = 7
|
55 |
+
MTGGenre_classifier = MLPProberBase()
|
56 |
+
MTGGenre_classifier.load_state_dict(torch.load('Prediction_Head/best_MTGGenre.ckpt')['state_dict'])
|
57 |
+
|
58 |
+
with open('Prediction_Head/MTGGenre_id2class.json', 'r') as f:
|
59 |
+
id2cls=json.load(f)
|
60 |
+
|
61 |
|
62 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
63 |
model.to(device)
|
64 |
+
MTGGenre_classifier.to(device)
|
65 |
|
66 |
def convert_audio(inputs, microphone):
|
67 |
if (microphone is not None):
|
|
|
86 |
# take a look at the output shape, there are 13 layers of representation
|
87 |
# each layer performs differently in different downstream tasks, you should choose empirically
|
88 |
all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()
|
89 |
+
print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
|
90 |
+
|
91 |
+
logits = MTGGenre_classifier(torch.mean(all_layer_hidden_states[MERT_LAYER_IDX], dim=0)) # [1, 87]
|
92 |
+
print(logits.shape)
|
93 |
+
sorted_idx = torch.argsort(logits, dim = -1, descending=True)
|
94 |
+
|
95 |
+
output_texts = "\n".join([id2cls[str(idx.item())].replace('genre---', '') for idx in sorted_idx[:5]])
|
96 |
# logger.warning(all_layer_hidden_states.shape)
|
97 |
|
98 |
+
# return f"device {device}, sample reprensentation: {str(all_layer_hidden_states[12, 0, :10])}"
|
99 |
+
return f"device: {device}\n" + output_texts
|
100 |
|
101 |
def live_convert_audio(microphone):
|
102 |
if (microphone is not None):
|
|
|
121 |
# take a look at the output shape, there are 13 layers of representation
|
122 |
# each layer performs differently in different downstream tasks, you should choose empirically
|
123 |
all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()
|
124 |
+
print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
|
125 |
+
|
126 |
+
logits = MTGGenre_classifier(torch.mean(all_layer_hidden_states[MERT_LAYER_IDX], dim=0)) # [1, 87]
|
127 |
+
print(logits.shape)
|
128 |
+
sorted_idx = torch.argsort(logits, dim = -1, descending=True)
|
129 |
+
|
130 |
+
output_texts = "\n".join([id2cls[str(idx.item())].replace('genre---', '') for idx in sorted_idx[:5]])
|
131 |
# logger.warning(all_layer_hidden_states.shape)
|
132 |
|
133 |
+
# return f"device {device}, sample reprensentation: {str(all_layer_hidden_states[12, 0, :10])}"
|
134 |
+
return f"device: {device}\n" + output_texts
|
135 |
|
136 |
|
137 |
audio_chunked = gr.Interface(
|