Explnation on using the ONNX model
- opened
Hi! Thanks for the lovely model :)
I'm trying to run the ONNX version (tried both 6/24 versions) but I get the same error with both (am able to run the pytorch based prediction that you provided).
# Load the model
six_l_onnx_path = "./ONNX-6 Layers/model.onnx"
twenty_four_l_onnx_path = "./ONNX-24 Layers/model.onnx"
ort_session = ort.InferenceSession(twenty_four_l_onnx_path)
model_name = 'audeering/wav2vec2-large-robust-24-ft-age-gender'
processor = Wav2Vec2Processor.from_pretrained(model_name)
def process_func_onnx(x: np.ndarray, sampling_rate: int) -> np.ndarray:
y = processor(x, sampling_rate=sampling_rate)
y = y['input_values'][0]
y = y.reshape(1, -1)
inputs = {ort_session.get_inputs()[0].name: y}
output_name = ort_session.get_outputs()[0].name
ort_outputs = ort_session.run(output_name, inputs)
y = np.hstack(ort_outputs)
return y
def test_stuff():
folder = './data/audio_samples'
results = {}
import glob
for file in glob.glob(f'{folder}/*.wav'):
from pydub import AudioSegment
segment = AudioSegment.from_wav(file)
segment = segment.set_frame_rate(16000)
signal = np.array(segment.get_array_of_samples(), dtype=np.float32)
results[file] = process_func_onnx(signal, segment.frame_rate)
return results
The error I'm seeing from onnx is:
TypeError: run(): incompatible function arguments. The following argument types are supported:
1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]
Invoked with: <onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession object at 0x10a178230>, 'hidden_states', {'signal': array([[ 0.6791527 , -0.40761387, 1.4003403 , ..., -1.1661903 ,
-0.94601196, -0.813905 ]], dtype=float32)}, None
any Ideas?
Hi Yovelcohen,
thanks for your interest in the model.output_name
must be a list but anyway, you probably don't want the hidden_states
but the logits for age and gender:
import numpy as np
import onnxruntime as ort
# Load the model
six_l_onnx_path = "./ONNX-6 Layers/model.onnx"
ort_session = ort.InferenceSession(six_l_onnx_path)
def process_func_onnx(x: np.ndarray, sampling_rate: int) -> np.ndarray:
y = x.reshape(1, -1)
inputs = {ort_session.get_inputs()[0].name: y}
# output_name = [ort_session.get_outputs()[0].name]
output_name = ['logits_age', 'logits_gender']
ort_outputs = ort_session.run(output_name, inputs)
y = np.hstack(ort_outputs)
return y
def test_stuff():
folder = './data/audio_samples'
results = {}
import glob
for file in glob.glob(f'{folder}/*.wav'):
from pydub import AudioSegment
segment = AudioSegment.from_wav(file)
segment = segment.set_frame_rate(16000)
signal = np.array(segment.get_array_of_samples(), dtype=np.float32)
results[file] = process_func_onnx(signal, segment.frame_rate)
return results
if __name__ == "__main__":
More information on how to interpret the outputs is found here: https://github.com/audeering/w2v2-age-gender-how-to/blob/master/notebook.ipynb
Note: As shown in the code, also the Wav2Vec2Processor
is not required when using the ONNX
Hope this helps!
changed discussion status to