Music_LMMs / app_queue.py
fistyee
update
d0c2b7c
raw
history blame
9.12 kB
import gradio as gr
from PIL import Image
from io import BytesIO
import openai
import os
from dotenv import load_dotenv
from image_processor import ImageProcessor
from evaluation_processor import EvaluationProcessor
from zhipuai import ZhipuAI
from collections import deque
# Load environment variables
load_dotenv()
# Initialize OpenAI client
openai.api_key = os.getenv("OPENAI_API_KEY")
engine = "gpt-4o-mini"
# Initialize image and evaluation processors
api_key = 'ddc85b14-bd83-4757-9bc4-8a11194da536'
image_processor = ImageProcessor(api_key)
evaluation_processor = EvaluationProcessor(api_key)
# Initialize memory with a deque (double-ended queue) to store up to 5 rounds
memory = deque(maxlen=5)
prev_image_result = None
prev_audio_result = None
prev_video_result = None
prev_image_files = None
prev_audio_file = None
prev_video_file = None
def process_input(text=None, images=None, audio=None, video=None):
global prev_image_result, prev_audio_result, prev_video_result
global prev_image_files, prev_audio_file, prev_video_file
print("Starting process_input")
system_prompt = (
"1.你是一个音乐专家,只能回答音乐知识..."
)
# 包含历史对话在内的 messages
messages = [{"role": "system", "content": system_prompt}]
# 将历史对话从 memory 加入到 messages 中
for past in memory:
messages.append({"role": "user", "content": past["prompt"]})
messages.append({"role": "assistant", "content": past["response"]})
prompt = ""
# 处理文本输入
if text:
print("Processing text input")
prompt += f"\nText input: {text}"
result_path = None
# 处理图片输入
if images:
if prev_image_files and set(images) == set(prev_image_files):
print("Using previous image result")
prompt += prev_image_result
else:
print("Processing images")
prompt += process_images(images)
prev_image_result = prompt # 更新图片处理结果
prev_image_files = images # 更新图片文件
elif prev_image_result:
print("Using previous image result")
prompt += prev_image_result
# 处理音频输入
if audio:
if prev_audio_file and audio.name == prev_audio_file.name:
print("Using previous audio result")
prompt += prev_audio_result
else:
print("Processing audio")
result, title = process_audio(audio)
prompt += result
result_path = title.get('result_path', '')
prev_audio_result = result # 更新音频处理结果
prev_audio_file = audio # 更新音频文件
elif prev_audio_result:
print("Using previous audio result")
prompt += prev_audio_result
# 处理视频输入
if video:
if prev_video_file and video.name == prev_video_file.name:
print("Using previous video result")
prompt += prev_video_result
else:
print("Processing video")
result, title = process_video(video)
prompt += result
result_path = title.get('result_path', '')
prev_video_result = result # 更新视频处理结果
prev_video_file = video # 更新视频文件
elif prev_video_result:
print("Using previous video result")
prompt += prev_video_result
# 将当前对话存储到 memory(包括问题和模型的回答)
current_conversation = {"prompt": prompt, "response": ""}
response, result_path = get_zhipuai_response(messages, prompt)
current_conversation["response"] = response # 更新当前对话的回复
memory.append(current_conversation) # 保存当前对话到历史中
return response, result_path
def process_images(images):
image_bytes_list = []
for image in images:
img = Image.open(image.name)
image_bytes = BytesIO()
img.save(image_bytes, format="PNG")
image_bytes.seek(0)
image_bytes_list.append(image_bytes.getvalue())
try:
processed_image_result = image_processor.process_images(image_bytes_list)
return f"\n乐谱的内容如下,请你根据曲子的曲风回答问题: {processed_image_result}"
except Exception as e:
return f"Error processing image: {e}"
def process_audio(audio):
audio_path = audio.name
try:
result, title = evaluation_processor.process_evaluation(audio_path, is_video=False)
prompt = (
f'''如果有曲名{title},请你根据这首歌的名字作者,并且'''
f'''1. 请你从
"eva_all":综合得分
"eva_completion":完整性
"eva_note":按键
"eva_stability":稳定性
"eva_tempo_sync":节奏
几个方面评价一下下面这首曲子演奏的结果, 不用提及键的英文,只使用中文,曲子为 {result}'''
)
return prompt, title
except Exception as e:
return f"Error processing audio: {e}", None
def process_video(video):
video_path = video.name
try:
result, title = evaluation_processor.process_evaluation(video_path, is_video=True)
prompt = (
f'''如果有曲名{title},请你根据这首歌的名字作者,并且'''
f'''1.请你从
"eva_all":综合得分
"eva_completion":完整性
"eva_note":按键
"eva_stability":稳定性
"eva_tempo_sync":节奏
几个方面评价一下下面这首曲子演奏的结果, 不用提及键的英文,只使用中文,曲子为 {result}'''
)
return prompt, title
except Exception as e:
return f"Error processing video: {e}", None
def get_gpt_response(messages, prompt):
messages.append({"role": "user", "content": prompt})
response_text = ""
# Use OpenAI API for streaming response
try:
for chunk in openai.ChatCompletion.create(
model=engine,
messages=messages,
temperature=0.2,
max_tokens=4096,
top_p=0.95,
frequency_penalty=0,
presence_penalty=0,
stream=True # Enable streaming
):
if 'content' in chunk['choices'][0]['delta']:
response_text += chunk['choices'][0]['delta']['content']
yield response_text # Yield response incrementally
except Exception as e:
yield f"Error: {e}"
def get_zhipuai_response_stream(messages, prompt):
print("Inside get_zhipuai_response")
client = ZhipuAI(api_key="423ca4c1f712621a4a1740bb6008673b.81aM7DNo2Ssn8FPA")
messages.append({"role": "user", "content": prompt})
response_text = ""
# Use ZhipuAI API for streaming response
try:
response = client.chat.completions.create(
model="glm-4-flash",
messages=messages,
stream=True # Enable streaming
)
print("Response received from ZhipuAI")
print(response)
for chunk in response:
print(f"Chunk received: {chunk}") # Log each chunk
response_text = chunk.choices[0].delta.content
print(response_text)
yield response_text # Yield response incrementally
except Exception as e:
print(f"Error in get_zhipuai_response_stream: {e}")
yield f"Error: {e}"
def get_zhipuai_response(messages, prompt):
print("Inside get_zhipuai_response") # Confirming entry into the function
client = ZhipuAI(api_key="423ca4c1f712621a4a1740bb6008673b.81aM7DNo2Ssn8FPA")
messages.append({"role": "user", "content": prompt})
print("Messages prepared:", messages) # Log messages
response_text = ""
# Non-streaming test
try:
print("Calling ZhipuAI API...") # Log before API call
response = client.chat.completions.create(
model="glm-4-flash",
messages=messages,
stream=False # Disable streaming for this test
)
print("Response received from ZhipuAI") # Log response retrieval
response_text = response.choices[0].message.content
return response_text # Return the entire response
except Exception as e:
print(f"Error in get_zhipuai_response: {e}") # More informative error message
return f"Error: {e}"
# Create Gradio interface
iface = gr.Interface(
fn=process_input,
inputs=[
gr.Textbox(label="Input Text", placeholder="我是音乐多模态大模型,您可以上传需要分析的曲谱,音频和视频", lines=2),
gr.File(label="Input Images", file_count="multiple", type="filepath"),
gr.File(label="Input Audio, mp3", type="filepath"),
gr.File(label="Input Video, mp4", type="filepath")
],
outputs=[
gr.Textbox(label="Output Text", interactive=True), # Enable streaming in the output
gr.HTML(label="Webpage")
],
live=False,
)
# Launch Gradio application
iface.launch()