File size: 4,397 Bytes
7d6e8c5 be77238 7d6e8c5 be77238 7d6e8c5 be77238 7d6e8c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from typing import List, Literal, Union
import math
from langchain.tools.base import StructuredTool
from langchain.agents import (
Tool,
AgentExecutor,
LLMSingleActionAgent,
AgentOutputParser,
)
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.prompts import StringPromptTemplate
from langchain.llms import HuggingFaceTextGenInference
from langchain.chains import LLMChain
##########################################################
# Step 1: Define the functions you want to articulate. ###
##########################################################
def calculator(
input_a: float,
input_b: float,
operation: Literal["add", "subtract", "multiply", "divide"],
):
"""
Computes a calculation.
Args:
input_a (float) : Required. The first input.
input_b (float) : Required. The second input.
operation (string): The operation. Choices include: add to add two numbers, subtract to subtract two numbers, multiply to multiply two numbers, and divide to divide them.
"""
match operation:
case "add":
return input_a + input_b
case "subtract":
return input_a - input_b
case "multiply":
return input_a * input_b
case "divide":
return input_a / input_b
def cylinder_volume(radius, height):
"""
Calculate the volume of a cylinder.
Parameters:
- radius (float): The radius of the base of the cylinder.
- height (float): The height of the cylinder.
Returns:
- float: The volume of the cylinder.
"""
if radius < 0 or height < 0:
raise ValueError("Radius and height must be non-negative.")
volume = math.pi * (radius**2) * height
return volume
#############################################################
# Step 2: Let's define some utils for building the prompt ###
#############################################################
RAVEN_PROMPT = """
{raven_tools}
User Query: {input}<human_end>
"""
# Set up a prompt template
class RavenPromptTemplate(StringPromptTemplate):
# The template to use
template: str
# The list of tools available
tools: List[Tool]
def format(self, **kwargs) -> str:
prompt = ""
for tool in self.tools:
func_signature, func_docstring = tool.description.split(" - ", 1)
prompt += f'\nFunction:\ndef {func_signature}\n"""\n{func_docstring}\n"""\n'
kwargs["raven_tools"] = prompt
return self.template.format(**kwargs).replace("{{", "{").replace("}}", "}")
class RavenOutputParser(AgentOutputParser):
def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
# Check if agent should finish
if "Call:" in llm_output:
return AgentFinish(
return_values={
"output": llm_output.strip()
.replace("Call:", "")
.strip()
},
log=llm_output,
)
else:
raise OutputParserException(f"Could not parse LLM output: `{llm_output}`")
##################################################
# Step 3: Build the agent with these utilities ###
##################################################
inference_server_url = "https://rjmy54al17scvxjr.us-east-1.aws.endpoints.huggingface.cloud"
assert (
inference_server_url is not "<YOUR ENDPOINT URL>"
), "Please provide your own HF inference endpoint URL!"
llm = HuggingFaceTextGenInference(
inference_server_url=inference_server_url,
temperature=0.001,
max_new_tokens=400,
do_sample=False,
)
tools = [
StructuredTool.from_function(calculator),
StructuredTool.from_function(cylinder_volume),
]
raven_prompt = RavenPromptTemplate(
template=RAVEN_PROMPT, tools=tools, input_variables=["input"]
)
llm_chain = LLMChain(llm=llm, prompt=raven_prompt)
output_parser = RavenOutputParser()
agent = LLMSingleActionAgent(
llm_chain=llm_chain,
output_parser=output_parser,
stop=["<bot_end>"],
allowed_tools=tools,
)
agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
call = agent_chain.run(
"I have a cake that is about 3 centimenters high and 200 centimeters in radius. How much cake do I have?"
)
print(eval(call))
call = agent_chain.run("What is 1+10?")
print(eval(call))
|