Spaces:
Sleeping
Sleeping
import logging | |
from abc import ABC, abstractmethod | |
from typing import Any, Dict, Union, TypeVar, Generic | |
from typing_extensions import TypedDict | |
from datetime import datetime | |
from termcolor import colored | |
from models.llms import ( | |
OllamaModel, | |
OpenAIModel, | |
GroqModel, | |
GeminiModel, | |
ClaudeModel, | |
VllmModel, | |
MistralModel | |
) | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Define a TypeVar for the state | |
StateT = TypeVar('StateT', bound=Dict[str, Any]) | |
class BaseAgent(ABC, Generic[StateT]): | |
def __init__(self, model: str = None, server: str = None, temperature: float = 0, | |
model_endpoint: str = None, stop: str = None, location: str = "us", hyrbid: bool = False): | |
self.model = model | |
self.server = server | |
self.temperature = temperature | |
self.model_endpoint = model_endpoint | |
self.stop = stop | |
self.llm = self.get_llm() | |
self.location = location | |
self.hybrid = hyrbid | |
def get_llm(self, json_model: bool = False): | |
if self.server == 'openai': | |
return OpenAIModel(model=self.model, temperature=self.temperature, json_response=json_model) | |
elif self.server == 'ollama': | |
return OllamaModel(model=self.model, temperature=self.temperature, json_response=json_model) | |
elif self.server == 'vllm': | |
return VllmModel(model=self.model, temperature=self.temperature, json_response=json_model, | |
model_endpoint=self.model_endpoint, stop=self.stop) | |
elif self.server == 'groq': | |
return GroqModel(model=self.model, temperature=self.temperature, json_response=json_model) | |
elif self.server == 'claude': | |
return ClaudeModel(temperature=self.temperature, model=self.model, json_response=json_model) | |
elif self.server == 'mistral': | |
return MistralModel(temperature=self.temperature, model=self.model, json_response=json_model) | |
elif self.server == 'gemini': | |
# raise ValueError(f"Unsupported server: {self.server}") | |
return GeminiModel(temperature=self.temperature, model=self.model, json_response=json_model) | |
else: | |
raise ValueError(f"Unsupported server: {self.server}") | |
def get_prompt(self, state: StateT = None) -> str: | |
pass | |
def get_guided_json(self, state:StateT = None) -> Dict[str, Any]: | |
pass | |
def update_state(self, key: str, value: Union[str, dict], state: StateT = None) -> StateT: | |
state[key] = value | |
return state | |
def process_response(self, response: Any, user_input: str = None, state: StateT = None) -> Dict[str, Union[str, dict]]: | |
pass | |
def get_conv_history(self, state: StateT = None) -> str: | |
pass | |
def get_user_input(self) -> str: | |
pass | |
def use_tool(self) -> Any: | |
pass | |
def invoke(self, state: StateT = None, human_in_loop: bool = False, user_input: str = None, final_answer: str = None) -> StateT: | |
prompt = self.get_prompt(state) | |
conversation_history = self.get_conv_history(state) | |
if final_answer: | |
print(colored(f"\n\n{final_answer}\n\n", "green")) | |
if human_in_loop: | |
user_input = self.get_user_input() | |
messages = [ | |
{"role": "system", "content": f"{prompt}\n Today's date is {datetime.now()}"}, | |
{"role": "user", "content": f"\n{final_answer}\n" * 10 + f"{conversation_history}\n{user_input}"} | |
] | |
if self.server == 'vllm': | |
guided_json = self.get_guided_json(state) | |
response = self.llm.invoke(messages, guided_json) | |
else: | |
response = self.llm.invoke(messages) | |
updates = self.process_response(response, user_input, state) | |
for key, value in updates.items(): | |
state = self.update_state(key, value, state) | |
return state | |