ardneebwar commited on
Commit
38bff40
1 Parent(s): 13c1e51

Added code it use it locally.

Browse files
Files changed (1) hide show
  1. README.md +44 -1
README.md CHANGED
@@ -82,4 +82,47 @@ The following hyperparameters were used during training:
82
 
83
  ### Github Repository
84
 
85
- [Animal Sound Classification](https://github.com/rawbeen248/audio_classification_finetuning)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  ### Github Repository
84
 
85
+ [Animal Sound Classification](https://github.com/rawbeen248/audio_classification_finetuning)
86
+
87
+
88
+ ### To try it locally
89
+
90
+ ```
91
+ import librosa
92
+ import torch
93
+ from transformers import HubertForSequenceClassification, Wav2Vec2FeatureExtractor
94
+
95
+ # Load the fine-tuned model and feature extractor
96
+ model_name = "ardneebwar/wav2vec2-animal-sounds-finetuned-hubert-finetuned-animals"
97
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
98
+ model = HubertForSequenceClassification.from_pretrained(model_name)
99
+
100
+ # Prepare the device
101
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102
+ model.to(device)
103
+ model.eval() # Set the model to evaluation mode
104
+
105
+ # Function to predict the class of an audio file
106
+ def predict_audio_class(audio_file, feature_extractor, model, device):
107
+ # Load and preprocess the audio file
108
+ speech, sr = librosa.load(audio_file, sr=16000)
109
+ input_values = feature_extractor(speech, return_tensors="pt", sampling_rate=16000).input_values
110
+ input_values = input_values.to(device)
111
+
112
+ # Predict
113
+ with torch.no_grad():
114
+ logits = model(input_values).logits
115
+
116
+ # Get the predicted class ID
117
+ predicted_id = torch.argmax(logits, dim=-1)
118
+ # Convert the predicted ID to the class name
119
+ predicted_class = model.config.id2label[predicted_id.item()]
120
+
121
+ return predicted_class
122
+
123
+ # Replace 'path_to_your_new_audio_file.wav' with the actual path to the new audio file
124
+ audio_file_path = "path_to_audio_file.wav"
125
+ predicted_class = predict_audio_class(audio_file_path, feature_extractor, model, device)
126
+ print(f"Predicted class: {predicted_class}")
127
+
128
+ ```