import streamlit as st import json from typing import Iterable from moa.agent import MOAgent from moa.agent.moa import ResponseChunk from streamlit_ace import st_ace import copy # Default configuration default_config = { "main_model": "llama-3.1-70b-versatile", "cycles": 3, "layer_agent_config": {} } layer_agent_config_def = { "layer_agent_1": { "system_prompt": "Think through your response step by step. {helper_response}", "model_name": "llama3-8b-8192" }, "layer_agent_2": { "system_prompt": "Respond with a thought and then your response to the question. {helper_response}", "model_name": "gemma-7b-it", "temperature": 0.7 }, "layer_agent_3": { "system_prompt": "You are an expert at logic and reasoning. Always take a logical approach to the answer. {helper_response}", "model_name": "llama3-8b-8192" }, } # Recommended Configuration rec_config = { "main_model": "llama-3.2-90b-text-preview", "cycles": 2, "layer_agent_config": {} } layer_agent_config_rec = { "layer_agent_1": { "system_prompt": "Think through your response step by step. {helper_response}", "model_name": "llama3-8b-8192", "temperature": 0.1 }, "layer_agent_2": { "system_prompt": "Respond with a thought and then your response to the question. {helper_response}", "model_name": "llama3-8b-8192", "temperature": 0.2 }, "layer_agent_3": { "system_prompt": "You are an expert at logic and reasoning. Always take a logical approach to the answer. {helper_response}", "model_name": "llama3-8b-8192", "temperature": 0.4 }, "layer_agent_4": { "system_prompt": "You are an expert planner agent. Create a plan for how to answer the human's query. {helper_response}", "model_name": "mixtral-8x7b-32768", "temperature": 0.5 }, } def stream_response(messages: Iterable[ResponseChunk]): layer_outputs = {} for message in messages: if message['response_type'] == 'intermediate': layer = message['metadata']['layer'] if layer not in layer_outputs: layer_outputs[layer] = [] layer_outputs[layer].append(message['delta']) else: # Display accumulated layer outputs for layer, outputs in layer_outputs.items(): st.write(f"Layer {layer}") cols = st.columns(len(outputs)) for i, output in enumerate(outputs): with cols[i]: st.expander(label=f"Agent {i+1}", expanded=False).write(output) # Clear layer outputs for the next iteration layer_outputs = {} # Yield the main agent's output yield message['delta'] def set_moa_agent( main_model: str = default_config['main_model'], cycles: int = default_config['cycles'], layer_agent_config: dict[dict[str, any]] = copy.deepcopy(layer_agent_config_def), main_model_temperature: float = 0.1, override: bool = False ): if override or ("main_model" not in st.session_state): st.session_state.main_model = main_model else: if "main_model" not in st.session_state: st.session_state.main_model = main_model if override or ("cycles" not in st.session_state): st.session_state.cycles = cycles else: if "cycles" not in st.session_state: st.session_state.cycles = cycles if override or ("layer_agent_config" not in st.session_state): st.session_state.layer_agent_config = layer_agent_config else: if "layer_agent_config" not in st.session_state: st.session_state.layer_agent_config = layer_agent_config if override or ("main_temp" not in st.session_state): st.session_state.main_temp = main_model_temperature else: if "main_temp" not in st.session_state: st.session_state.main_temp = main_model_temperature cls_ly_conf = copy.deepcopy(st.session_state.layer_agent_config) if override or ("moa_agent" not in st.session_state): st.session_state.moa_agent = MOAgent.from_config( main_model=st.session_state.main_model, cycles=st.session_state.cycles, layer_agent_config=cls_ly_conf, temperature=st.session_state.main_temp ) del cls_ly_conf del layer_agent_config st.set_page_config( page_title="Mixture-Of-Agents Powered by Groq", page_icon='/app/static/favicon.ico', menu_items={ 'About': "## Groq Mixture-Of-Agents \n Powered by [Groq](https://groq.com)" }, layout="wide" ) valid_model_names = [ 'llama-3.2-90b-text-preview', 'llama3-70b-8192', 'llama3-8b-8192', 'llama-3.1-70b-versatile', 'llama-3.1-8b-instant', 'llama3-groq-70b-8192-tool-use-preview', 'llama3-groq-8b-8192-tool-use-preview', 'gemma-7b-it', 'gemma2-9b-it', 'mixtral-8x7b-32768' ] # st.image("./static/banner.png", width=500) st.write("---") # Initialize session state if "messages" not in st.session_state: st.session_state.messages = [] set_moa_agent() # Sidebar for configuration with st.sidebar: # config_form = st.form("Agent Configuration", border=False) st.title("MOA Configuration") with st.form("Agent Configuration", border=False): if st.form_submit_button("Use Recommended Config"): try: set_moa_agent( main_model=rec_config['main_model'], cycles=rec_config['cycles'], layer_agent_config=layer_agent_config_rec, override=True ) st.session_state.messages = [] st.success("Configuration updated successfully!") except json.JSONDecodeError: st.error("Invalid JSON in Layer Agent Configuration. Please check your input.") except Exception as e: st.error(f"Error updating configuration: {str(e)}") # Main model selection new_main_model = st.selectbox( "Select Main Model", options=valid_model_names, index=valid_model_names.index(st.session_state.main_model) ) # Cycles input new_cycles = st.number_input( "Number of Layers", min_value=1, max_value=10, value=st.session_state.cycles ) # Main Model Temperature main_temperature = st.number_input( label="Main Model Temperature", value=0.1, min_value=0.0, max_value=1.0, step=0.1 ) # Layer agent configuration tooltip = "Agents in the layer agent configuration run in parallel _per cycle_. Each layer agent supports all initialization parameters of [Langchain's ChatGroq](https://api.python.langchain.com/en/latest/chat_models/langchain_groq.chat_models.ChatGroq.html) class as valid dictionary fields." st.markdown("Layer Agent Config", help=tooltip) new_layer_agent_config = st_ace( value=json.dumps(st.session_state.layer_agent_config, indent=2), language='json', placeholder="Layer Agent Configuration (JSON)", show_gutter=False, wrap=True, auto_update=True ) if st.form_submit_button("Update Configuration"): try: new_layer_config = json.loads(new_layer_agent_config) set_moa_agent( main_model=new_main_model, cycles=new_cycles, layer_agent_config=new_layer_config, main_model_temperature=main_temperature, override=True ) st.session_state.messages = [] st.success("Configuration updated successfully!") except json.JSONDecodeError: st.error("Invalid JSON in Layer Agent Configuration. Please check your input.") except Exception as e: st.error(f"Error updating configuration: {str(e)}") st.markdown("---") st.markdown(""" ### Credits - MOA: [Together AI](https://www.together.ai/blog/together-moa) - LLMs: [Groq](https://groq.com/) - Paper: [arXiv:2406.04692](https://arxiv.org/abs/2406.04692) - GitHub repo: [skapadia3214/groq-moa](https://github.com/skapadia3214/groq-moa) """) # Main app layout st.header("Mixture of Agents", anchor=False) st.write("A demo of the Mixture of Agents architecture proposed by [Together AI](https://www.together.ai/blog/together-moa), Powered by [Groq](https://groq.com/) LLMs.") st.image("/app/static/moa_groq.svg", caption="Mixture of Agents Workflow", width=1000) # Display current configuration with st.expander("Current MOA Configuration", expanded=False): st.markdown(f"**Main Model**: ``{st.session_state.main_model}``") st.markdown(f"**Main Model Temperature**: ``{st.session_state.main_temp:.1f}``") st.markdown(f"**Layers**: ``{st.session_state.cycles}``") st.markdown(f"**Layer Agents Config**:") new_layer_agent_config = st_ace( value=json.dumps(st.session_state.layer_agent_config, indent=2), language='json', placeholder="Layer Agent Configuration (JSON)", show_gutter=False, wrap=True, readonly=True, auto_update=True ) # Chat interface for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) if query := st.chat_input("Ask a question"): st.session_state.messages.append({"role": "user", "content": query}) with st.chat_message("user"): st.write(query) moa_agent: MOAgent = st.session_state.moa_agent with st.chat_message("assistant"): message_placeholder = st.empty() ast_mess = stream_response(moa_agent.chat(query, output_format='json')) response = st.write_stream(ast_mess) st.session_state.messages.append({"role": "assistant", "content": response})