import gradio as gr from PIL import Image from io import BytesIO import openai import os from dotenv import load_dotenv from image_processor import ImageProcessor from evaluation_processor import EvaluationProcessor from zhipuai import ZhipuAI from collections import deque # Load environment variables load_dotenv() # Initialize OpenAI client openai.api_key = os.getenv("OPENAI_API_KEY") engine = "gpt-4o-mini" # Initialize image and evaluation processors api_key = 'ddc85b14-bd83-4757-9bc4-8a11194da536' image_processor = ImageProcessor(api_key) evaluation_processor = EvaluationProcessor(api_key) # Initialize memory with a deque (double-ended queue) to store up to 5 rounds memory = deque(maxlen=5) prev_image_result = None prev_audio_result = None prev_video_result = None prev_image_files = None prev_audio_file = None prev_video_file = None def process_input(text=None, images=None, audio=None, video=None): global prev_image_result, prev_audio_result, prev_video_result global prev_image_files, prev_audio_file, prev_video_file print("Starting process_input") system_prompt = ( "1.你是一个音乐专家,只能回答音乐知识..." ) # 包含历史对话在内的 messages messages = [{"role": "system", "content": system_prompt}] # 将历史对话从 memory 加入到 messages 中 for past in memory: messages.append({"role": "user", "content": past["prompt"]}) messages.append({"role": "assistant", "content": past["response"]}) prompt = "" # 处理文本输入 if text: print("Processing text input") prompt += f"\nText input: {text}" result_path = None # 处理图片输入 if images: if prev_image_files and set(images) == set(prev_image_files): print("Using previous image result") prompt += prev_image_result else: print("Processing images") prompt += process_images(images) prev_image_result = prompt # 更新图片处理结果 prev_image_files = images # 更新图片文件 elif prev_image_result: print("Using previous image result") prompt += prev_image_result # 处理音频输入 if audio: if prev_audio_file and audio.name == prev_audio_file.name: print("Using previous audio result") prompt += prev_audio_result else: print("Processing audio") result, title = process_audio(audio) prompt += result result_path = title.get('result_path', '') prev_audio_result = result # 更新音频处理结果 prev_audio_file = audio # 更新音频文件 elif prev_audio_result: print("Using previous audio result") prompt += prev_audio_result # 处理视频输入 if video: if prev_video_file and video.name == prev_video_file.name: print("Using previous video result") prompt += prev_video_result else: print("Processing video") result, title = process_video(video) prompt += result result_path = title.get('result_path', '') prev_video_result = result # 更新视频处理结果 prev_video_file = video # 更新视频文件 elif prev_video_result: print("Using previous video result") prompt += prev_video_result # 将当前对话存储到 memory(包括问题和模型的回答) current_conversation = {"prompt": prompt, "response": ""} response, result_path = get_zhipuai_response(messages, prompt) current_conversation["response"] = response # 更新当前对话的回复 memory.append(current_conversation) # 保存当前对话到历史中 return response, result_path def process_images(images): image_bytes_list = [] for image in images: img = Image.open(image.name) image_bytes = BytesIO() img.save(image_bytes, format="PNG") image_bytes.seek(0) image_bytes_list.append(image_bytes.getvalue()) try: processed_image_result = image_processor.process_images(image_bytes_list) return f"\n乐谱的内容如下,请你根据曲子的曲风回答问题: {processed_image_result}" except Exception as e: return f"Error processing image: {e}" def process_audio(audio): audio_path = audio.name try: result, title = evaluation_processor.process_evaluation(audio_path, is_video=False) prompt = ( f'''如果有曲名{title},请你根据这首歌的名字作者,并且''' f'''1. 请你从 "eva_all":综合得分 "eva_completion":完整性 "eva_note":按键 "eva_stability":稳定性 "eva_tempo_sync":节奏 几个方面评价一下下面这首曲子演奏的结果, 不用提及键的英文,只使用中文,曲子为 {result}''' ) return prompt, title except Exception as e: return f"Error processing audio: {e}", None def process_video(video): video_path = video.name try: result, title = evaluation_processor.process_evaluation(video_path, is_video=True) prompt = ( f'''如果有曲名{title},请你根据这首歌的名字作者,并且''' f'''1.请你从 "eva_all":综合得分 "eva_completion":完整性 "eva_note":按键 "eva_stability":稳定性 "eva_tempo_sync":节奏 几个方面评价一下下面这首曲子演奏的结果, 不用提及键的英文,只使用中文,曲子为 {result}''' ) return prompt, title except Exception as e: return f"Error processing video: {e}", None def get_gpt_response(messages, prompt): messages.append({"role": "user", "content": prompt}) response_text = "" # Use OpenAI API for streaming response try: for chunk in openai.ChatCompletion.create( model=engine, messages=messages, temperature=0.2, max_tokens=4096, top_p=0.95, frequency_penalty=0, presence_penalty=0, stream=True # Enable streaming ): if 'content' in chunk['choices'][0]['delta']: response_text += chunk['choices'][0]['delta']['content'] yield response_text # Yield response incrementally except Exception as e: yield f"Error: {e}" def get_zhipuai_response_stream(messages, prompt): print("Inside get_zhipuai_response") client = ZhipuAI(api_key="423ca4c1f712621a4a1740bb6008673b.81aM7DNo2Ssn8FPA") messages.append({"role": "user", "content": prompt}) response_text = "" # Use ZhipuAI API for streaming response try: response = client.chat.completions.create( model="glm-4-flash", messages=messages, stream=True # Enable streaming ) print("Response received from ZhipuAI") print(response) for chunk in response: print(f"Chunk received: {chunk}") # Log each chunk response_text = chunk.choices[0].delta.content print(response_text) yield response_text # Yield response incrementally except Exception as e: print(f"Error in get_zhipuai_response_stream: {e}") yield f"Error: {e}" def get_zhipuai_response(messages, prompt): print("Inside get_zhipuai_response") # Confirming entry into the function client = ZhipuAI(api_key="423ca4c1f712621a4a1740bb6008673b.81aM7DNo2Ssn8FPA") messages.append({"role": "user", "content": prompt}) print("Messages prepared:", messages) # Log messages response_text = "" # Non-streaming test try: print("Calling ZhipuAI API...") # Log before API call response = client.chat.completions.create( model="glm-4-flash", messages=messages, stream=False # Disable streaming for this test ) print("Response received from ZhipuAI") # Log response retrieval response_text = response.choices[0].message.content return response_text # Return the entire response except Exception as e: print(f"Error in get_zhipuai_response: {e}") # More informative error message return f"Error: {e}" # Create Gradio interface iface = gr.Interface( fn=process_input, inputs=[ gr.Textbox(label="Input Text", placeholder="我是音乐多模态大模型,您可以上传需要分析的曲谱,音频和视频", lines=2), gr.File(label="Input Images", file_count="multiple", type="filepath"), gr.File(label="Input Audio, mp3", type="filepath"), gr.File(label="Input Video, mp4", type="filepath") ], outputs=[ gr.Textbox(label="Output Text", interactive=True), # Enable streaming in the output gr.HTML(label="Webpage") ], live=False, ) # Launch Gradio application iface.launch()