Spaces:
Build error
Build error
import torch | |
import torchaudio | |
from torch import nn | |
from transformers import AutoFeatureExtractor,AutoModelForAudioClassification,pipeline | |
#Preprocessing the data | |
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base") | |
max_duration = 2.0 # seconds | |
if torch.cuda.is_available(): | |
device = "cuda" | |
else: | |
device = "cpu" | |
softmax = nn.Softmax() | |
label2id, id2label = dict(), dict() | |
labels = ['0','1','2','3','4','5','6','7','8','9'] | |
num_labels = 10 | |
for i, label in enumerate(labels): | |
label2id[label] = str(i) | |
id2label[str(i)] = label | |
def get_pipeline(model_name): | |
if model_name.split('-')[-1].strip()!='ibo': | |
return None | |
return pipeline(task="audio-classification", model=model_name) | |
def load_model(model_checkpoint): | |
#if model_checkpoint.split('-')[-1].strip()!='ibo': #This is for DEBUGGING | |
# return None, None | |
# construct model and assign it to device | |
model = AutoModelForAudioClassification.from_pretrained( | |
model_checkpoint, | |
num_labels=num_labels, | |
label2id=label2id, | |
id2label=id2label, | |
).to(device) | |
return model | |
language_dict = { | |
"Igbo":'ibo', | |
"Oshiwambo":'kua', | |
"Yoruba":'yor', | |
"Oromo":'gax', | |
"Shona":'sna', | |
"Rundi":'run', | |
"Choose language":'none', | |
"MULTILINGUAL":'all' | |
} | |
AUDIO_CLASSIFICATION_MODELS= {'ibo':load_model('chrisjay/afrospeech-wav2vec-ibo'), | |
'kua':load_model('chrisjay/afrospeech-wav2vec-kua'), | |
'sna':load_model('chrisjay/afrospeech-wav2vec-sna'), | |
'yor':load_model('chrisjay/afrospeech-wav2vec-yor'), | |
'gax':load_model('chrisjay/afrospeech-wav2vec-gax'), | |
'run':load_model('chrisjay/afrospeech-wav2vec-run'), | |
'all':load_model('chrisjay/afrospeech-wav2vec-all-6') } | |
def cut_if_necessary(signal,num_samples): | |
if signal.shape[1] > num_samples: | |
signal = signal[:, :num_samples] | |
return signal | |
def right_pad_if_necessary(signal,num_samples): | |
length_signal = signal.shape[1] | |
if length_signal < num_samples: | |
num_missing_samples = num_samples - length_signal | |
last_dim_padding = (0, num_missing_samples) | |
signal = torch.nn.functional.pad(signal, last_dim_padding) | |
return signal | |
def resample_if_necessary(signal, sr,target_sample_rate,device): | |
if sr != target_sample_rate: | |
resampler = torchaudio.transforms.Resample(sr, target_sample_rate).to(device) | |
signal = resampler(signal) | |
return signal | |
def mix_down_if_necessary(signal): | |
if signal.shape[0] > 1: | |
signal = torch.mean(signal, dim=0, keepdim=True) | |
return signal | |
def preprocess_audio(waveform,sample_rate,feature_extractor): | |
waveform = resample_if_necessary(waveform, sample_rate,16000,device) | |
waveform = mix_down_if_necessary(waveform) | |
waveform = cut_if_necessary(waveform,16000) | |
waveform = right_pad_if_necessary(waveform,16000) | |
transformed = feature_extractor(waveform,sampling_rate=feature_extractor.sampling_rate, max_length=16000, truncation=True) | |
return transformed | |
def make_inference(drop_down,audio): | |
waveform, sample_rate = torchaudio.load(audio) | |
preprocessed_audio = preprocess_audio(waveform,sample_rate,feature_extractor) | |
language_code_chosen = language_dict[drop_down] | |
model = AUDIO_CLASSIFICATION_MODELS[language_code_chosen] | |
model.eval() | |
torch_preprocessed_audio = torch.from_numpy(preprocessed_audio.input_values[0]) | |
# make prediction | |
prediction = softmax(model(torch_preprocessed_audio).logits) | |
sorted_prediction = torch.sort(prediction,descending=True) | |
confidences={} | |
for s,v in zip(sorted_prediction.indices.detach().numpy().tolist()[0],sorted_prediction.values.detach().numpy().tolist()[0]): | |
confidences.update({s:v}) | |
return confidences | |