Spaces:
Runtime error
Runtime error
Wonderplex
commited on
Commit
·
c3a4051
1
Parent(s):
c423c55
Feature/select agent env (#45)
Browse files* changed ui to include scenario and agent info
* ui layout correct; need to fix logics
* half-way through; need to fix record reading and agent pair filtering logic
* fixed deletion of app.py
* debugging gradio change
* before debug
* finished UI features
* added 5 times retry
* finished merging
- app.py +48 -27
- requirements.txt +1 -0
- sotopia_pi_generate.py +3 -3
- utils.py +1 -1
app.py
CHANGED
@@ -12,7 +12,7 @@ with open("openai_api.key", "r") as f:
|
|
12 |
os.environ["OPENAI_API_KEY"] = f.read().strip()
|
13 |
|
14 |
DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
|
15 |
-
DEFAULT_MODEL_SELECTION = "
|
16 |
TEMPERATURE = 0.7
|
17 |
TOP_P = 1
|
18 |
MAX_TOKENS = 1024
|
@@ -100,6 +100,7 @@ def create_bot_agent_dropdown(environment_id, user_agent_id):
|
|
100 |
environment, user_agent = environment_dict[environment_id], agent_dict[user_agent_id]
|
101 |
|
102 |
bot_agent_list = []
|
|
|
103 |
for neighbor_id in relationship_dict[environment.relationship][user_agent.agent_id]:
|
104 |
bot_agent_list.append((agent_dict[neighbor_id].name, neighbor_id))
|
105 |
|
@@ -109,46 +110,62 @@ def create_environment_info(environment_dropdown):
|
|
109 |
_, environment_dict, _, _ = get_sotopia_profiles()
|
110 |
environment = environment_dict[environment_dropdown]
|
111 |
text = environment.scenario
|
112 |
-
return gr.Textbox(label="Scenario
|
113 |
|
114 |
-
def create_user_info(
|
115 |
-
_,
|
116 |
-
|
117 |
-
text = f"{user_agent.background} {user_agent.personality}
|
118 |
return gr.Textbox(label="User Agent Profile", lines=4, value=text)
|
119 |
|
120 |
-
def create_bot_info(
|
121 |
-
_,
|
122 |
-
|
123 |
-
|
|
|
124 |
return gr.Textbox(label="Bot Agent Profile", lines=4, value=text)
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
def sotopia_info_accordion(accordion_visible=True):
|
|
|
127 |
|
128 |
-
with gr.Accordion("
|
129 |
-
with gr.Column():
|
130 |
-
model_name_dropdown = gr.Dropdown(
|
131 |
-
choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "mistralai/Mistral-7B-Instruct-v0.1", "gpt-3.5-turbo"],
|
132 |
-
value="cmu-lti/sotopia-pi-mistral-7b-BC_SR",
|
133 |
-
interactive=True,
|
134 |
-
label="Model Selection"
|
135 |
-
)
|
136 |
with gr.Row():
|
137 |
-
environments, _, _, _ = get_sotopia_profiles()
|
138 |
environment_dropdown = gr.Dropdown(
|
139 |
choices=environments,
|
140 |
label="Scenario Selection",
|
141 |
value=environments[0][1] if environments else None,
|
142 |
interactive=True,
|
143 |
)
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
user_agent_dropdown = create_user_agent_dropdown(environment_dropdown.value)
|
146 |
bot_agent_dropdown = create_bot_agent_dropdown(environment_dropdown.value, user_agent_dropdown.value)
|
147 |
|
148 |
with gr.Row():
|
149 |
-
|
150 |
-
|
151 |
-
bot_agent_info_display = create_bot_info(environment_dropdown.value, bot_agent_dropdown.value)
|
152 |
|
153 |
# Update user dropdown when scenario changes
|
154 |
environment_dropdown.change(fn=create_user_agent_dropdown, inputs=[environment_dropdown], outputs=[user_agent_dropdown])
|
@@ -157,9 +174,13 @@ def sotopia_info_accordion(accordion_visible=True):
|
|
157 |
# Update scenario information when scenario changes
|
158 |
environment_dropdown.change(fn=create_environment_info, inputs=[environment_dropdown], outputs=[scenario_info_display])
|
159 |
# Update user agent profile when user changes
|
160 |
-
user_agent_dropdown.change(fn=create_user_info, inputs=[
|
161 |
# Update bot agent profile when bot changes
|
162 |
-
bot_agent_dropdown.change(fn=create_bot_info, inputs=[
|
|
|
|
|
|
|
|
|
163 |
|
164 |
return model_name_dropdown, environment_dropdown, user_agent_dropdown, bot_agent_dropdown
|
165 |
|
@@ -192,12 +213,12 @@ def chat_tab():
|
|
192 |
user_agent = agent_dict[user_agent_dropdown]
|
193 |
bot_agent = agent_dict[bot_agent_dropdown]
|
194 |
|
195 |
-
import pdb; pdb.set_trace()
|
196 |
context = get_context_prompt(bot_agent, user_agent, environment)
|
197 |
dialogue_history, next_turn_idx = dialogue_history_prompt(message, history, user_agent, bot_agent)
|
198 |
prompt_history = f"{context}\n\n{dialogue_history}"
|
199 |
agent_action = generate_action(model_selection, prompt_history, next_turn_idx, ACTION_TYPES, bot_agent.name, TEMPERATURE)
|
200 |
-
import pdb; pdb.set_trace()
|
201 |
return agent_action.to_natural_language()
|
202 |
|
203 |
with gr.Column():
|
|
|
12 |
os.environ["OPENAI_API_KEY"] = f.read().strip()
|
13 |
|
14 |
DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
|
15 |
+
DEFAULT_MODEL_SELECTION = "gpt-3.5-turbo" # "mistralai/Mistral-7B-Instruct-v0.1"
|
16 |
TEMPERATURE = 0.7
|
17 |
TOP_P = 1
|
18 |
MAX_TOKENS = 1024
|
|
|
100 |
environment, user_agent = environment_dict[environment_id], agent_dict[user_agent_id]
|
101 |
|
102 |
bot_agent_list = []
|
103 |
+
# import pdb; pdb.set_trace()
|
104 |
for neighbor_id in relationship_dict[environment.relationship][user_agent.agent_id]:
|
105 |
bot_agent_list.append((agent_dict[neighbor_id].name, neighbor_id))
|
106 |
|
|
|
110 |
_, environment_dict, _, _ = get_sotopia_profiles()
|
111 |
environment = environment_dict[environment_dropdown]
|
112 |
text = environment.scenario
|
113 |
+
return gr.Textbox(label="Scenario", lines=1, value=text)
|
114 |
|
115 |
+
def create_user_info(user_agent_dropdown):
|
116 |
+
_, _, agent_dict, _ = get_sotopia_profiles()
|
117 |
+
user_agent = agent_dict[user_agent_dropdown]
|
118 |
+
text = f"{user_agent.background} {user_agent.personality}"
|
119 |
return gr.Textbox(label="User Agent Profile", lines=4, value=text)
|
120 |
|
121 |
+
def create_bot_info(bot_agent_dropdown):
|
122 |
+
_, _, agent_dict, _ = get_sotopia_profiles()
|
123 |
+
# import pdb; pdb.set_trace()
|
124 |
+
bot_agent = agent_dict[bot_agent_dropdown]
|
125 |
+
text = f"{bot_agent.background} {bot_agent.personality}"
|
126 |
return gr.Textbox(label="Bot Agent Profile", lines=4, value=text)
|
127 |
|
128 |
+
def create_user_goal(environment_dropdown):
|
129 |
+
_, environment_dict, _, _ = get_sotopia_profiles()
|
130 |
+
text = environment_dict[environment_dropdown].agent_goals[0]
|
131 |
+
return gr.Textbox(label="User Agent Goal", lines=4, value=text)
|
132 |
+
|
133 |
+
def create_bot_goal(environment_dropdown):
|
134 |
+
_, environment_dict, _, _ = get_sotopia_profiles()
|
135 |
+
text = environment_dict[environment_dropdown].agent_goals[1]
|
136 |
+
return gr.Textbox(label="Bot Agent Goal", lines=4, value=text)
|
137 |
+
|
138 |
def sotopia_info_accordion(accordion_visible=True):
|
139 |
+
environments, _, _, _ = get_sotopia_profiles()
|
140 |
|
141 |
+
with gr.Accordion("Environment Configuration", open=accordion_visible):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
with gr.Row():
|
|
|
143 |
environment_dropdown = gr.Dropdown(
|
144 |
choices=environments,
|
145 |
label="Scenario Selection",
|
146 |
value=environments[0][1] if environments else None,
|
147 |
interactive=True,
|
148 |
)
|
149 |
+
model_name_dropdown = gr.Dropdown(
|
150 |
+
choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "mistralai/Mistral-7B-Instruct-v0.1", "gpt-3.5-turbo", "gpt-4-turbo"],
|
151 |
+
value=DEFAULT_MODEL_SELECTION,
|
152 |
+
interactive=True,
|
153 |
+
label="Model Selection"
|
154 |
+
)
|
155 |
+
|
156 |
+
scenario_info_display = create_environment_info(environment_dropdown.value)
|
157 |
+
|
158 |
+
with gr.Row():
|
159 |
+
bot_goal_display = create_bot_goal(environment_dropdown.value)
|
160 |
+
user_goal_display = create_user_goal(environment_dropdown.value)
|
161 |
+
|
162 |
+
with gr.Row():
|
163 |
user_agent_dropdown = create_user_agent_dropdown(environment_dropdown.value)
|
164 |
bot_agent_dropdown = create_bot_agent_dropdown(environment_dropdown.value, user_agent_dropdown.value)
|
165 |
|
166 |
with gr.Row():
|
167 |
+
user_agent_info_display = create_user_info(user_agent_dropdown.value)
|
168 |
+
bot_agent_info_display = create_bot_info(bot_agent_dropdown.value)
|
|
|
169 |
|
170 |
# Update user dropdown when scenario changes
|
171 |
environment_dropdown.change(fn=create_user_agent_dropdown, inputs=[environment_dropdown], outputs=[user_agent_dropdown])
|
|
|
174 |
# Update scenario information when scenario changes
|
175 |
environment_dropdown.change(fn=create_environment_info, inputs=[environment_dropdown], outputs=[scenario_info_display])
|
176 |
# Update user agent profile when user changes
|
177 |
+
user_agent_dropdown.change(fn=create_user_info, inputs=[user_agent_dropdown], outputs=[user_agent_info_display])
|
178 |
# Update bot agent profile when bot changes
|
179 |
+
bot_agent_dropdown.change(fn=create_bot_info, inputs=[bot_agent_dropdown], outputs=[bot_agent_info_display])
|
180 |
+
# Update user goal when scenario changes
|
181 |
+
environment_dropdown.change(fn=create_user_goal, inputs=[environment_dropdown], outputs=[user_goal_display])
|
182 |
+
# Update bot goal when scenario changes
|
183 |
+
environment_dropdown.change(fn=create_bot_goal, inputs=[environment_dropdown], outputs=[bot_goal_display])
|
184 |
|
185 |
return model_name_dropdown, environment_dropdown, user_agent_dropdown, bot_agent_dropdown
|
186 |
|
|
|
213 |
user_agent = agent_dict[user_agent_dropdown]
|
214 |
bot_agent = agent_dict[bot_agent_dropdown]
|
215 |
|
216 |
+
# import pdb; pdb.set_trace()
|
217 |
context = get_context_prompt(bot_agent, user_agent, environment)
|
218 |
dialogue_history, next_turn_idx = dialogue_history_prompt(message, history, user_agent, bot_agent)
|
219 |
prompt_history = f"{context}\n\n{dialogue_history}"
|
220 |
agent_action = generate_action(model_selection, prompt_history, next_turn_idx, ACTION_TYPES, bot_agent.name, TEMPERATURE)
|
221 |
+
# import pdb; pdb.set_trace()
|
222 |
return agent_action.to_natural_language()
|
223 |
|
224 |
with gr.Column():
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
gradio
|
2 |
transformers
|
3 |
torch
|
|
|
1 |
+
sotopia
|
2 |
gradio
|
3 |
transformers
|
4 |
torch
|
sotopia_pi_generate.py
CHANGED
@@ -113,7 +113,7 @@ def obtain_chain_hf(
|
|
113 |
model, tokenizer = prepare_model(model_name)
|
114 |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=max_tokens, temperature=temperature)
|
115 |
hf = HuggingFacePipeline(pipeline=pipe)
|
116 |
-
import pdb; pdb.set_trace()
|
117 |
chain = LLMChain(llm=hf, prompt=chat_prompt_template)
|
118 |
return chain
|
119 |
|
@@ -124,7 +124,7 @@ def generate(
|
|
124 |
output_parser: BaseOutputParser[OutputType],
|
125 |
temperature: float = 0.7,
|
126 |
) -> tuple[OutputType, str]:
|
127 |
-
import pdb; pdb.set_trace()
|
128 |
input_variables = re.findall(r"{(.*?)}", template)
|
129 |
assert (
|
130 |
set(input_variables) == set(list(input_values.keys()) + ["format_instructions"])
|
@@ -136,7 +136,7 @@ def generate(
|
|
136 |
if "format_instructions" not in input_values:
|
137 |
input_values["format_instructions"] = output_parser.get_format_instructions()
|
138 |
result = chain.predict([], **input_values)
|
139 |
-
import pdb; pdb.set_trace()
|
140 |
try:
|
141 |
parsed_result = output_parser.parse(result)
|
142 |
except KeyboardInterrupt:
|
|
|
113 |
model, tokenizer = prepare_model(model_name)
|
114 |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=max_tokens, temperature=temperature)
|
115 |
hf = HuggingFacePipeline(pipeline=pipe)
|
116 |
+
# import pdb; pdb.set_trace()
|
117 |
chain = LLMChain(llm=hf, prompt=chat_prompt_template)
|
118 |
return chain
|
119 |
|
|
|
124 |
output_parser: BaseOutputParser[OutputType],
|
125 |
temperature: float = 0.7,
|
126 |
) -> tuple[OutputType, str]:
|
127 |
+
# import pdb; pdb.set_trace()
|
128 |
input_variables = re.findall(r"{(.*?)}", template)
|
129 |
assert (
|
130 |
set(input_variables) == set(list(input_values.keys()) + ["format_instructions"])
|
|
|
136 |
if "format_instructions" not in input_values:
|
137 |
input_values["format_instructions"] = output_parser.get_format_instructions()
|
138 |
result = chain.predict([], **input_values)
|
139 |
+
# import pdb; pdb.set_trace()
|
140 |
try:
|
141 |
parsed_result = output_parser.parse(result)
|
142 |
except KeyboardInterrupt:
|
utils.py
CHANGED
@@ -74,7 +74,7 @@ def truncate_dialogue_history_to_length(dia_his, surpass_num, tokenizer):
|
|
74 |
|
75 |
|
76 |
def format_bot_message(bot_message) -> str:
|
77 |
-
# import pdb; pdb.set_trace()
|
78 |
start_idx, end_idx = bot_message.index("{"), bot_message.index("}")
|
79 |
if end_idx == -1:
|
80 |
bot_message += "'}"
|
|
|
74 |
|
75 |
|
76 |
def format_bot_message(bot_message) -> str:
|
77 |
+
# # import pdb; pdb.set_trace()
|
78 |
start_idx, end_idx = bot_message.index("{"), bot_message.index("}")
|
79 |
if end_idx == -1:
|
80 |
bot_message += "'}"
|