gender_detection / gender_prediction.py
Salman11223's picture
Create gender_prediction.py
7d66980
import os
import tqdm
import torch
import torchaudio
import numpy as np
from torch.utils.data import DataLoader
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification, Wav2Vec2Processor
from torch.nn import functional as F
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, dataset, basedir=None, sampling_rate=16000, max_audio_len=5):
self.dataset = dataset
self.basedir = basedir
self.sampling_rate = sampling_rate
self.max_audio_len = max_audio_len
def __len__(self):
return len(self.dataset)
def _cutorpad(self, audio):
effective_length = self.sampling_rate * self.max_audio_len
len_audio = len(audio)
if len_audio > effective_length:
audio = audio[:effective_length]
return audio
def __getitem__(self, index):
if self.basedir is None:
filepath = self.dataset[index]
else:
filepath = os.path.join(self.basedir, self.dataset[index])
speech_array, sr = torchaudio.load(filepath)
if speech_array.shape[0] > 1:
speech_array = torch.mean(speech_array, dim=0, keepdim=True)
if sr != self.sampling_rate:
transform = torchaudio.transforms.Resample(sr, self.sampling_rate)
speech_array = transform(speech_array)
sr = self.sampling_rate
speech_array = speech_array.squeeze().numpy()
speech_array = self._cutorpad(speech_array)
return {"input_values": speech_array, "attention_mask": None}
class CollateFunc:
def __init__(self, processor, max_length=None, padding=True, pad_to_multiple_of=None, sampling_rate=16000):
self.padding = padding
self.processor = processor
self.max_length = max_length
self.sampling_rate = sampling_rate
self.pad_to_multiple_of = pad_to_multiple_of
def __call__(self, batch):
input_features = []
for audio in batch:
input_tensor = self.processor(audio["input_values"], sampling_rate=self.sampling_rate).input_values
input_tensor = np.squeeze(input_tensor)
input_features.append({"input_values": input_tensor})
batch = self.processor.pad(
input_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
return batch
def predict(test_dataloader, model, device):
model.to(device)
model.eval()
preds = []
with torch.no_grad():
for batch in tqdm.tqdm(test_dataloader):
input_values = batch['input_values'].to(device)
logits = model(input_values).logits
scores = F.softmax(logits, dim=-1)
pred = torch.argmax(scores, dim=1).cpu().detach().numpy()
preds.extend(pred)
return preds
def get_gender(model_name_or_path, audio_paths, device):
num_labels = 2
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path)
model = AutoModelForAudioClassification.from_pretrained(
pretrained_model_name_or_path=model_name_or_path,
num_labels=num_labels,
)
test_dataset = CustomDataset(audio_paths)
data_collator = CollateFunc(
processor=feature_extractor,
padding=True,
sampling_rate=16000,
)
test_dataloader = DataLoader(
dataset=test_dataset,
batch_size=16,
collate_fn=data_collator,
shuffle=False,
num_workers=10
)
preds = predict(test_dataloader=test_dataloader, model=model, device=device)
# Map class indices to labels
label_mapping = {0: "female", 1: "male"}
# Determine the most common predicted label
most_common_label = max(set(preds), key=preds.count)
predicted_label = label_mapping[most_common_label]
return predicted_label