Wonderplex commited on
Commit
c3a4051
·
1 Parent(s): c423c55

Feature/select agent env (#45)

Browse files

* changed ui to include scenario and agent info

* ui layout correct; need to fix logics

* half-way through; need to fix record reading and agent pair filtering logic

* fixed deletion of app.py

* debugging gradio change

* before debug

* finished UI features

* added 5 times retry

* finished merging

Files changed (4) hide show
  1. app.py +48 -27
  2. requirements.txt +1 -0
  3. sotopia_pi_generate.py +3 -3
  4. utils.py +1 -1
app.py CHANGED
@@ -12,7 +12,7 @@ with open("openai_api.key", "r") as f:
12
  os.environ["OPENAI_API_KEY"] = f.read().strip()
13
 
14
  DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
15
- DEFAULT_MODEL_SELECTION = "cmu-lti/sotopia-pi-mistral-7b-BC_SR" # "mistralai/Mistral-7B-Instruct-v0.1"
16
  TEMPERATURE = 0.7
17
  TOP_P = 1
18
  MAX_TOKENS = 1024
@@ -100,6 +100,7 @@ def create_bot_agent_dropdown(environment_id, user_agent_id):
100
  environment, user_agent = environment_dict[environment_id], agent_dict[user_agent_id]
101
 
102
  bot_agent_list = []
 
103
  for neighbor_id in relationship_dict[environment.relationship][user_agent.agent_id]:
104
  bot_agent_list.append((agent_dict[neighbor_id].name, neighbor_id))
105
 
@@ -109,46 +110,62 @@ def create_environment_info(environment_dropdown):
109
  _, environment_dict, _, _ = get_sotopia_profiles()
110
  environment = environment_dict[environment_dropdown]
111
  text = environment.scenario
112
- return gr.Textbox(label="Scenario Information", lines=4, value=text)
113
 
114
- def create_user_info(environment_dropdown, user_agent_dropdown):
115
- _, environment_dict, agent_dict, _ = get_sotopia_profiles()
116
- environment, user_agent = environment_dict[environment_dropdown], agent_dict[user_agent_dropdown]
117
- text = f"{user_agent.background} {user_agent.personality} \n {environment.agent_goals[0]}"
118
  return gr.Textbox(label="User Agent Profile", lines=4, value=text)
119
 
120
- def create_bot_info(environment_dropdown, bot_agent_dropdown):
121
- _, environment_dict, agent_dict, _ = get_sotopia_profiles()
122
- environment, bot_agent = environment_dict[environment_dropdown], agent_dict[bot_agent_dropdown]
123
- text = f"{bot_agent.background} {bot_agent.personality} \n {environment.agent_goals[1]}"
 
124
  return gr.Textbox(label="Bot Agent Profile", lines=4, value=text)
125
 
 
 
 
 
 
 
 
 
 
 
126
  def sotopia_info_accordion(accordion_visible=True):
 
127
 
128
- with gr.Accordion("Sotopia Information", open=accordion_visible):
129
- with gr.Column():
130
- model_name_dropdown = gr.Dropdown(
131
- choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "mistralai/Mistral-7B-Instruct-v0.1", "gpt-3.5-turbo"],
132
- value="cmu-lti/sotopia-pi-mistral-7b-BC_SR",
133
- interactive=True,
134
- label="Model Selection"
135
- )
136
  with gr.Row():
137
- environments, _, _, _ = get_sotopia_profiles()
138
  environment_dropdown = gr.Dropdown(
139
  choices=environments,
140
  label="Scenario Selection",
141
  value=environments[0][1] if environments else None,
142
  interactive=True,
143
  )
144
- print(environment_dropdown.value)
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  user_agent_dropdown = create_user_agent_dropdown(environment_dropdown.value)
146
  bot_agent_dropdown = create_bot_agent_dropdown(environment_dropdown.value, user_agent_dropdown.value)
147
 
148
  with gr.Row():
149
- scenario_info_display = create_environment_info(environment_dropdown.value)
150
- user_agent_info_display = create_user_info(environment_dropdown.value, user_agent_dropdown.value)
151
- bot_agent_info_display = create_bot_info(environment_dropdown.value, bot_agent_dropdown.value)
152
 
153
  # Update user dropdown when scenario changes
154
  environment_dropdown.change(fn=create_user_agent_dropdown, inputs=[environment_dropdown], outputs=[user_agent_dropdown])
@@ -157,9 +174,13 @@ def sotopia_info_accordion(accordion_visible=True):
157
  # Update scenario information when scenario changes
158
  environment_dropdown.change(fn=create_environment_info, inputs=[environment_dropdown], outputs=[scenario_info_display])
159
  # Update user agent profile when user changes
160
- user_agent_dropdown.change(fn=create_user_info, inputs=[environment_dropdown, user_agent_dropdown], outputs=[user_agent_info_display])
161
  # Update bot agent profile when bot changes
162
- bot_agent_dropdown.change(fn=create_bot_info, inputs=[environment_dropdown, bot_agent_dropdown], outputs=[bot_agent_info_display])
 
 
 
 
163
 
164
  return model_name_dropdown, environment_dropdown, user_agent_dropdown, bot_agent_dropdown
165
 
@@ -192,12 +213,12 @@ def chat_tab():
192
  user_agent = agent_dict[user_agent_dropdown]
193
  bot_agent = agent_dict[bot_agent_dropdown]
194
 
195
- import pdb; pdb.set_trace()
196
  context = get_context_prompt(bot_agent, user_agent, environment)
197
  dialogue_history, next_turn_idx = dialogue_history_prompt(message, history, user_agent, bot_agent)
198
  prompt_history = f"{context}\n\n{dialogue_history}"
199
  agent_action = generate_action(model_selection, prompt_history, next_turn_idx, ACTION_TYPES, bot_agent.name, TEMPERATURE)
200
- import pdb; pdb.set_trace()
201
  return agent_action.to_natural_language()
202
 
203
  with gr.Column():
 
12
  os.environ["OPENAI_API_KEY"] = f.read().strip()
13
 
14
  DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
15
+ DEFAULT_MODEL_SELECTION = "gpt-3.5-turbo" # "mistralai/Mistral-7B-Instruct-v0.1"
16
  TEMPERATURE = 0.7
17
  TOP_P = 1
18
  MAX_TOKENS = 1024
 
100
  environment, user_agent = environment_dict[environment_id], agent_dict[user_agent_id]
101
 
102
  bot_agent_list = []
103
+ # import pdb; pdb.set_trace()
104
  for neighbor_id in relationship_dict[environment.relationship][user_agent.agent_id]:
105
  bot_agent_list.append((agent_dict[neighbor_id].name, neighbor_id))
106
 
 
110
  _, environment_dict, _, _ = get_sotopia_profiles()
111
  environment = environment_dict[environment_dropdown]
112
  text = environment.scenario
113
+ return gr.Textbox(label="Scenario", lines=1, value=text)
114
 
115
+ def create_user_info(user_agent_dropdown):
116
+ _, _, agent_dict, _ = get_sotopia_profiles()
117
+ user_agent = agent_dict[user_agent_dropdown]
118
+ text = f"{user_agent.background} {user_agent.personality}"
119
  return gr.Textbox(label="User Agent Profile", lines=4, value=text)
120
 
121
+ def create_bot_info(bot_agent_dropdown):
122
+ _, _, agent_dict, _ = get_sotopia_profiles()
123
+ # import pdb; pdb.set_trace()
124
+ bot_agent = agent_dict[bot_agent_dropdown]
125
+ text = f"{bot_agent.background} {bot_agent.personality}"
126
  return gr.Textbox(label="Bot Agent Profile", lines=4, value=text)
127
 
128
+ def create_user_goal(environment_dropdown):
129
+ _, environment_dict, _, _ = get_sotopia_profiles()
130
+ text = environment_dict[environment_dropdown].agent_goals[0]
131
+ return gr.Textbox(label="User Agent Goal", lines=4, value=text)
132
+
133
+ def create_bot_goal(environment_dropdown):
134
+ _, environment_dict, _, _ = get_sotopia_profiles()
135
+ text = environment_dict[environment_dropdown].agent_goals[1]
136
+ return gr.Textbox(label="Bot Agent Goal", lines=4, value=text)
137
+
138
  def sotopia_info_accordion(accordion_visible=True):
139
+ environments, _, _, _ = get_sotopia_profiles()
140
 
141
+ with gr.Accordion("Environment Configuration", open=accordion_visible):
 
 
 
 
 
 
 
142
  with gr.Row():
 
143
  environment_dropdown = gr.Dropdown(
144
  choices=environments,
145
  label="Scenario Selection",
146
  value=environments[0][1] if environments else None,
147
  interactive=True,
148
  )
149
+ model_name_dropdown = gr.Dropdown(
150
+ choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "mistralai/Mistral-7B-Instruct-v0.1", "gpt-3.5-turbo", "gpt-4-turbo"],
151
+ value=DEFAULT_MODEL_SELECTION,
152
+ interactive=True,
153
+ label="Model Selection"
154
+ )
155
+
156
+ scenario_info_display = create_environment_info(environment_dropdown.value)
157
+
158
+ with gr.Row():
159
+ bot_goal_display = create_bot_goal(environment_dropdown.value)
160
+ user_goal_display = create_user_goal(environment_dropdown.value)
161
+
162
+ with gr.Row():
163
  user_agent_dropdown = create_user_agent_dropdown(environment_dropdown.value)
164
  bot_agent_dropdown = create_bot_agent_dropdown(environment_dropdown.value, user_agent_dropdown.value)
165
 
166
  with gr.Row():
167
+ user_agent_info_display = create_user_info(user_agent_dropdown.value)
168
+ bot_agent_info_display = create_bot_info(bot_agent_dropdown.value)
 
169
 
170
  # Update user dropdown when scenario changes
171
  environment_dropdown.change(fn=create_user_agent_dropdown, inputs=[environment_dropdown], outputs=[user_agent_dropdown])
 
174
  # Update scenario information when scenario changes
175
  environment_dropdown.change(fn=create_environment_info, inputs=[environment_dropdown], outputs=[scenario_info_display])
176
  # Update user agent profile when user changes
177
+ user_agent_dropdown.change(fn=create_user_info, inputs=[user_agent_dropdown], outputs=[user_agent_info_display])
178
  # Update bot agent profile when bot changes
179
+ bot_agent_dropdown.change(fn=create_bot_info, inputs=[bot_agent_dropdown], outputs=[bot_agent_info_display])
180
+ # Update user goal when scenario changes
181
+ environment_dropdown.change(fn=create_user_goal, inputs=[environment_dropdown], outputs=[user_goal_display])
182
+ # Update bot goal when scenario changes
183
+ environment_dropdown.change(fn=create_bot_goal, inputs=[environment_dropdown], outputs=[bot_goal_display])
184
 
185
  return model_name_dropdown, environment_dropdown, user_agent_dropdown, bot_agent_dropdown
186
 
 
213
  user_agent = agent_dict[user_agent_dropdown]
214
  bot_agent = agent_dict[bot_agent_dropdown]
215
 
216
+ # import pdb; pdb.set_trace()
217
  context = get_context_prompt(bot_agent, user_agent, environment)
218
  dialogue_history, next_turn_idx = dialogue_history_prompt(message, history, user_agent, bot_agent)
219
  prompt_history = f"{context}\n\n{dialogue_history}"
220
  agent_action = generate_action(model_selection, prompt_history, next_turn_idx, ACTION_TYPES, bot_agent.name, TEMPERATURE)
221
+ # import pdb; pdb.set_trace()
222
  return agent_action.to_natural_language()
223
 
224
  with gr.Column():
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  gradio
2
  transformers
3
  torch
 
1
+ sotopia
2
  gradio
3
  transformers
4
  torch
sotopia_pi_generate.py CHANGED
@@ -113,7 +113,7 @@ def obtain_chain_hf(
113
  model, tokenizer = prepare_model(model_name)
114
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=max_tokens, temperature=temperature)
115
  hf = HuggingFacePipeline(pipeline=pipe)
116
- import pdb; pdb.set_trace()
117
  chain = LLMChain(llm=hf, prompt=chat_prompt_template)
118
  return chain
119
 
@@ -124,7 +124,7 @@ def generate(
124
  output_parser: BaseOutputParser[OutputType],
125
  temperature: float = 0.7,
126
  ) -> tuple[OutputType, str]:
127
- import pdb; pdb.set_trace()
128
  input_variables = re.findall(r"{(.*?)}", template)
129
  assert (
130
  set(input_variables) == set(list(input_values.keys()) + ["format_instructions"])
@@ -136,7 +136,7 @@ def generate(
136
  if "format_instructions" not in input_values:
137
  input_values["format_instructions"] = output_parser.get_format_instructions()
138
  result = chain.predict([], **input_values)
139
- import pdb; pdb.set_trace()
140
  try:
141
  parsed_result = output_parser.parse(result)
142
  except KeyboardInterrupt:
 
113
  model, tokenizer = prepare_model(model_name)
114
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=max_tokens, temperature=temperature)
115
  hf = HuggingFacePipeline(pipeline=pipe)
116
+ # import pdb; pdb.set_trace()
117
  chain = LLMChain(llm=hf, prompt=chat_prompt_template)
118
  return chain
119
 
 
124
  output_parser: BaseOutputParser[OutputType],
125
  temperature: float = 0.7,
126
  ) -> tuple[OutputType, str]:
127
+ # import pdb; pdb.set_trace()
128
  input_variables = re.findall(r"{(.*?)}", template)
129
  assert (
130
  set(input_variables) == set(list(input_values.keys()) + ["format_instructions"])
 
136
  if "format_instructions" not in input_values:
137
  input_values["format_instructions"] = output_parser.get_format_instructions()
138
  result = chain.predict([], **input_values)
139
+ # import pdb; pdb.set_trace()
140
  try:
141
  parsed_result = output_parser.parse(result)
142
  except KeyboardInterrupt:
utils.py CHANGED
@@ -74,7 +74,7 @@ def truncate_dialogue_history_to_length(dia_his, surpass_num, tokenizer):
74
 
75
 
76
  def format_bot_message(bot_message) -> str:
77
- # import pdb; pdb.set_trace()
78
  start_idx, end_idx = bot_message.index("{"), bot_message.index("}")
79
  if end_idx == -1:
80
  bot_message += "'}"
 
74
 
75
 
76
  def format_bot_message(bot_message) -> str:
77
+ # # import pdb; pdb.set_trace()
78
  start_idx, end_idx = bot_message.index("{"), bot_message.index("}")
79
  if end_idx == -1:
80
  bot_message += "'}"