Spaces:
Runtime error
Runtime error
import os | |
import tqdm | |
import torch | |
import torchaudio | |
import numpy as np | |
from torch.utils.data import DataLoader | |
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification, Wav2Vec2Processor | |
from torch.nn import functional as F | |
class CustomDataset(torch.utils.data.Dataset): | |
def __init__(self, dataset, basedir=None, sampling_rate=16000, max_audio_len=5): | |
self.dataset = dataset | |
self.basedir = basedir | |
self.sampling_rate = sampling_rate | |
self.max_audio_len = max_audio_len | |
def __len__(self): | |
return len(self.dataset) | |
def _cutorpad(self, audio): | |
effective_length = self.sampling_rate * self.max_audio_len | |
len_audio = len(audio) | |
if len_audio > effective_length: | |
audio = audio[:effective_length] | |
return audio | |
def __getitem__(self, index): | |
if self.basedir is None: | |
filepath = self.dataset[index] | |
else: | |
filepath = os.path.join(self.basedir, self.dataset[index]) | |
speech_array, sr = torchaudio.load(filepath) | |
if speech_array.shape[0] > 1: | |
speech_array = torch.mean(speech_array, dim=0, keepdim=True) | |
if sr != self.sampling_rate: | |
transform = torchaudio.transforms.Resample(sr, self.sampling_rate) | |
speech_array = transform(speech_array) | |
sr = self.sampling_rate | |
speech_array = speech_array.squeeze().numpy() | |
speech_array = self._cutorpad(speech_array) | |
return {"input_values": speech_array, "attention_mask": None} | |
class CollateFunc: | |
def __init__(self, processor, max_length=None, padding=True, pad_to_multiple_of=None, sampling_rate=16000): | |
self.padding = padding | |
self.processor = processor | |
self.max_length = max_length | |
self.sampling_rate = sampling_rate | |
self.pad_to_multiple_of = pad_to_multiple_of | |
def __call__(self, batch): | |
input_features = [] | |
for audio in batch: | |
input_tensor = self.processor(audio["input_values"], sampling_rate=self.sampling_rate).input_values | |
input_tensor = np.squeeze(input_tensor) | |
input_features.append({"input_values": input_tensor}) | |
batch = self.processor.pad( | |
input_features, | |
padding=self.padding, | |
max_length=self.max_length, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
return_tensors="pt", | |
) | |
return batch | |
def predict(test_dataloader, model, device): | |
model.to(device) | |
model.eval() | |
preds = [] | |
with torch.no_grad(): | |
for batch in tqdm.tqdm(test_dataloader): | |
input_values = batch['input_values'].to(device) | |
logits = model(input_values).logits | |
scores = F.softmax(logits, dim=-1) | |
pred = torch.argmax(scores, dim=1).cpu().detach().numpy() | |
preds.extend(pred) | |
return preds | |
def get_gender(model_name_or_path, audio_paths, device): | |
num_labels = 2 | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path) | |
model = AutoModelForAudioClassification.from_pretrained( | |
pretrained_model_name_or_path=model_name_or_path, | |
num_labels=num_labels, | |
) | |
test_dataset = CustomDataset(audio_paths) | |
data_collator = CollateFunc( | |
processor=feature_extractor, | |
padding=True, | |
sampling_rate=16000, | |
) | |
test_dataloader = DataLoader( | |
dataset=test_dataset, | |
batch_size=16, | |
collate_fn=data_collator, | |
shuffle=False, | |
num_workers=10 | |
) | |
preds = predict(test_dataloader=test_dataloader, model=model, device=device) | |
# Map class indices to labels | |
label_mapping = {0: "female", 1: "male"} | |
# Determine the most common predicted label | |
most_common_label = max(set(preds), key=preds.count) | |
predicted_label = label_mapping[most_common_label] | |
return predicted_label |