SuperExpert / agents /base_agent.py
JarvisChan630's picture
first commit
75309ed
raw
history blame
4.07 kB
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Union, TypeVar, Generic
from typing_extensions import TypedDict
from datetime import datetime
from termcolor import colored
from models.llms import (
OllamaModel,
OpenAIModel,
GroqModel,
GeminiModel,
ClaudeModel,
VllmModel,
MistralModel
)
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Define a TypeVar for the state
StateT = TypeVar('StateT', bound=Dict[str, Any])
class BaseAgent(ABC, Generic[StateT]):
def __init__(self, model: str = None, server: str = None, temperature: float = 0,
model_endpoint: str = None, stop: str = None, location: str = "us", hyrbid: bool = False):
self.model = model
self.server = server
self.temperature = temperature
self.model_endpoint = model_endpoint
self.stop = stop
self.llm = self.get_llm()
self.location = location
self.hybrid = hyrbid
def get_llm(self, json_model: bool = False):
if self.server == 'openai':
return OpenAIModel(model=self.model, temperature=self.temperature, json_response=json_model)
elif self.server == 'ollama':
return OllamaModel(model=self.model, temperature=self.temperature, json_response=json_model)
elif self.server == 'vllm':
return VllmModel(model=self.model, temperature=self.temperature, json_response=json_model,
model_endpoint=self.model_endpoint, stop=self.stop)
elif self.server == 'groq':
return GroqModel(model=self.model, temperature=self.temperature, json_response=json_model)
elif self.server == 'claude':
return ClaudeModel(temperature=self.temperature, model=self.model, json_response=json_model)
elif self.server == 'mistral':
return MistralModel(temperature=self.temperature, model=self.model, json_response=json_model)
elif self.server == 'gemini':
# raise ValueError(f"Unsupported server: {self.server}")
return GeminiModel(temperature=self.temperature, model=self.model, json_response=json_model)
else:
raise ValueError(f"Unsupported server: {self.server}")
@abstractmethod
def get_prompt(self, state: StateT = None) -> str:
pass
@abstractmethod
def get_guided_json(self, state:StateT = None) -> Dict[str, Any]:
pass
def update_state(self, key: str, value: Union[str, dict], state: StateT = None) -> StateT:
state[key] = value
return state
@abstractmethod
def process_response(self, response: Any, user_input: str = None, state: StateT = None) -> Dict[str, Union[str, dict]]:
pass
@abstractmethod
def get_conv_history(self, state: StateT = None) -> str:
pass
@abstractmethod
def get_user_input(self) -> str:
pass
@abstractmethod
def use_tool(self) -> Any:
pass
def invoke(self, state: StateT = None, human_in_loop: bool = False, user_input: str = None, final_answer: str = None) -> StateT:
prompt = self.get_prompt(state)
conversation_history = self.get_conv_history(state)
if final_answer:
print(colored(f"\n\n{final_answer}\n\n", "green"))
if human_in_loop:
user_input = self.get_user_input()
messages = [
{"role": "system", "content": f"{prompt}\n Today's date is {datetime.now()}"},
{"role": "user", "content": f"\n{final_answer}\n" * 10 + f"{conversation_history}\n{user_input}"}
]
if self.server == 'vllm':
guided_json = self.get_guided_json(state)
response = self.llm.invoke(messages, guided_json)
else:
response = self.llm.invoke(messages)
updates = self.process_response(response, user_input, state)
for key, value in updates.items():
state = self.update_state(key, value, state)
return state