Spaces:
Running
Running
import gradio as gr | |
from PIL import Image | |
from io import BytesIO | |
import openai | |
import os | |
from dotenv import load_dotenv | |
from image_processor import ImageProcessor | |
from evaluation_processor import EvaluationProcessor | |
from zhipuai import ZhipuAI | |
from collections import deque | |
# Load environment variables | |
load_dotenv() | |
# Initialize OpenAI client | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
engine = "gpt-4o-mini" | |
# Initialize image and evaluation processors | |
api_key = 'ddc85b14-bd83-4757-9bc4-8a11194da536' | |
image_processor = ImageProcessor(api_key) | |
evaluation_processor = EvaluationProcessor(api_key) | |
# Initialize memory with a deque (double-ended queue) to store up to 5 rounds | |
memory = deque(maxlen=5) | |
prev_image_result = None | |
prev_audio_result = None | |
prev_video_result = None | |
prev_image_files = None | |
prev_audio_file = None | |
prev_video_file = None | |
def process_input(text=None, images=None, audio=None, video=None): | |
global prev_image_result, prev_audio_result, prev_video_result | |
global prev_image_files, prev_audio_file, prev_video_file | |
print("Starting process_input") | |
system_prompt = ( | |
"1.你是一个音乐专家,只能回答音乐知识..." | |
) | |
# 包含历史对话在内的 messages | |
messages = [{"role": "system", "content": system_prompt}] | |
# 将历史对话从 memory 加入到 messages 中 | |
for past in memory: | |
messages.append({"role": "user", "content": past["prompt"]}) | |
messages.append({"role": "assistant", "content": past["response"]}) | |
prompt = "" | |
# 处理文本输入 | |
if text: | |
print("Processing text input") | |
prompt += f"\nText input: {text}" | |
result_path = None | |
# 处理图片输入 | |
if images: | |
if prev_image_files and set(images) == set(prev_image_files): | |
print("Using previous image result") | |
prompt += prev_image_result | |
else: | |
print("Processing images") | |
prompt += process_images(images) | |
prev_image_result = prompt # 更新图片处理结果 | |
prev_image_files = images # 更新图片文件 | |
elif prev_image_result: | |
print("Using previous image result") | |
prompt += prev_image_result | |
# 处理音频输入 | |
if audio: | |
if prev_audio_file and audio.name == prev_audio_file.name: | |
print("Using previous audio result") | |
prompt += prev_audio_result | |
else: | |
print("Processing audio") | |
result, title = process_audio(audio) | |
prompt += result | |
result_path = title.get('result_path', '') | |
prev_audio_result = result # 更新音频处理结果 | |
prev_audio_file = audio # 更新音频文件 | |
elif prev_audio_result: | |
print("Using previous audio result") | |
prompt += prev_audio_result | |
# 处理视频输入 | |
if video: | |
if prev_video_file and video.name == prev_video_file.name: | |
print("Using previous video result") | |
prompt += prev_video_result | |
else: | |
print("Processing video") | |
result, title = process_video(video) | |
prompt += result | |
result_path = title.get('result_path', '') | |
prev_video_result = result # 更新视频处理结果 | |
prev_video_file = video # 更新视频文件 | |
elif prev_video_result: | |
print("Using previous video result") | |
prompt += prev_video_result | |
# 将当前对话存储到 memory(包括问题和模型的回答) | |
current_conversation = {"prompt": prompt, "response": ""} | |
response, result_path = get_zhipuai_response(messages, prompt) | |
current_conversation["response"] = response # 更新当前对话的回复 | |
memory.append(current_conversation) # 保存当前对话到历史中 | |
return response, result_path | |
def process_images(images): | |
image_bytes_list = [] | |
for image in images: | |
img = Image.open(image.name) | |
image_bytes = BytesIO() | |
img.save(image_bytes, format="PNG") | |
image_bytes.seek(0) | |
image_bytes_list.append(image_bytes.getvalue()) | |
try: | |
processed_image_result = image_processor.process_images(image_bytes_list) | |
return f"\n乐谱的内容如下,请你根据曲子的曲风回答问题: {processed_image_result}" | |
except Exception as e: | |
return f"Error processing image: {e}" | |
def process_audio(audio): | |
audio_path = audio.name | |
try: | |
result, title = evaluation_processor.process_evaluation(audio_path, is_video=False) | |
prompt = ( | |
f'''如果有曲名{title},请你根据这首歌的名字作者,并且''' | |
f'''1. 请你从 | |
"eva_all":综合得分 | |
"eva_completion":完整性 | |
"eva_note":按键 | |
"eva_stability":稳定性 | |
"eva_tempo_sync":节奏 | |
几个方面评价一下下面这首曲子演奏的结果, 不用提及键的英文,只使用中文,曲子为 {result}''' | |
) | |
return prompt, title | |
except Exception as e: | |
return f"Error processing audio: {e}", None | |
def process_video(video): | |
video_path = video.name | |
try: | |
result, title = evaluation_processor.process_evaluation(video_path, is_video=True) | |
prompt = ( | |
f'''如果有曲名{title},请你根据这首歌的名字作者,并且''' | |
f'''1.请你从 | |
"eva_all":综合得分 | |
"eva_completion":完整性 | |
"eva_note":按键 | |
"eva_stability":稳定性 | |
"eva_tempo_sync":节奏 | |
几个方面评价一下下面这首曲子演奏的结果, 不用提及键的英文,只使用中文,曲子为 {result}''' | |
) | |
return prompt, title | |
except Exception as e: | |
return f"Error processing video: {e}", None | |
def get_gpt_response(messages, prompt): | |
messages.append({"role": "user", "content": prompt}) | |
response_text = "" | |
# Use OpenAI API for streaming response | |
try: | |
for chunk in openai.ChatCompletion.create( | |
model=engine, | |
messages=messages, | |
temperature=0.2, | |
max_tokens=4096, | |
top_p=0.95, | |
frequency_penalty=0, | |
presence_penalty=0, | |
stream=True # Enable streaming | |
): | |
if 'content' in chunk['choices'][0]['delta']: | |
response_text += chunk['choices'][0]['delta']['content'] | |
yield response_text # Yield response incrementally | |
except Exception as e: | |
yield f"Error: {e}" | |
def get_zhipuai_response_stream(messages, prompt): | |
print("Inside get_zhipuai_response") | |
client = ZhipuAI(api_key="423ca4c1f712621a4a1740bb6008673b.81aM7DNo2Ssn8FPA") | |
messages.append({"role": "user", "content": prompt}) | |
response_text = "" | |
# Use ZhipuAI API for streaming response | |
try: | |
response = client.chat.completions.create( | |
model="glm-4-flash", | |
messages=messages, | |
stream=True # Enable streaming | |
) | |
print("Response received from ZhipuAI") | |
print(response) | |
for chunk in response: | |
print(f"Chunk received: {chunk}") # Log each chunk | |
response_text = chunk.choices[0].delta.content | |
print(response_text) | |
yield response_text # Yield response incrementally | |
except Exception as e: | |
print(f"Error in get_zhipuai_response_stream: {e}") | |
yield f"Error: {e}" | |
def get_zhipuai_response(messages, prompt): | |
print("Inside get_zhipuai_response") # Confirming entry into the function | |
client = ZhipuAI(api_key="423ca4c1f712621a4a1740bb6008673b.81aM7DNo2Ssn8FPA") | |
messages.append({"role": "user", "content": prompt}) | |
print("Messages prepared:", messages) # Log messages | |
response_text = "" | |
# Non-streaming test | |
try: | |
print("Calling ZhipuAI API...") # Log before API call | |
response = client.chat.completions.create( | |
model="glm-4-flash", | |
messages=messages, | |
stream=False # Disable streaming for this test | |
) | |
print("Response received from ZhipuAI") # Log response retrieval | |
response_text = response.choices[0].message.content | |
return response_text # Return the entire response | |
except Exception as e: | |
print(f"Error in get_zhipuai_response: {e}") # More informative error message | |
return f"Error: {e}" | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=process_input, | |
inputs=[ | |
gr.Textbox(label="Input Text", placeholder="我是音乐多模态大模型,您可以上传需要分析的曲谱,音频和视频", lines=2), | |
gr.File(label="Input Images", file_count="multiple", type="filepath"), | |
gr.File(label="Input Audio, mp3", type="filepath"), | |
gr.File(label="Input Video, mp4", type="filepath") | |
], | |
outputs=[ | |
gr.Textbox(label="Output Text", interactive=True), # Enable streaming in the output | |
gr.HTML(label="Webpage") | |
], | |
live=False, | |
) | |
# Launch Gradio application | |
iface.launch() | |