File size: 4,397 Bytes
7d6e8c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be77238
 
 
7d6e8c5
 
 
 
 
 
 
 
 
 
 
be77238
7d6e8c5
 
be77238
7d6e8c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from typing import List, Literal, Union

import math

from langchain.tools.base import StructuredTool
from langchain.agents import (
    Tool,
    AgentExecutor,
    LLMSingleActionAgent,
    AgentOutputParser,
)
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.prompts import StringPromptTemplate
from langchain.llms import HuggingFaceTextGenInference
from langchain.chains import LLMChain


##########################################################
# Step 1: Define the functions you want to articulate. ###
##########################################################


def calculator(
    input_a: float,
    input_b: float,
    operation: Literal["add", "subtract", "multiply", "divide"],
):
    """
    Computes a calculation.

    Args:
    input_a (float) : Required. The first input.
    input_b (float) : Required. The second input.
    operation (string): The operation. Choices include: add to add two numbers, subtract to subtract two numbers, multiply to multiply two numbers, and divide to divide them.
    """
    match operation:
        case "add":
            return input_a + input_b
        case "subtract":
            return input_a - input_b
        case "multiply":
            return input_a * input_b
        case "divide":
            return input_a / input_b


def cylinder_volume(radius, height):
    """
    Calculate the volume of a cylinder.

    Parameters:
    - radius (float): The radius of the base of the cylinder.
    - height (float): The height of the cylinder.

    Returns:
    - float: The volume of the cylinder.
    """
    if radius < 0 or height < 0:
        raise ValueError("Radius and height must be non-negative.")

    volume = math.pi * (radius**2) * height
    return volume


#############################################################
# Step 2: Let's define some utils for building the prompt ###
#############################################################


RAVEN_PROMPT = """
{raven_tools}
User Query: {input}<human_end>

"""



# Set up a prompt template
class RavenPromptTemplate(StringPromptTemplate):
    # The template to use
    template: str
    # The list of tools available
    tools: List[Tool]

    def format(self, **kwargs) -> str:
        prompt = ""
        for tool in self.tools:
            func_signature, func_docstring = tool.description.split(" - ", 1)
            prompt += f'\nFunction:\ndef {func_signature}\n"""\n{func_docstring}\n"""\n'
        kwargs["raven_tools"] = prompt
        return self.template.format(**kwargs).replace("{{", "{").replace("}}", "}")


class RavenOutputParser(AgentOutputParser):
    def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
        # Check if agent should finish
        if "Call:" in llm_output:
            return AgentFinish(
                return_values={
                    "output": llm_output.strip()
                    .replace("Call:", "")
                    .strip()
                },
                log=llm_output,
            )
        else:
            raise OutputParserException(f"Could not parse LLM output: `{llm_output}`")


##################################################
# Step 3: Build the agent with these utilities ###
##################################################


inference_server_url = "https://rjmy54al17scvxjr.us-east-1.aws.endpoints.huggingface.cloud"
assert (
    inference_server_url is not "<YOUR ENDPOINT URL>"
), "Please provide your own HF inference endpoint URL!"

llm = HuggingFaceTextGenInference(
    inference_server_url=inference_server_url,
    temperature=0.001,
    max_new_tokens=400,
    do_sample=False,
)
tools = [
    StructuredTool.from_function(calculator),
    StructuredTool.from_function(cylinder_volume),
]
raven_prompt = RavenPromptTemplate(
    template=RAVEN_PROMPT, tools=tools, input_variables=["input"]
)
llm_chain = LLMChain(llm=llm, prompt=raven_prompt)
output_parser = RavenOutputParser()
agent = LLMSingleActionAgent(
    llm_chain=llm_chain,
    output_parser=output_parser,
    stop=["<bot_end>"],
    allowed_tools=tools,
)
agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)

call = agent_chain.run(
    "I have a cake that is about 3 centimenters high and 200 centimeters in radius. How much cake do I have?"
)
print(eval(call))
call = agent_chain.run("What is 1+10?")
print(eval(call))