import os os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' os.environ['CUDA_VISIBLE_DEVICES'] = '0' import time import shutil import argparse import librosa import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr from glob import glob from audio_models import EnglishEmotionModel def classify(audio, model_choice, preprocess, weight): return en_model.predict(audio, model_choice, preprocess, weight) def handle_feedback(audio: str, model_prediction, unsatisfied, true_label, savedir='./user_feedback'): os.makedirs(savedir, exist_ok=True) if unsatisfied: audio_path = os.path.join(savedir, f"{int(time.time())}.wav") # 保存用户反馈的音频 if isinstance(audio, str): # a temporary file path shutil.copy(audio, audio_path) elif isinstance(audio, tuple): # audio data sr, data = audio librosa.output.write_wav(audio_path, data, sr) else: raise ValueError("Invalid audio input") # save model prediction and true label as Python dict, save it as pickle file feedback = { "audio_path": audio_path, "model_prediction": model_prediction, "true_label": true_label } feedback_path = os.path.join(savedir, f"{int(time.time())}.pkl") torch.save(feedback, feedback_path) return f"Feedback submitted: True Label = {true_label}, Model Prediction = {model_prediction}" return "Thank you for using our SER demo!" def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--duration', type=int, default=10, help='duration of audio') parser.add_argument('--sr', type=int, default=16000, help='sampling rate of audio') parser.add_argument('--device', type=str, default='cuda', help='device index to run model') return parser.parse_args() if __name__ == '__main__': args = parse_args() en_model = EnglishEmotionModel(duration=args.duration, sr=args.sr, device=torch.device(args.device)) english_audio_paths = glob('audio_files/english/*.wav') english_audio_paths.sort() english_audio_paths = {f"English Audio {idx + 1}": path for idx, path in enumerate(english_audio_paths)} # 更新音频选项的函数 def update_audio_options(): return gr.update(choices=list(english_audio_paths.keys()), value="English Audio 1") # 更新音频播放器的函数 def update_audio_file(audio_selection): return english_audio_paths[audio_selection] with gr.Blocks() as demo: # 创建tab界面 # with gr.Tab("Demo (Built-In Audio)"): # gr.Markdown("""## Automatic Emotion Recognition Demo \n # This is a demo for audio emotion recognition. # Note that the model is still under active developments. Please feel free to report any issues. \n # The Chinese model is based on Hubert and the English model is based on Wav2Vec2.""") # with gr.Row(): # with gr.Column(): # # 选择音频的 Dropdown,默认显示中文音频的第一个 # audio_dropdown = gr.Dropdown(list(english_audio_paths.keys()), label="Select Audio", value="English Audio 1", interactive=True) # # 音频播放器,默认播放中文音频 1 # audio_player = gr.Audio(value=english_audio_paths["English Audio 1"], interactive=False) # slider = gr.Slider(label='Context Weight', minimum=0, maximum=1, step=0.01, value=0.6) # with gr.Column(): # # 显示情感分类结果 # emotion_label = gr.Label(label="Emotion Prediction") # dim_label = gr.Plot(label="Emotion Dimension") # transcripts = gr.Textbox(label="Transcription", type='text', lines=5, max_lines=20, placeholder="Transcription") # # 按钮,点击后更新情感分类结果 # classify_button = gr.Button("Classify Emotion") # audio_dropdown.change( # fn=update_audio_file, # inputs=audio_dropdown, # outputs=audio_player # ) # # 点击按钮后,更新情感分类结果 # classify_button.click(base_classify, inputs=[audio_player, slider], outputs=[emotion_label, dim_label, transcripts]) with gr.Tab("Speech Emotion Recognition Demo"): gr.Markdown("""## Interactive SER Demo \n Please upload audio via file path or microphone. If you are recording audio via microphone, please make sure that the audio is clear. \n The performance could be affected by environmental noise. \n If you are recording in a noisy environment, please enable the noise reduction option. Note that this will lead to slight deterioration in performance.\n""") with gr.Row(): with gr.Column(): audio = gr.Audio(sources=['microphone', 'upload'], type='filepath') text = gr.Textbox(label="Transcription", type='text', lines=5, max_lines=20, placeholder="Transcription") model_choice = gr.Dropdown(choices=['中文', 'English'], label='语言 / Language', value='中文') with gr.Accordion("Advanced Settings", open=False): preprocess = gr.Checkbox(label='Noise Reduction (Do not tick the box unless the environment is noisy)', value=False) weight_slider = gr.Slider(label='Context Weight', minimum=0, maximum=1, step=0.01, value=0.6) demo_button = gr.Button("Analyze Emotion") with gr.Column(): emotion_pred = gr.Label(label="Emotion Prediction") dim_pred = gr.Plot(label="Emotion Dimension") with gr.Accordion("Feedback", open=False) as feedback_section: gr.Markdown("### User Feedback") satisfied_checkbox = gr.Checkbox(label="Are you unsatisfied with the result?", value=False) true_label_dropdown = gr.Dropdown( label="Select the correct label", choices=["angry", "disgust", "fearful", "happy", "neutral", "sad", "surprised"], ) submit_feedback_button = gr.Button("Submit Feedback") feedback_result = gr.Textbox(label="Feedback Result", interactive=False) demo_button.click(classify, inputs=[audio, model_choice, preprocess, weight_slider], outputs=[emotion_pred, dim_pred, text]) submit_feedback_button.click(handle_feedback, inputs=[audio, emotion_pred, satisfied_checkbox, true_label_dropdown], outputs=[feedback_result]) demo.launch(share=True)