import json import os from base_handler import BaseHandler from constant import DEFAULT_SYSTEM_PROMPT, GORILLA_TO_OPENAPI from model_style import ModelStyle from utils import ( convert_to_function_call, convert_to_tool, default_decode_ast_prompting, default_decode_execute_prompting, format_execution_results_prompting, func_doc_language_specific_pre_processing, system_prompt_pre_processing_chat_model, convert_system_prompt_into_user_prompt, combine_consecutive_user_prompts, ) from openai import OpenAI class OpenAIHandler(BaseHandler): def __init__(self, model_name, temperature) -> None: super().__init__(model_name, temperature) self.model_style = ModelStyle.OpenAI self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) def decode_ast(self, result, language="Python"): if "FC" not in self.model_name: return default_decode_ast_prompting(result, language) else: decoded_output = [] for invoked_function in result: name = list(invoked_function.keys())[0] params = json.loads(invoked_function[name]) decoded_output.append({name: params}) return decoded_output def decode_execute(self, result): if "FC" not in self.model_name: return default_decode_execute_prompting(result) else: function_call = convert_to_function_call(result) return function_call #### FC methods #### def _query_FC(self, inference_data: dict): message: list[dict] = inference_data["message"] tools = inference_data["tools"] inference_data["inference_input_log"] = {"message": repr(message), "tools": tools} if len(tools) > 0: api_response = self.client.chat.completions.create( messages=message, model=self.model_name.replace("-FC", ""), temperature=self.temperature, tools=tools, ) else: api_response = self.client.chat.completions.create( messages=message, model=self.model_name.replace("-FC", ""), temperature=self.temperature, ) return api_response def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict: inference_data["message"] = [] return inference_data def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict: functions: list = test_entry["function"] test_category: str = test_entry["id"].rsplit("_", 1)[0] functions = func_doc_language_specific_pre_processing(functions, test_category) tools = convert_to_tool(functions, GORILLA_TO_OPENAPI, self.model_style) inference_data["tools"] = tools return inference_data def _parse_query_response_FC(self, api_response: any) -> dict: try: model_responses = [ {func_call.function.name: func_call.function.arguments} for func_call in api_response.choices[0].message.tool_calls ] tool_call_ids = [ func_call.id for func_call in api_response.choices[0].message.tool_calls ] except: model_responses = api_response.choices[0].message.content tool_call_ids = [] model_responses_message_for_chat_history = api_response.choices[0].message return { "model_responses": model_responses, "model_responses_message_for_chat_history": model_responses_message_for_chat_history, "tool_call_ids": tool_call_ids, "input_token": api_response.usage.prompt_tokens, "output_token": api_response.usage.completion_tokens, } def add_first_turn_message_FC( self, inference_data: dict, first_turn_message: list[dict] ) -> dict: inference_data["message"].extend(first_turn_message) return inference_data def _add_next_turn_user_message_FC( self, inference_data: dict, user_message: list[dict] ) -> dict: inference_data["message"].extend(user_message) return inference_data def _add_assistant_message_FC( self, inference_data: dict, model_response_data: dict ) -> dict: inference_data["message"].append( model_response_data["model_responses_message_for_chat_history"] ) return inference_data def _add_execution_results_FC( self, inference_data: dict, execution_results: list[str], model_response_data: dict, ) -> dict: # Add the execution results to the current round result, one at a time for execution_result, tool_call_id in zip( execution_results, model_response_data["tool_call_ids"] ): tool_message = { "role": "tool", "content": execution_result, "tool_call_id": tool_call_id, } inference_data["message"].append(tool_message) return inference_data