ardneebwar
commited on
Commit
•
38bff40
1
Parent(s):
13c1e51
Added code it use it locally.
Browse files
README.md
CHANGED
@@ -82,4 +82,47 @@ The following hyperparameters were used during training:
|
|
82 |
|
83 |
### Github Repository
|
84 |
|
85 |
-
[Animal Sound Classification](https://github.com/rawbeen248/audio_classification_finetuning)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
### Github Repository
|
84 |
|
85 |
+
[Animal Sound Classification](https://github.com/rawbeen248/audio_classification_finetuning)
|
86 |
+
|
87 |
+
|
88 |
+
### To try it locally
|
89 |
+
|
90 |
+
```
|
91 |
+
import librosa
|
92 |
+
import torch
|
93 |
+
from transformers import HubertForSequenceClassification, Wav2Vec2FeatureExtractor
|
94 |
+
|
95 |
+
# Load the fine-tuned model and feature extractor
|
96 |
+
model_name = "ardneebwar/wav2vec2-animal-sounds-finetuned-hubert-finetuned-animals"
|
97 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
|
98 |
+
model = HubertForSequenceClassification.from_pretrained(model_name)
|
99 |
+
|
100 |
+
# Prepare the device
|
101 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
102 |
+
model.to(device)
|
103 |
+
model.eval() # Set the model to evaluation mode
|
104 |
+
|
105 |
+
# Function to predict the class of an audio file
|
106 |
+
def predict_audio_class(audio_file, feature_extractor, model, device):
|
107 |
+
# Load and preprocess the audio file
|
108 |
+
speech, sr = librosa.load(audio_file, sr=16000)
|
109 |
+
input_values = feature_extractor(speech, return_tensors="pt", sampling_rate=16000).input_values
|
110 |
+
input_values = input_values.to(device)
|
111 |
+
|
112 |
+
# Predict
|
113 |
+
with torch.no_grad():
|
114 |
+
logits = model(input_values).logits
|
115 |
+
|
116 |
+
# Get the predicted class ID
|
117 |
+
predicted_id = torch.argmax(logits, dim=-1)
|
118 |
+
# Convert the predicted ID to the class name
|
119 |
+
predicted_class = model.config.id2label[predicted_id.item()]
|
120 |
+
|
121 |
+
return predicted_class
|
122 |
+
|
123 |
+
# Replace 'path_to_your_new_audio_file.wav' with the actual path to the new audio file
|
124 |
+
audio_file_path = "path_to_audio_file.wav"
|
125 |
+
predicted_class = predict_audio_class(audio_file_path, feature_extractor, model, device)
|
126 |
+
print(f"Predicted class: {predicted_class}")
|
127 |
+
|
128 |
+
```
|