dual_window / claude.py
Huanzhi (Hans) Mao
init
4d1746c
import json
import os
from anthropic import Anthropic
from anthropic.types import TextBlock, ToolUseBlock
from base_handler import BaseHandler
from constant import GORILLA_TO_OPENAPI
from model_style import ModelStyle
from utils import (
ast_parse,
combine_consecutive_user_prompts,
convert_system_prompt_into_user_prompt,
convert_to_function_call,
convert_to_tool,
extract_system_prompt,
format_execution_results_prompting,
func_doc_language_specific_pre_processing,
system_prompt_pre_processing_chat_model,
)
class ClaudeHandler(BaseHandler):
def __init__(self, model_name, temperature) -> None:
super().__init__(model_name, temperature)
self.model_style = ModelStyle.Anthropic
self.client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
def decode_ast(self, result, language="Python"):
if "FC" not in self.model_name:
func = result
if " " == func[0]:
func = func[1:]
if not func.startswith("["):
func = "[" + func
if not func.endswith("]"):
func = func + "]"
decode_output = ast_parse(func, language)
return decode_output
else:
decoded_output = []
for invoked_function in result:
name = list(invoked_function.keys())[0]
params = json.loads(invoked_function[name])
decoded_output.append({name: params})
return decoded_output
def decode_execute(self, result):
if "FC" not in self.model_name:
func = result
if " " == func[0]:
func = func[1:]
if not func.startswith("["):
func = "[" + func
if not func.endswith("]"):
func = func + "]"
decode_output = ast_parse(func)
execution_list = []
for function_call in decode_output:
for key, value in function_call.items():
execution_list.append(
f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})"
)
return execution_list
else:
function_call = convert_to_function_call(result)
return function_call
#### FC methods ####
def _query_FC(self, inference_data: dict):
inference_data["inference_input_log"] = {
"message": repr(inference_data["message"]),
"tools": inference_data["tools"],
}
messages = inference_data["message"]
if inference_data["caching_enabled"]:
# Only add cache control to the last two user messages
# Remove previously set cache control flags from all user messages except the last two
count = 0
for message in reversed(messages):
if message["role"] == "user":
if count < 2:
message["content"][0]["cache_control"] = {"type": "ephemeral"}
else:
if "cache_control" in message["content"][0]:
del message["content"][0]["cache_control"]
count += 1
return self.client.beta.prompt_caching.messages.create(
model=self.model_name.strip("-FC"),
max_tokens=(
8192 if "claude-3-5" in self.model_name else 4096
), # 3.5 Sonnet has a higher max token limit
tools=inference_data["tools"],
messages=messages,
)
def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict:
for round_idx in range(len(test_entry["question"])):
test_entry["question"][round_idx] = convert_system_prompt_into_user_prompt(
test_entry["question"][round_idx]
)
test_entry["question"][round_idx] = combine_consecutive_user_prompts(
test_entry["question"][round_idx]
)
inference_data["message"] = []
test_entry_id: str = test_entry["id"]
test_category: str = test_entry_id.rsplit("_", 1)[0]
# caching enabled only for multi_turn category
inference_data["caching_enabled"] = (
"claude-3-sonnet" not in self.model_name
)
return inference_data
def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict:
functions: list = test_entry["function"]
test_category: str = test_entry["id"].rsplit("_", 1)[0]
functions = func_doc_language_specific_pre_processing(functions, test_category)
tools = convert_to_tool(functions, GORILLA_TO_OPENAPI, self.model_style)
if inference_data["caching_enabled"]:
# First time compiling tools, so adding cache control flag to the last tool
if "tools" not in inference_data:
tools[-1]["cache_control"] = {"type": "ephemeral"}
# This is the situation where the tools are already compiled and we are adding more tools to the existing tools (in miss_func category)
# We add the cache control flag to the last tool in the previous existing tools and the last tool in the new tools to maximize cache hit
else:
existing_tool_len = len(inference_data["tools"])
tools[existing_tool_len - 1]["cache_control"] = {"type": "ephemeral"}
tools[-1]["cache_control"] = {"type": "ephemeral"}
inference_data["tools"] = tools
return inference_data
def _parse_query_response_FC(self, api_response: any) -> dict:
text_outputs = []
tool_call_outputs = []
tool_call_ids = []
for content in api_response.content:
if isinstance(content, TextBlock):
text_outputs.append(content.text)
elif isinstance(content, ToolUseBlock):
tool_call_outputs.append({content.name: json.dumps(content.input)})
tool_call_ids.append(content.id)
model_responses = tool_call_outputs if tool_call_outputs else text_outputs
model_responses_message_for_chat_history = api_response.content
return {
"model_responses": model_responses,
"model_responses_message_for_chat_history": model_responses_message_for_chat_history,
"tool_call_ids": tool_call_ids,
"input_token": api_response.usage.input_tokens,
"output_token": api_response.usage.output_tokens,
}
def add_first_turn_message_FC(
self, inference_data: dict, first_turn_message: list[dict]
) -> dict:
for message in first_turn_message:
message["content"] = [{"type": "text", "text": message["content"]}]
inference_data["message"].extend(first_turn_message)
return inference_data
def _add_next_turn_user_message_FC(
self, inference_data: dict, user_message: list[dict]
) -> dict:
for message in user_message:
message["content"] = [{"type": "text", "text": message["content"]}]
inference_data["message"].extend(user_message)
return inference_data
def _add_assistant_message_FC(
self, inference_data: dict, model_response_data: dict
) -> dict:
inference_data["message"].append(
{
"role": "assistant",
"content": model_response_data["model_responses_message_for_chat_history"],
}
)
return inference_data
def _add_execution_results_FC(
self,
inference_data: dict,
execution_results: list[str],
model_response_data: dict,
) -> dict:
# Claude don't use the tool role; it uses the user role to send the tool output
tool_message = {
"role": "user",
"content": [],
}
for execution_result, tool_call_id in zip(
execution_results, model_response_data["tool_call_ids"]
):
tool_message["content"].append(
{
"type": "tool_result",
"content": execution_result,
"tool_use_id": tool_call_id,
}
)
inference_data["message"].append(tool_message)
return inference_data