AescF commited on
Commit
fc74109
1 Parent(s): 33d8db0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -31
app.py CHANGED
@@ -1,13 +1,10 @@
1
  import gradio as gr
2
  import librosa
3
  import numpy as np
4
- from transformers import AutoFeatureExtractor
5
- import os
6
 
7
 
8
- model_id = "AescF/hubert-base-ls960-finetuned-common_language"
9
- processor = Wav2Vec2Processor.from_pretrained(model_id)
10
- model = Wav2Vec2ForClassification.from_pretrained(model_id)
11
  language_classes = {
12
  0: "Arabic",
13
  1: "Basque",
@@ -57,31 +54,51 @@ language_classes = {
57
  }
58
 
59
 
60
- def predict_language(audio):
61
- # Read audio file
62
- audio_input, sr = librosa.load(audio, sr=16000)
63
-
64
- # Convert to suitable format
65
- input_values = processor(audio_input, return_tensors="pt", padding=True).input_values
66
-
67
- # Make prediction
68
- with torch.no_grad():
69
- logits = model(input_values).logits
70
-
71
- # Compute probabilities
72
- probabilities = torch.softmax(logits, dim=1)
73
-
74
- # Retrieve label
75
- predicted_language_idx = torch.argmax(probabilities[0]).item()
76
-
77
- return {language_classes[predicted_language_idx]: float(probabilities[0][predicted_language_idx])}
78
 
79
- iface = gr.Interface(
80
- predict_language,
81
- inputs=gr.inputs.Audio(type="filepath", label="Upload Language Audio file"),
82
- outputs=gr.outputs.Label(),
83
- title="Language Classifier",
84
- live=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  )
86
- script_dir = os.path.abspath(os.path.join(os.path.abspath(''), os.pardir))
87
- iface.launch()
 
1
  import gradio as gr
2
  import librosa
3
  import numpy as np
4
+ import torch
5
+ from transformers import pipeline
6
 
7
 
 
 
 
8
  language_classes = {
9
  0: "Arabic",
10
  1: "Basque",
 
54
  }
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+
59
+ username = "AescF" ## Complete your username
60
+ model_id = "AescF/hubert-base-ls960-finetuned-common_language"
61
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
62
+ pipe = pipeline("audio-classification", model=model_id, device=device)
63
+
64
+ # def predict_trunc(filepath):
65
+ # preprocessed = pipe.preprocess(filepath)
66
+ # truncated = pipe.feature_extractor.pad(preprocessed,truncation=True, max_length = 16_000*30)
67
+ # model_outputs = pipe.forward(truncated)
68
+ # outputs = pipe.postprocess(model_outputs)
69
+
70
+ # return outputs
71
+
72
+
73
+ def classify_audio(filepath):
74
+ """
75
+ Goes from
76
+ [{'score': 0.8339303731918335, 'label': 'country'},
77
+ {'score': 0.11914275586605072, 'label': 'rock'},]
78
+ to
79
+ {"country": 0.8339303731918335, "rock":0.11914275586605072}
80
+ """
81
+ start_time = timer()
82
+ preds = pipe(filepath)
83
+ # preds = predict_trunc(filepath)
84
+ outputs = {}
85
+ pred_time = round(timer() - start_time, 5)
86
+ for p in preds:
87
+ outputs[p["label"]] = p["score"], timer
88
+ return outputs
89
+
90
+
91
+ title = "🎵 Music Genre Classifier"
92
+ description = """
93
+ Demo for a music genre classifier trained on [GTZAN](https://huggingface.co/datasets/marsyas/gtzan)
94
+ For more info checkout [GITHUB](https://github.com/AEscF)
95
+ """
96
+ demo = gr.Interface(
97
+ fn=classify_audio,
98
+ inputs=gr.Audio(type="filepath"),
99
+ outputs=[gr.Label(label="Predictions"), gr.Number(label="Prediction time (s)")],
100
+ title=title,
101
+ description=description,
102
+ examples=filenames,
103
  )
104
+ demo.launch()