Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
import torch.nn as nn | |
from transformers import Wav2Vec2Processor | |
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model | |
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel | |
import audiofile | |
class ModelHead(nn.Module): | |
r"""Classification head.""" | |
def __init__(self, config, num_labels): | |
super().__init__() | |
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
self.dropout = nn.Dropout(config.final_dropout) | |
self.out_proj = nn.Linear(config.hidden_size, num_labels) | |
def forward(self, features, **kwargs): | |
x = features | |
x = self.dropout(x) | |
x = self.dense(x) | |
x = torch.tanh(x) | |
x = self.dropout(x) | |
x = self.out_proj(x) | |
return x | |
class AgeGenderModel(Wav2Vec2PreTrainedModel): | |
r"""Speech emotion classifier.""" | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
self.wav2vec2 = Wav2Vec2Model(config) | |
self.age = ModelHead(config, 1) | |
self.gender = ModelHead(config, 3) | |
self.init_weights() | |
def forward( | |
self, | |
input_values, | |
): | |
outputs = self.wav2vec2(input_values) | |
hidden_states = outputs[0] | |
hidden_states = torch.mean(hidden_states, dim=1) | |
logits_age = self.age(hidden_states) | |
logits_gender = torch.softmax(self.gender(hidden_states), dim=1) | |
return hidden_states, logits_age, logits_gender | |
# load model from hub | |
device = 0 if torch.cuda.is_available() else "cpu" | |
model_name = "audeering/wav2vec2-large-robust-24-ft-age-gender" | |
processor = Wav2Vec2Processor.from_pretrained(model_name) | |
model = AgeGenderModel.from_pretrained(model_name) | |
def process_func(x: np.ndarray, sampling_rate: int) -> dict: | |
r"""Predict age and gender or extract embeddings from raw audio signal.""" | |
# run through processor to normalize signal | |
# always returns a batch, so we just get the first entry | |
# then we put it on the device | |
y = processor(x, sampling_rate=sampling_rate) | |
y = y['input_values'][0] | |
y = y.reshape(1, -1) | |
y = torch.from_numpy(y).to(device) | |
# run through model | |
with torch.no_grad(): | |
y = model(y) | |
y = torch.hstack([y[1], y[2]]) | |
# convert to numpy | |
y = y.detach().cpu().numpy() | |
# convert to dict | |
y = [ | |
{"score": 100 * y[0][0], "label": "age"}, | |
{"score": y[0][1], "label": "female"}, | |
{"score": y[0][2], "label": "male"}, | |
{"score": y[0][3], "label": "child"}, | |
] | |
return y | |
def recognize(file): | |
if file is None: | |
raise gr.Error( | |
"No audio file submitted! " | |
"Please upload or record an audio file " | |
"before submitting your request." | |
) | |
signal, sampling_rate = audiofile.read(file) | |
age_gender = process_func(signal, sampling_rate) | |
return age_gender | |
demo = gr.Blocks() | |
outputs = gr.outputs.Label() | |
title = "audEERING age and gender recognition" | |
description = ( | |
"Recognize age and gender of a microphone recording or audio file. " | |
"Demo uses the checkpoint [{model_name}](https://huggingface.co/{model_name})." | |
) | |
allow_flagging = "never" | |
microphone = gr.Interface( | |
fn=recognize, | |
inputs=gr.Audio(sources="microphone", type="filepath"), | |
outputs=outputs, | |
title=title, | |
description=description, | |
allow_flagging=allow_flagging, | |
) | |
file = gr.Interface( | |
fn=recognize, | |
inputs=gr.Audio(sources="upload", type="filepath", label="Audio file"), | |
outputs=outputs, | |
title=title, | |
description=description, | |
allow_flagging=allow_flagging, | |
) | |
with demo: | |
gr.TabbedInterface([microphone, file], ["Microphone", "Audio file"]) | |
demo.queue().launch() | |