Epsilon617
add model inference codes
826be26
raw
history blame
3.05 kB
import gradio as gr
from transformers import Wav2Vec2FeatureExtractor
from transformers import AutoModel
import torch
from torch import nn
import torchaudio
import torchaudio.transforms as T
import logging
# input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py
logger = logging.getLogger("whisper-jax-app")
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
ch.setFormatter(formatter)
logger.addHandler(ch)
inputs = [gr.components.Audio(type="filepath", label="Add music audio file"),
gr.inputs.Audio(source="microphone",optional=True, type="filepath"),
]
outputs = [gr.components.Textbox()]
# outputs = [gr.components.Textbox(), transcription_df]
title = "Output the tags of a (music) audio"
description = "An example of using MERT-95M-public to conduct music tagging."
article = ""
audio_examples = [
# ["input/example-1.wav"],
# ["input/example-2.wav"],
]
# Load the model
model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True)
# loading the corresponding preprocessor config
processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
def convert_audio(inputs, microphone):
if (microphone is not None):
inputs = microphone
waveform, sample_rate = torchaudio.load(inputs)
resample_rate = processor.sampling_rate
# make sure the sample_rate aligned
if resample_rate != sample_rate:
print(f'setting rate from {sample_rate} to {resample_rate}')
resampler = T.Resample(sample_rate, resample_rate)
waveform = resampler(waveform)
waveform = waveform.view(-1,) # make it (n_sample, )
model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt")
model_inputs.to(device)
with torch.no_grad():
model_outputs = model(**model_inputs, output_hidden_states=True)
# take a look at the output shape, there are 13 layers of representation
# each layer performs differently in different downstream tasks, you should choose empirically
all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()
# print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
# logger.warning(all_layer_hidden_states.shape)
return device + " :" + str(all_layer_hidden_states.shape)
# iface = gr.Interface(fn=convert_audio, inputs="audio", outputs="text")
# iface.launch()
audio_chunked = gr.Interface(
fn=convert_audio,
inputs=inputs,
outputs=outputs,
allow_flagging="never",
title=title,
description=description,
article=article,
examples=audio_examples,
)
demo = gr.Blocks()
with demo:
gr.TabbedInterface([audio_chunked], [
"Audio File"])
# demo.queue(concurrency_count=1, max_size=5)
demo.launch(show_api=False)