dual_window / openai_helper.py
Huanzhi (Hans) Mao
init
4d1746c
raw
history blame
5.15 kB
import json
import os
from base_handler import BaseHandler
from constant import DEFAULT_SYSTEM_PROMPT, GORILLA_TO_OPENAPI
from model_style import ModelStyle
from utils import (
convert_to_function_call,
convert_to_tool,
default_decode_ast_prompting,
default_decode_execute_prompting,
format_execution_results_prompting,
func_doc_language_specific_pre_processing,
system_prompt_pre_processing_chat_model,
convert_system_prompt_into_user_prompt,
combine_consecutive_user_prompts,
)
from openai import OpenAI
class OpenAIHandler(BaseHandler):
def __init__(self, model_name, temperature) -> None:
super().__init__(model_name, temperature)
self.model_style = ModelStyle.OpenAI
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
def decode_ast(self, result, language="Python"):
if "FC" not in self.model_name:
return default_decode_ast_prompting(result, language)
else:
decoded_output = []
for invoked_function in result:
name = list(invoked_function.keys())[0]
params = json.loads(invoked_function[name])
decoded_output.append({name: params})
return decoded_output
def decode_execute(self, result):
if "FC" not in self.model_name:
return default_decode_execute_prompting(result)
else:
function_call = convert_to_function_call(result)
return function_call
#### FC methods ####
def _query_FC(self, inference_data: dict):
message: list[dict] = inference_data["message"]
tools = inference_data["tools"]
inference_data["inference_input_log"] = {"message": repr(message), "tools": tools}
if len(tools) > 0:
api_response = self.client.chat.completions.create(
messages=message,
model=self.model_name.replace("-FC", ""),
temperature=self.temperature,
tools=tools,
)
else:
api_response = self.client.chat.completions.create(
messages=message,
model=self.model_name.replace("-FC", ""),
temperature=self.temperature,
)
return api_response
def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict:
inference_data["message"] = []
return inference_data
def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict:
functions: list = test_entry["function"]
test_category: str = test_entry["id"].rsplit("_", 1)[0]
functions = func_doc_language_specific_pre_processing(functions, test_category)
tools = convert_to_tool(functions, GORILLA_TO_OPENAPI, self.model_style)
inference_data["tools"] = tools
return inference_data
def _parse_query_response_FC(self, api_response: any) -> dict:
try:
model_responses = [
{func_call.function.name: func_call.function.arguments}
for func_call in api_response.choices[0].message.tool_calls
]
tool_call_ids = [
func_call.id for func_call in api_response.choices[0].message.tool_calls
]
except:
model_responses = api_response.choices[0].message.content
tool_call_ids = []
model_responses_message_for_chat_history = api_response.choices[0].message
return {
"model_responses": model_responses,
"model_responses_message_for_chat_history": model_responses_message_for_chat_history,
"tool_call_ids": tool_call_ids,
"input_token": api_response.usage.prompt_tokens,
"output_token": api_response.usage.completion_tokens,
}
def add_first_turn_message_FC(
self, inference_data: dict, first_turn_message: list[dict]
) -> dict:
inference_data["message"].extend(first_turn_message)
return inference_data
def _add_next_turn_user_message_FC(
self, inference_data: dict, user_message: list[dict]
) -> dict:
inference_data["message"].extend(user_message)
return inference_data
def _add_assistant_message_FC(
self, inference_data: dict, model_response_data: dict
) -> dict:
inference_data["message"].append(
model_response_data["model_responses_message_for_chat_history"]
)
return inference_data
def _add_execution_results_FC(
self,
inference_data: dict,
execution_results: list[str],
model_response_data: dict,
) -> dict:
# Add the execution results to the current round result, one at a time
for execution_result, tool_call_id in zip(
execution_results, model_response_data["tool_call_ids"]
):
tool_message = {
"role": "tool",
"content": execution_result,
"tool_call_id": tool_call_id,
}
inference_data["message"].append(tool_message)
return inference_data