afro-speech / inference.py
chrisjay's picture
modules to test the model
794ebc0
import torch
import torchaudio
from torch import nn
from transformers import AutoFeatureExtractor,AutoModelForAudioClassification,pipeline
#Preprocessing the data
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
max_duration = 2.0 # seconds
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
softmax = nn.Softmax()
label2id, id2label = dict(), dict()
labels = ['0','1','2','3','4','5','6','7','8','9']
num_labels = 10
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
def get_pipeline(model_name):
if model_name.split('-')[-1].strip()!='ibo':
return None
return pipeline(task="audio-classification", model=model_name)
def load_model(model_checkpoint):
#if model_checkpoint.split('-')[-1].strip()!='ibo': #This is for DEBUGGING
# return None, None
# construct model and assign it to device
model = AutoModelForAudioClassification.from_pretrained(
model_checkpoint,
num_labels=num_labels,
label2id=label2id,
id2label=id2label,
).to(device)
return model
language_dict = {
"Igbo":'ibo',
"Oshiwambo":'kua',
"Yoruba":'yor',
"Oromo":'gax',
"Shona":'sna',
"Rundi":'run',
"Choose language":'none',
"MULTILINGUAL":'all'
}
AUDIO_CLASSIFICATION_MODELS= {'ibo':load_model('chrisjay/afrospeech-wav2vec-ibo'),
'kua':load_model('chrisjay/afrospeech-wav2vec-kua'),
'sna':load_model('chrisjay/afrospeech-wav2vec-sna'),
'yor':load_model('chrisjay/afrospeech-wav2vec-yor'),
'gax':load_model('chrisjay/afrospeech-wav2vec-gax'),
'run':load_model('chrisjay/afrospeech-wav2vec-run'),
'all':load_model('chrisjay/afrospeech-wav2vec-all-6') }
def cut_if_necessary(signal,num_samples):
if signal.shape[1] > num_samples:
signal = signal[:, :num_samples]
return signal
def right_pad_if_necessary(signal,num_samples):
length_signal = signal.shape[1]
if length_signal < num_samples:
num_missing_samples = num_samples - length_signal
last_dim_padding = (0, num_missing_samples)
signal = torch.nn.functional.pad(signal, last_dim_padding)
return signal
def resample_if_necessary(signal, sr,target_sample_rate,device):
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, target_sample_rate).to(device)
signal = resampler(signal)
return signal
def mix_down_if_necessary(signal):
if signal.shape[0] > 1:
signal = torch.mean(signal, dim=0, keepdim=True)
return signal
def preprocess_audio(waveform,sample_rate,feature_extractor):
waveform = resample_if_necessary(waveform, sample_rate,16000,device)
waveform = mix_down_if_necessary(waveform)
waveform = cut_if_necessary(waveform,16000)
waveform = right_pad_if_necessary(waveform,16000)
transformed = feature_extractor(waveform,sampling_rate=feature_extractor.sampling_rate, max_length=16000, truncation=True)
return transformed
def make_inference(drop_down,audio):
waveform, sample_rate = torchaudio.load(audio)
preprocessed_audio = preprocess_audio(waveform,sample_rate,feature_extractor)
language_code_chosen = language_dict[drop_down]
model = AUDIO_CLASSIFICATION_MODELS[language_code_chosen]
model.eval()
torch_preprocessed_audio = torch.from_numpy(preprocessed_audio.input_values[0])
# make prediction
prediction = softmax(model(torch_preprocessed_audio).logits)
sorted_prediction = torch.sort(prediction,descending=True)
confidences={}
for s,v in zip(sorted_prediction.indices.detach().numpy().tolist()[0],sorted_prediction.values.detach().numpy().tolist()[0]):
confidences.update({s:v})
return confidences