dromerosm commited on
Commit
dc9a654
·
verified ·
1 Parent(s): f7e5ad1

Update moa/agent/moa.py

Browse files
Files changed (1) hide show
  1. moa/agent/moa.py +5 -4
moa/agent/moa.py CHANGED
@@ -17,9 +17,10 @@ load_dotenv()
17
  valid_model_names = Literal[
18
  'llama3-70b-8192',
19
  'llama3-8b-8192',
20
- 'gemma-7b-it',
21
  'gemma2-9b-it',
22
- 'mixtral-8x7b-32768'
 
23
  ]
24
 
25
  class ResponseChunk(TypedDict):
@@ -97,7 +98,7 @@ class MOAgent:
97
  if not layer_agent_config:
98
  layer_agent_config = {
99
  'layer_agent_1' : {'system_prompt': SYSTEM_PROMPT, 'model_name': 'llama3-8b-8192'},
100
- 'layer_agent_2' : {'system_prompt': SYSTEM_PROMPT, 'model_name': 'gemma-7b-it'},
101
  'layer_agent_3' : {'system_prompt': SYSTEM_PROMPT, 'model_name': 'mixtral-8x7b-32768'}
102
  }
103
 
@@ -105,7 +106,7 @@ class MOAgent:
105
  for key, value in layer_agent_config.items():
106
  chain = MOAgent._create_agent_from_system_prompt(
107
  system_prompt=value.pop("system_prompt", SYSTEM_PROMPT),
108
- model_name=value.pop("model_name", 'llama3-8b-8192'),
109
  **value
110
  )
111
  parallel_chain_map[key] = RunnablePassthrough() | chain
 
17
  valid_model_names = Literal[
18
  'llama3-70b-8192',
19
  'llama3-8b-8192',
20
+ 'llama-3.2-3b-preview',
21
  'gemma2-9b-it',
22
+ 'mixtral-8x7b-32768',
23
+ 'llama-3.2-1b-preview'
24
  ]
25
 
26
  class ResponseChunk(TypedDict):
 
98
  if not layer_agent_config:
99
  layer_agent_config = {
100
  'layer_agent_1' : {'system_prompt': SYSTEM_PROMPT, 'model_name': 'llama3-8b-8192'},
101
+ 'layer_agent_2' : {'system_prompt': SYSTEM_PROMPT, 'model_name': 'llama-3.2-3b-preview'},
102
  'layer_agent_3' : {'system_prompt': SYSTEM_PROMPT, 'model_name': 'mixtral-8x7b-32768'}
103
  }
104
 
 
106
  for key, value in layer_agent_config.items():
107
  chain = MOAgent._create_agent_from_system_prompt(
108
  system_prompt=value.pop("system_prompt", SYSTEM_PROMPT),
109
+ model_name=value.pop("model_name", 'llama-3.2-3b-preview'),
110
  **value
111
  )
112
  parallel_chain_map[key] = RunnablePassthrough() | chain