Spaces:
Running
Running
import json | |
import os | |
from base_handler import BaseHandler | |
from constant import DEFAULT_SYSTEM_PROMPT, GORILLA_TO_OPENAPI | |
from model_style import ModelStyle | |
from utils import ( | |
convert_to_function_call, | |
convert_to_tool, | |
default_decode_ast_prompting, | |
default_decode_execute_prompting, | |
format_execution_results_prompting, | |
func_doc_language_specific_pre_processing, | |
system_prompt_pre_processing_chat_model, | |
convert_system_prompt_into_user_prompt, | |
combine_consecutive_user_prompts, | |
) | |
from openai import OpenAI | |
class OpenAIHandler(BaseHandler): | |
def __init__(self, model_name, temperature) -> None: | |
super().__init__(model_name, temperature) | |
self.model_style = ModelStyle.OpenAI | |
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
def decode_ast(self, result, language="Python"): | |
if "FC" not in self.model_name: | |
return default_decode_ast_prompting(result, language) | |
else: | |
decoded_output = [] | |
for invoked_function in result: | |
name = list(invoked_function.keys())[0] | |
params = json.loads(invoked_function[name]) | |
decoded_output.append({name: params}) | |
return decoded_output | |
def decode_execute(self, result): | |
if "FC" not in self.model_name: | |
return default_decode_execute_prompting(result) | |
else: | |
function_call = convert_to_function_call(result) | |
return function_call | |
#### FC methods #### | |
def _query_FC(self, inference_data: dict): | |
message: list[dict] = inference_data["message"] | |
tools = inference_data["tools"] | |
inference_data["inference_input_log"] = {"message": repr(message), "tools": tools} | |
if len(tools) > 0: | |
api_response = self.client.chat.completions.create( | |
messages=message, | |
model=self.model_name.replace("-FC", ""), | |
temperature=self.temperature, | |
tools=tools, | |
) | |
else: | |
api_response = self.client.chat.completions.create( | |
messages=message, | |
model=self.model_name.replace("-FC", ""), | |
temperature=self.temperature, | |
) | |
return api_response | |
def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict: | |
inference_data["message"] = [] | |
return inference_data | |
def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict: | |
functions: list = test_entry["function"] | |
test_category: str = test_entry["id"].rsplit("_", 1)[0] | |
functions = func_doc_language_specific_pre_processing(functions, test_category) | |
tools = convert_to_tool(functions, GORILLA_TO_OPENAPI, self.model_style) | |
inference_data["tools"] = tools | |
return inference_data | |
def _parse_query_response_FC(self, api_response: any) -> dict: | |
try: | |
model_responses = [ | |
{func_call.function.name: func_call.function.arguments} | |
for func_call in api_response.choices[0].message.tool_calls | |
] | |
tool_call_ids = [ | |
func_call.id for func_call in api_response.choices[0].message.tool_calls | |
] | |
except: | |
model_responses = api_response.choices[0].message.content | |
tool_call_ids = [] | |
model_responses_message_for_chat_history = api_response.choices[0].message | |
return { | |
"model_responses": model_responses, | |
"model_responses_message_for_chat_history": model_responses_message_for_chat_history, | |
"tool_call_ids": tool_call_ids, | |
"input_token": api_response.usage.prompt_tokens, | |
"output_token": api_response.usage.completion_tokens, | |
} | |
def add_first_turn_message_FC( | |
self, inference_data: dict, first_turn_message: list[dict] | |
) -> dict: | |
inference_data["message"].extend(first_turn_message) | |
return inference_data | |
def _add_next_turn_user_message_FC( | |
self, inference_data: dict, user_message: list[dict] | |
) -> dict: | |
inference_data["message"].extend(user_message) | |
return inference_data | |
def _add_assistant_message_FC( | |
self, inference_data: dict, model_response_data: dict | |
) -> dict: | |
inference_data["message"].append( | |
model_response_data["model_responses_message_for_chat_history"] | |
) | |
return inference_data | |
def _add_execution_results_FC( | |
self, | |
inference_data: dict, | |
execution_results: list[str], | |
model_response_data: dict, | |
) -> dict: | |
# Add the execution results to the current round result, one at a time | |
for execution_result, tool_call_id in zip( | |
execution_results, model_response_data["tool_call_ids"] | |
): | |
tool_message = { | |
"role": "tool", | |
"content": execution_result, | |
"tool_call_id": tool_call_id, | |
} | |
inference_data["message"].append(tool_message) | |
return inference_data | |