Spaces:
Running
Running
import json | |
from typing import List | |
from langchain.pydantic_v1 import BaseModel, Field | |
from langchain.schema import BaseRetriever, Document | |
from langchain.tools import Tool | |
from backend.chat_bot.json_decoder import CustomJSONEncoder | |
class RetrieverInput(BaseModel): | |
query: str = Field(description="query to look up in retriever") | |
def create_retriever_tool( | |
retriever: BaseRetriever, | |
tool_name: str, | |
description: str | |
) -> Tool: | |
"""Create a tool to do retrieval of documents. | |
Args: | |
retriever: The retriever to use for the retrieval | |
tool_name: The name for the tool. This will be passed to the language model, | |
so should be unique and somewhat descriptive. | |
description: The description for the tool. This will be passed to the language | |
model, so should be descriptive. | |
Returns: | |
Tool class to pass to an agent | |
""" | |
def wrap(func): | |
def wrapped_retrieve(*args, **kwargs): | |
docs: List[Document] = func(*args, **kwargs) | |
return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder) | |
return wrapped_retrieve | |
return Tool( | |
name=tool_name, | |
description=description, | |
func=wrap(retriever.get_relevant_documents), | |
coroutine=retriever.aget_relevant_documents, | |
args_schema=RetrieverInput, | |
) | |