|
""" |
|
Langchain agent |
|
""" |
|
from typing import Generator, Dict, Optional, Literal, TypedDict, List |
|
from dotenv import load_dotenv |
|
|
|
from langchain_groq import ChatGroq |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_core.messages import BaseMessage |
|
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableSerializable |
|
from langchain_core.output_parsers import StrOutputParser |
|
|
|
from .prompts import SYSTEM_PROMPT, REFERENCE_SYSTEM_PROMPT |
|
|
|
load_dotenv() |
|
valid_model_names = Literal[ |
|
'llama3-70b-8192', |
|
'llama3-8b-8192', |
|
'gemma-7b-it', |
|
'gemma2-9b-it', |
|
'mixtral-8x7b-32768' |
|
] |
|
|
|
class ResponseChunk(TypedDict): |
|
delta: str |
|
response_type: Literal['intermediate', 'output'] |
|
metadata: Dict = {} |
|
|
|
|
|
class MOAgent: |
|
def __init__( |
|
self, |
|
main_agent: RunnableSerializable[Dict, str], |
|
layer_agent: RunnableSerializable[Dict, Dict], |
|
reference_system_prompt: Optional[str] = None, |
|
cycles: Optional[int] = None, |
|
chat_memory: Optional[ConversationBufferMemory] = None |
|
) -> None: |
|
self.reference_system_prompt = reference_system_prompt or REFERENCE_SYSTEM_PROMPT |
|
self.main_agent = main_agent |
|
self.layer_agent = layer_agent |
|
self.cycles = cycles or 1 |
|
self.chat_memory = chat_memory or ConversationBufferMemory( |
|
memory_key="messages", |
|
return_messages=True |
|
) |
|
|
|
@staticmethod |
|
def concat_response( |
|
inputs: Dict[str, str], |
|
reference_system_prompt: Optional[str] = None |
|
): |
|
reference_system_prompt = reference_system_prompt or REFERENCE_SYSTEM_PROMPT |
|
|
|
responses = "" |
|
res_list = [] |
|
for i, out in enumerate(inputs.values()): |
|
responses += f"{i}. {out}\n" |
|
res_list.append(out) |
|
|
|
formatted_prompt = reference_system_prompt.format(responses=responses) |
|
return { |
|
'formatted_response': formatted_prompt, |
|
'responses': res_list |
|
} |
|
|
|
@classmethod |
|
def from_config( |
|
cls, |
|
main_model: Optional[valid_model_names] = 'llama3-70b-8192', |
|
system_prompt: Optional[str] = None, |
|
cycles: int = 1, |
|
layer_agent_config: Optional[Dict] = None, |
|
reference_system_prompt: Optional[str] = None, |
|
**main_model_kwargs |
|
): |
|
reference_system_prompt = reference_system_prompt or REFERENCE_SYSTEM_PROMPT |
|
system_prompt = system_prompt or SYSTEM_PROMPT |
|
layer_agent = MOAgent._configure_layer_agent(layer_agent_config) |
|
main_agent = MOAgent._create_agent_from_system_prompt( |
|
system_prompt=system_prompt, |
|
model_name=main_model, |
|
**main_model_kwargs |
|
) |
|
return cls( |
|
main_agent=main_agent, |
|
layer_agent=layer_agent, |
|
reference_system_prompt=reference_system_prompt, |
|
cycles=cycles |
|
) |
|
|
|
@staticmethod |
|
def _configure_layer_agent( |
|
layer_agent_config: Optional[Dict] = None |
|
) -> RunnableSerializable[Dict, Dict]: |
|
if not layer_agent_config: |
|
layer_agent_config = { |
|
'layer_agent_1' : {'system_prompt': SYSTEM_PROMPT, 'model_name': 'llama3-8b-8192'}, |
|
'layer_agent_2' : {'system_prompt': SYSTEM_PROMPT, 'model_name': 'gemma-7b-it'}, |
|
'layer_agent_3' : {'system_prompt': SYSTEM_PROMPT, 'model_name': 'mixtral-8x7b-32768'} |
|
} |
|
|
|
parallel_chain_map = dict() |
|
for key, value in layer_agent_config.items(): |
|
chain = MOAgent._create_agent_from_system_prompt( |
|
system_prompt=value.pop("system_prompt", SYSTEM_PROMPT), |
|
model_name=value.pop("model_name", 'llama3-8b-8192'), |
|
**value |
|
) |
|
parallel_chain_map[key] = RunnablePassthrough() | chain |
|
|
|
chain = parallel_chain_map | RunnableLambda(MOAgent.concat_response) |
|
return chain |
|
|
|
@staticmethod |
|
def _create_agent_from_system_prompt( |
|
system_prompt: str = SYSTEM_PROMPT, |
|
model_name: str = "llama3-8b-8192", |
|
**llm_kwargs |
|
) -> RunnableSerializable[Dict, str]: |
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", system_prompt), |
|
MessagesPlaceholder(variable_name="messages", optional=True), |
|
("human", "{input}") |
|
]) |
|
|
|
assert 'helper_response' in prompt.input_variables |
|
llm = ChatGroq(model=model_name, **llm_kwargs) |
|
|
|
chain = prompt | llm | StrOutputParser() |
|
return chain |
|
|
|
def chat( |
|
self, |
|
input: str, |
|
messages: Optional[List[BaseMessage]] = None, |
|
cycles: Optional[int] = None, |
|
save: bool = True, |
|
output_format: Literal['string', 'json'] = 'string' |
|
) -> Generator[str | ResponseChunk, None, None]: |
|
cycles = cycles or self.cycles |
|
llm_inp = { |
|
'input': input, |
|
'messages': messages or self.chat_memory.load_memory_variables({})['messages'], |
|
'helper_response': "" |
|
} |
|
for cyc in range(cycles): |
|
layer_output = self.layer_agent.invoke(llm_inp) |
|
l_frm_resp = layer_output['formatted_response'] |
|
l_resps = layer_output['responses'] |
|
|
|
llm_inp = { |
|
'input': input, |
|
'messages': self.chat_memory.load_memory_variables({})['messages'], |
|
'helper_response': l_frm_resp |
|
} |
|
|
|
if output_format == 'json': |
|
for l_out in l_resps: |
|
yield ResponseChunk( |
|
delta=l_out, |
|
response_type='intermediate', |
|
metadata={'layer': cyc + 1} |
|
) |
|
|
|
stream = self.main_agent.stream(llm_inp) |
|
response = "" |
|
for chunk in stream: |
|
if output_format == 'json': |
|
yield ResponseChunk( |
|
delta=chunk, |
|
response_type='output', |
|
metadata={} |
|
) |
|
else: |
|
yield chunk |
|
response += chunk |
|
|
|
if save: |
|
self.chat_memory.save_context({'input': input}, {'output': response}) |