Spaces:
Runtime error
Runtime error
Wonderplex
commited on
Commit
·
c423c55
1
Parent(s):
0c22348
adopated sotopia logics (#44)
Browse files- .gitignore +3 -1
- app.py +22 -76
- sotopia_pi_generate.py +248 -0
- utils.py +20 -6
.gitignore
CHANGED
@@ -1,2 +1,4 @@
|
|
1 |
__pycache__/
|
2 |
-
.cache/
|
|
|
|
|
|
1 |
__pycache__/
|
2 |
+
.cache/
|
3 |
+
openai_api.key
|
4 |
+
core
|
app.py
CHANGED
@@ -1,25 +1,19 @@
|
|
1 |
import os
|
2 |
from collections import defaultdict
|
3 |
-
from dataclasses import dataclass
|
4 |
-
from uuid import uuid4
|
5 |
import json
|
6 |
|
7 |
import gradio as gr
|
8 |
-
import torch
|
9 |
-
import transformers
|
10 |
-
from peft import PeftConfig, PeftModel, get_peft_model
|
11 |
-
from transformers import (
|
12 |
-
AutoModelForCausalLM,
|
13 |
-
AutoTokenizer,
|
14 |
-
BitsAndBytesConfig,
|
15 |
-
)
|
16 |
|
17 |
-
from utils import Environment, Agent,
|
18 |
from functools import cache
|
|
|
|
|
|
|
|
|
19 |
|
20 |
DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
|
21 |
DEFAULT_MODEL_SELECTION = "cmu-lti/sotopia-pi-mistral-7b-BC_SR" # "mistralai/Mistral-7B-Instruct-v0.1"
|
22 |
-
TEMPERATURE = 0.
|
23 |
TOP_P = 1
|
24 |
MAX_TOKENS = 1024
|
25 |
|
@@ -27,6 +21,7 @@ ENVIRONMENT_PROFILES = "profiles/environment_profiles.jsonl"
|
|
27 |
AGENT_PROFILES = "profiles/agent_profiles.jsonl"
|
28 |
RELATIONSHIP_PROFILES = "profiles/relationship_profiles.jsonl"
|
29 |
|
|
|
30 |
|
31 |
@cache
|
32 |
def get_sotopia_profiles(env_file=ENVIRONMENT_PROFILES, agent_file=AGENT_PROFILES, relationship_file=RELATIONSHIP_PROFILES):
|
@@ -68,35 +63,6 @@ def get_sotopia_profiles(env_file=ENVIRONMENT_PROFILES, agent_file=AGENT_PROFILE
|
|
68 |
|
69 |
return environments, environment_dict, agent_dict, relationship_dict
|
70 |
|
71 |
-
@cache
|
72 |
-
def prepare_model(model_name):
|
73 |
-
compute_type = torch.float16
|
74 |
-
|
75 |
-
if 'cmu-lti/sotopia-pi-mistral-7b-BC_SR'in model_name:
|
76 |
-
model = AutoModelForCausalLM.from_pretrained(
|
77 |
-
"mistralai/Mistral-7B-Instruct-v0.1",
|
78 |
-
cache_dir="./.cache",
|
79 |
-
device_map='cuda',
|
80 |
-
quantization_config=BitsAndBytesConfig(
|
81 |
-
load_in_4bit=True,
|
82 |
-
bnb_4bit_use_double_quant=True,
|
83 |
-
bnb_4bit_quant_type="nf4",
|
84 |
-
bnb_4bit_compute_dtype=compute_type,
|
85 |
-
)
|
86 |
-
)
|
87 |
-
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
|
88 |
-
model = PeftModel.from_pretrained(model, model_name).to("cuda")
|
89 |
-
elif 'mistralai/Mistral-7B-Instruct-v0.1' in model_name:
|
90 |
-
model = AutoModelForCausalLM.from_pretrained(
|
91 |
-
"mistralai/Mistral-7B-Instruct-v0.1",
|
92 |
-
cache_dir="./.cache",
|
93 |
-
device_map='cuda',
|
94 |
-
)
|
95 |
-
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
|
96 |
-
else:
|
97 |
-
raise RuntimeError(f"Model {model_name} not supported")
|
98 |
-
return model, tokenizer
|
99 |
-
|
100 |
|
101 |
def introduction():
|
102 |
with gr.Column(scale=2):
|
@@ -162,7 +128,7 @@ def sotopia_info_accordion(accordion_visible=True):
|
|
162 |
with gr.Accordion("Sotopia Information", open=accordion_visible):
|
163 |
with gr.Column():
|
164 |
model_name_dropdown = gr.Dropdown(
|
165 |
-
choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "mistralai/Mistral-7B-Instruct-v0.1", "
|
166 |
value="cmu-lti/sotopia-pi-mistral-7b-BC_SR",
|
167 |
interactive=True,
|
168 |
label="Model Selection"
|
@@ -213,50 +179,30 @@ def instructions_accordion(instructions, according_visible=False):
|
|
213 |
|
214 |
def chat_tab():
|
215 |
# history are input output pairs
|
|
|
216 |
def run_chat(
|
217 |
message,
|
218 |
history,
|
219 |
-
|
220 |
user_agent_dropdown,
|
221 |
bot_agent_dropdown,
|
222 |
model_selection:str
|
223 |
):
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
)
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
temperature=TEMPERATURE,
|
236 |
-
top_p=TOP_P,
|
237 |
-
max_length=MAX_TOKENS,
|
238 |
-
pad_token_id=tokenizer.eos_token_id,
|
239 |
-
num_return_sequences=1,
|
240 |
-
)
|
241 |
-
output_tokens = output_tokens[:, input_length:]
|
242 |
-
text_output = tokenizer.decode(
|
243 |
-
output_tokens[0], skip_special_tokens=True
|
244 |
-
)
|
245 |
-
output = ""
|
246 |
-
for _ in range(5):
|
247 |
-
try:
|
248 |
-
output = format_bot_message(text_output)
|
249 |
-
break
|
250 |
-
except Exception as e:
|
251 |
-
print(e)
|
252 |
-
print("Retrying...")
|
253 |
-
return output
|
254 |
|
255 |
-
_, environment_dict, agent_dict, _ = get_sotopia_profiles()
|
256 |
with gr.Column():
|
257 |
with gr.Row():
|
258 |
model_name_dropdown, scenario_dropdown, user_agent_dropdown, bot_agent_dropdown = sotopia_info_accordion()
|
259 |
-
starter_prompt = gr.Textbox(value=get_starter_prompt(agent_dict[user_agent_dropdown.value], agent_dict[bot_agent_dropdown.value], environment_dict[scenario_dropdown.value]), label="Modify the prompt as needed:", visible=False)
|
260 |
|
261 |
with gr.Column():
|
262 |
with gr.Blocks():
|
@@ -279,7 +225,7 @@ def chat_tab():
|
|
279 |
rtl=False,
|
280 |
),
|
281 |
additional_inputs=[
|
282 |
-
|
283 |
user_agent_dropdown,
|
284 |
bot_agent_dropdown,
|
285 |
model_name_dropdown,
|
|
|
1 |
import os
|
2 |
from collections import defaultdict
|
|
|
|
|
3 |
import json
|
4 |
|
5 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
from utils import Environment, Agent, get_context_prompt, dialogue_history_prompt
|
8 |
from functools import cache
|
9 |
+
from sotopia_pi_generate import prepare_model, generate_action
|
10 |
+
|
11 |
+
with open("openai_api.key", "r") as f:
|
12 |
+
os.environ["OPENAI_API_KEY"] = f.read().strip()
|
13 |
|
14 |
DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
|
15 |
DEFAULT_MODEL_SELECTION = "cmu-lti/sotopia-pi-mistral-7b-BC_SR" # "mistralai/Mistral-7B-Instruct-v0.1"
|
16 |
+
TEMPERATURE = 0.7
|
17 |
TOP_P = 1
|
18 |
MAX_TOKENS = 1024
|
19 |
|
|
|
21 |
AGENT_PROFILES = "profiles/agent_profiles.jsonl"
|
22 |
RELATIONSHIP_PROFILES = "profiles/relationship_profiles.jsonl"
|
23 |
|
24 |
+
ACTION_TYPES = ['none', 'action', 'non-verbal communication', 'speak', 'leave']
|
25 |
|
26 |
@cache
|
27 |
def get_sotopia_profiles(env_file=ENVIRONMENT_PROFILES, agent_file=AGENT_PROFILES, relationship_file=RELATIONSHIP_PROFILES):
|
|
|
63 |
|
64 |
return environments, environment_dict, agent_dict, relationship_dict
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
def introduction():
|
68 |
with gr.Column(scale=2):
|
|
|
128 |
with gr.Accordion("Sotopia Information", open=accordion_visible):
|
129 |
with gr.Column():
|
130 |
model_name_dropdown = gr.Dropdown(
|
131 |
+
choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "mistralai/Mistral-7B-Instruct-v0.1", "gpt-3.5-turbo"],
|
132 |
value="cmu-lti/sotopia-pi-mistral-7b-BC_SR",
|
133 |
interactive=True,
|
134 |
label="Model Selection"
|
|
|
179 |
|
180 |
def chat_tab():
|
181 |
# history are input output pairs
|
182 |
+
_, environment_dict, agent_dict, _ = get_sotopia_profiles()
|
183 |
def run_chat(
|
184 |
message,
|
185 |
history,
|
186 |
+
environment_selection,
|
187 |
user_agent_dropdown,
|
188 |
bot_agent_dropdown,
|
189 |
model_selection:str
|
190 |
):
|
191 |
+
environment = environment_dict[environment_selection]
|
192 |
+
user_agent = agent_dict[user_agent_dropdown]
|
193 |
+
bot_agent = agent_dict[bot_agent_dropdown]
|
194 |
+
|
195 |
+
import pdb; pdb.set_trace()
|
196 |
+
context = get_context_prompt(bot_agent, user_agent, environment)
|
197 |
+
dialogue_history, next_turn_idx = dialogue_history_prompt(message, history, user_agent, bot_agent)
|
198 |
+
prompt_history = f"{context}\n\n{dialogue_history}"
|
199 |
+
agent_action = generate_action(model_selection, prompt_history, next_turn_idx, ACTION_TYPES, bot_agent.name, TEMPERATURE)
|
200 |
+
import pdb; pdb.set_trace()
|
201 |
+
return agent_action.to_natural_language()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
|
|
203 |
with gr.Column():
|
204 |
with gr.Row():
|
205 |
model_name_dropdown, scenario_dropdown, user_agent_dropdown, bot_agent_dropdown = sotopia_info_accordion()
|
|
|
206 |
|
207 |
with gr.Column():
|
208 |
with gr.Blocks():
|
|
|
225 |
rtl=False,
|
226 |
),
|
227 |
additional_inputs=[
|
228 |
+
scenario_dropdown,
|
229 |
user_agent_dropdown,
|
230 |
bot_agent_dropdown,
|
231 |
model_name_dropdown,
|
sotopia_pi_generate.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from peft import PeftModel
|
5 |
+
from transformers import (
|
6 |
+
AutoModelForCausalLM,
|
7 |
+
AutoTokenizer,
|
8 |
+
BitsAndBytesConfig,
|
9 |
+
)
|
10 |
+
|
11 |
+
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
12 |
+
from langchain_community.chat_models import ChatLiteLLM
|
13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
14 |
+
|
15 |
+
from langchain.chains import LLMChain
|
16 |
+
from langchain.output_parsers import PydanticOutputParser
|
17 |
+
from langchain.prompts import (
|
18 |
+
ChatPromptTemplate,
|
19 |
+
HumanMessagePromptTemplate,
|
20 |
+
PromptTemplate,
|
21 |
+
)
|
22 |
+
from langchain.schema import BaseOutputParser, OutputParserException
|
23 |
+
from typing import TypeVar
|
24 |
+
|
25 |
+
from sotopia.messages import ActionType, AgentAction
|
26 |
+
from sotopia.utils import format_docstring
|
27 |
+
from functools import cache
|
28 |
+
import logging
|
29 |
+
|
30 |
+
OutputType = TypeVar("OutputType", bound=object)
|
31 |
+
|
32 |
+
log = logging.getLogger("generate")
|
33 |
+
# logging_handler = LoggingCallbackHandler("langchain")
|
34 |
+
|
35 |
+
def generate_action(
|
36 |
+
model_name: str,
|
37 |
+
history: str,
|
38 |
+
turn_number: int,
|
39 |
+
action_types: list[ActionType],
|
40 |
+
agent: str,
|
41 |
+
temperature: float = 0.7,
|
42 |
+
) -> tuple[AgentAction, str]:
|
43 |
+
"""
|
44 |
+
Using langchain to generate an example episode
|
45 |
+
"""
|
46 |
+
try:
|
47 |
+
# Normal case, model as agent
|
48 |
+
template = """
|
49 |
+
Imagine you are {agent}, your task is to act/speak as {agent} would, keeping in mind {agent}'s social goal.
|
50 |
+
You can find {agent}'s goal (or background) in the 'Here is the context of the interaction' field.
|
51 |
+
Note that {agent}'s goal is only visible to you.
|
52 |
+
You should try your best to achieve {agent}'s goal in a way that align with their character traits.
|
53 |
+
Additionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).
|
54 |
+
{history}.
|
55 |
+
You are at Turn #{turn_number}. Your available action types are
|
56 |
+
{action_list}.
|
57 |
+
Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave.
|
58 |
+
|
59 |
+
Please only generate a JSON string including the action type and the argument.
|
60 |
+
Your action should follow the given format:
|
61 |
+
{format_instructions}
|
62 |
+
"""
|
63 |
+
return generate(
|
64 |
+
model_name=model_name,
|
65 |
+
template=template,
|
66 |
+
input_values=dict(
|
67 |
+
agent=agent,
|
68 |
+
turn_number=str(turn_number),
|
69 |
+
history=history,
|
70 |
+
action_list=" ".join(action_types),
|
71 |
+
),
|
72 |
+
output_parser=PydanticOutputParser(pydantic_object=AgentAction),
|
73 |
+
temperature=temperature,
|
74 |
+
)
|
75 |
+
except Exception:
|
76 |
+
return AgentAction(action_type="none", argument=""), ""
|
77 |
+
|
78 |
+
@cache
|
79 |
+
def prepare_model(model_name):
|
80 |
+
compute_type = torch.float16
|
81 |
+
|
82 |
+
if 'cmu-lti/sotopia-pi-mistral-7b-BC_SR'in model_name:
|
83 |
+
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token="REDACTED")
|
84 |
+
model = AutoModelForCausalLM.from_pretrained(
|
85 |
+
"mistralai/Mistral-7B-Instruct-v0.1",
|
86 |
+
cache_dir="./.cache",
|
87 |
+
device_map='cuda',
|
88 |
+
quantization_config=BitsAndBytesConfig(
|
89 |
+
load_in_4bit=True,
|
90 |
+
bnb_4bit_use_double_quant=True,
|
91 |
+
bnb_4bit_quant_type="nf4",
|
92 |
+
bnb_4bit_compute_dtype=compute_type,
|
93 |
+
),
|
94 |
+
token="REDACTED"
|
95 |
+
)
|
96 |
+
model = PeftModel.from_pretrained(model, model_name).to("cuda")
|
97 |
+
else:
|
98 |
+
raise RuntimeError(f"Model {model_name} not supported")
|
99 |
+
return model, tokenizer
|
100 |
+
|
101 |
+
def obtain_chain_hf(
|
102 |
+
model_name: str,
|
103 |
+
template: str,
|
104 |
+
input_variables: list[str],
|
105 |
+
temperature: float = 0.7,
|
106 |
+
max_retries: int = 6,
|
107 |
+
max_tokens: int = 2700
|
108 |
+
) -> LLMChain:
|
109 |
+
human_message_prompt = HumanMessagePromptTemplate(
|
110 |
+
prompt=PromptTemplate(template=template, input_variables=input_variables)
|
111 |
+
)
|
112 |
+
chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
|
113 |
+
model, tokenizer = prepare_model(model_name)
|
114 |
+
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=max_tokens, temperature=temperature)
|
115 |
+
hf = HuggingFacePipeline(pipeline=pipe)
|
116 |
+
import pdb; pdb.set_trace()
|
117 |
+
chain = LLMChain(llm=hf, prompt=chat_prompt_template)
|
118 |
+
return chain
|
119 |
+
|
120 |
+
def generate(
|
121 |
+
model_name: str,
|
122 |
+
template: str,
|
123 |
+
input_values: dict[str, str],
|
124 |
+
output_parser: BaseOutputParser[OutputType],
|
125 |
+
temperature: float = 0.7,
|
126 |
+
) -> tuple[OutputType, str]:
|
127 |
+
import pdb; pdb.set_trace()
|
128 |
+
input_variables = re.findall(r"{(.*?)}", template)
|
129 |
+
assert (
|
130 |
+
set(input_variables) == set(list(input_values.keys()) + ["format_instructions"])
|
131 |
+
or set(input_variables) == set(list(input_values.keys()))
|
132 |
+
), f"The variables in the template must match input_values except for format_instructions. Got {sorted(input_values.keys())}, expect {sorted(input_variables)}"
|
133 |
+
# process template
|
134 |
+
template = format_docstring(template)
|
135 |
+
chain = obtain_chain(model_name, template, input_variables, temperature)
|
136 |
+
if "format_instructions" not in input_values:
|
137 |
+
input_values["format_instructions"] = output_parser.get_format_instructions()
|
138 |
+
result = chain.predict([], **input_values)
|
139 |
+
import pdb; pdb.set_trace()
|
140 |
+
try:
|
141 |
+
parsed_result = output_parser.parse(result)
|
142 |
+
except KeyboardInterrupt:
|
143 |
+
raise KeyboardInterrupt
|
144 |
+
except Exception as e:
|
145 |
+
log.debug(
|
146 |
+
f"[red] Failed to parse result: {result}\nEncounter Exception {e}\nstart to reparse",
|
147 |
+
extra={"markup": True},
|
148 |
+
)
|
149 |
+
reformat_parsed_result = format_bad_output(
|
150 |
+
result, format_instructions=output_parser.get_format_instructions()
|
151 |
+
)
|
152 |
+
parsed_result = output_parser.parse(reformat_parsed_result)
|
153 |
+
log.info(f"Generated result: {parsed_result}")
|
154 |
+
return parsed_result
|
155 |
+
|
156 |
+
def format_bad_output(
|
157 |
+
ill_formed_output: str,
|
158 |
+
format_instructions: str,
|
159 |
+
model_name: str = "gpt-3.5-turbo",
|
160 |
+
) -> str:
|
161 |
+
template = """
|
162 |
+
Given the string that can not be parsed by json parser, reformat it to a string that can be parsed by json parser.
|
163 |
+
Original string: {ill_formed_output}
|
164 |
+
|
165 |
+
Format instructions: {format_instructions}
|
166 |
+
|
167 |
+
Please only generate the JSON:
|
168 |
+
"""
|
169 |
+
chain = obtain_chain(
|
170 |
+
model_name=model_name,
|
171 |
+
template=template,
|
172 |
+
input_variables=re.findall(r"{(.*?)}", template),
|
173 |
+
)
|
174 |
+
input_values = {
|
175 |
+
"ill_formed_output": ill_formed_output,
|
176 |
+
"format_instructions": format_instructions,
|
177 |
+
}
|
178 |
+
reformat = chain.predict([], **input_values)
|
179 |
+
log.info(f"Reformated output: {reformat}")
|
180 |
+
return reformat
|
181 |
+
|
182 |
+
def obtain_chain(
|
183 |
+
model_name: str,
|
184 |
+
template: str,
|
185 |
+
input_variables: list[str],
|
186 |
+
temperature: float = 0.7,
|
187 |
+
max_retries: int = 6,
|
188 |
+
) -> LLMChain:
|
189 |
+
"""
|
190 |
+
Using langchain to sample profiles for participants
|
191 |
+
"""
|
192 |
+
if model_name in ["cmu-lti/sotopia-pi-mistral-7b-BC_SR"]:
|
193 |
+
return obtain_chain_hf(
|
194 |
+
model_name=model_name,
|
195 |
+
template=template,
|
196 |
+
input_variables=input_variables,
|
197 |
+
temperature=temperature,
|
198 |
+
max_retries=max_retries,
|
199 |
+
)
|
200 |
+
|
201 |
+
model_name = _return_fixed_model_version(model_name)
|
202 |
+
chat = ChatLiteLLM(
|
203 |
+
model=model_name,
|
204 |
+
temperature=temperature,
|
205 |
+
max_tokens=2700, # tweak as needed
|
206 |
+
max_retries=max_retries,
|
207 |
+
)
|
208 |
+
human_message_prompt = HumanMessagePromptTemplate(
|
209 |
+
prompt=PromptTemplate(template=template, input_variables=input_variables)
|
210 |
+
)
|
211 |
+
chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
|
212 |
+
chain = LLMChain(llm=chat, prompt=chat_prompt_template)
|
213 |
+
return chain
|
214 |
+
|
215 |
+
def format_bad_output(
|
216 |
+
ill_formed_output: str,
|
217 |
+
format_instructions: str,
|
218 |
+
model_name: str = "gpt-3.5-turbo",
|
219 |
+
) -> str:
|
220 |
+
template = """
|
221 |
+
Given the string that can not be parsed by json parser, reformat it to a string that can be parsed by json parser.
|
222 |
+
Original string: {ill_formed_output}
|
223 |
+
|
224 |
+
Format instructions: {format_instructions}
|
225 |
+
|
226 |
+
Please only generate the JSON:
|
227 |
+
"""
|
228 |
+
chain = obtain_chain(
|
229 |
+
model_name=model_name,
|
230 |
+
template=template,
|
231 |
+
input_variables=re.findall(r"{(.*?)}", template),
|
232 |
+
)
|
233 |
+
input_values = {
|
234 |
+
"ill_formed_output": ill_formed_output,
|
235 |
+
"format_instructions": format_instructions,
|
236 |
+
}
|
237 |
+
reformat = chain.predict([], **input_values)
|
238 |
+
log.info(f"Reformated output: {reformat}")
|
239 |
+
return reformat
|
240 |
+
|
241 |
+
def _return_fixed_model_version(model_name: str) -> str:
|
242 |
+
return {
|
243 |
+
"gpt-3.5-turbo": "gpt-3.5-turbo-0613",
|
244 |
+
"gpt-3.5-turbo-finetuned": "ft:gpt-3.5-turbo-0613:academicscmu::8nY2zgdt",
|
245 |
+
"gpt-3.5-turbo-ft-MF": "ft:gpt-3.5-turbo-0613:academicscmu::8nuER4bO",
|
246 |
+
"gpt-4": "gpt-4-0613",
|
247 |
+
"gpt-4-turbo": "gpt-4-1106-preview",
|
248 |
+
}[model_name]
|
utils.py
CHANGED
@@ -44,7 +44,10 @@ def get_format_guide():
|
|
44 |
"""
|
45 |
|
46 |
def get_starter_prompt(machine_agent, human_agent, environment):
|
47 |
-
return f"
|
|
|
|
|
|
|
48 |
|
49 |
|
50 |
# we define history as
|
@@ -102,6 +105,20 @@ def dialogue_history_creation(history, user_name, bot_name):
|
|
102 |
last_turn_idx = len(history) * 2
|
103 |
return dialogue_history, last_turn_idx
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
def dialogue_history_truncation(dialogue_history, max_token_num, tokenizer):
|
107 |
surpass_num = dialogue_history_length_check(
|
@@ -114,15 +131,12 @@ def dialogue_history_truncation(dialogue_history, max_token_num, tokenizer):
|
|
114 |
return dialogue_history
|
115 |
|
116 |
|
117 |
-
def
|
118 |
message: str,
|
119 |
history: List[Tuple[str, str]],
|
120 |
instructions: str,
|
121 |
user_name: str,
|
122 |
bot_name: str,
|
123 |
-
include_all_chat_history: bool = True,
|
124 |
-
index: int = 1,
|
125 |
-
use_format_guide: bool = True,
|
126 |
) -> str:
|
127 |
prompt = instructions.strip()
|
128 |
dialogue_history, last_turn_idx = dialogue_history_creation(
|
@@ -130,4 +144,4 @@ def format_sotopia_prompt(
|
|
130 |
)
|
131 |
prompt = f"{prompt}\n{dialogue_history}"
|
132 |
prompt = f"{prompt}\n\nTurn #{last_turn_idx+1}: {user_name}: {message}\n.\nYou are at Turn #{last_turn_idx+2}."
|
133 |
-
return prompt
|
|
|
44 |
"""
|
45 |
|
46 |
def get_starter_prompt(machine_agent, human_agent, environment):
|
47 |
+
return f"Imagine you are {machine_agent.name}, your task is to act/speak as {machine_agent.name} would, keeping in mind {machine_agent.name}'s social goal.\nYou can find {machine_agent.name}'s background and goal in the 'Here is the context of the interaction' field.\nNote that {machine_agent.name}'s secret and goal is only visible to you.\nYou should try your best to achieve {machine_agent.name}'s goal in a way that align with their character traits.\nAdditionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).\n\nHere is the context of this interaction:\n Scenario: {environment.scenario}\nParticipants: {human_agent.name} and {machine_agent.name}\n{human_agent.name}'s background: {human_agent.background} Personality and values description: {human_agent.personality} \n{machine_agent.name}'s background: {machine_agent.background} Personality and values description: {machine_agent.personality} {machine_agent.name}'s secrets: {machine_agent.secret}\n{human_agent.name}'s goal: Unknown\n{machine_agent.name}'s goal: {environment.agent_goals[1]}\nConversation Starts:"
|
48 |
+
|
49 |
+
def get_context_prompt(machine_agent, human_agent, environment):
|
50 |
+
return f"Here is the context of this interaction:\n Scenario: {environment.scenario}\nParticipants: {human_agent.name} and {machine_agent.name}\n{human_agent.name}'s background: {human_agent.background} Personality and values description: {human_agent.personality} \n{machine_agent.name}'s background: {machine_agent.background} Personality and values description: {machine_agent.personality} {machine_agent.name}'s secrets: {machine_agent.secret}\n{human_agent.name}'s goal: Unknown\n{machine_agent.name}'s goal: {environment.agent_goals[1]}\nConversation Starts:"
|
51 |
|
52 |
|
53 |
# we define history as
|
|
|
105 |
last_turn_idx = len(history) * 2
|
106 |
return dialogue_history, last_turn_idx
|
107 |
|
108 |
+
def dialogue_history_prompt(message, history, user_agent, bot_agent):
|
109 |
+
dialogue_history = ""
|
110 |
+
for idx, turn in enumerate(history):
|
111 |
+
user_message, bot_message = turn
|
112 |
+
# TODOTODO (haofeiyu): we first assume that human talks first
|
113 |
+
user_turn_idx = idx * 2
|
114 |
+
bot_turn_idx = idx * 2 + 1
|
115 |
+
if not bot_message.startswith("["): # if action type == speak, need to add 'said: ' to be consistent with the dialog prompt
|
116 |
+
bot_message = "said :" + bot_message
|
117 |
+
dialogue_history = f"{dialogue_history}\n\nTurn #{user_turn_idx}: {user_agent.name}: {user_message}\n\nTurn #{bot_turn_idx}: {bot_agent.name}: {bot_message}"
|
118 |
+
last_turn_idx = len(history) * 2
|
119 |
+
dialogue_history = f"{dialogue_history}\n\nTurn #{last_turn_idx+1}: {user_agent.name}: {message}\n."
|
120 |
+
return dialogue_history, last_turn_idx+2
|
121 |
+
|
122 |
|
123 |
def dialogue_history_truncation(dialogue_history, max_token_num, tokenizer):
|
124 |
surpass_num = dialogue_history_length_check(
|
|
|
131 |
return dialogue_history
|
132 |
|
133 |
|
134 |
+
def format_hostory_prompt(
|
135 |
message: str,
|
136 |
history: List[Tuple[str, str]],
|
137 |
instructions: str,
|
138 |
user_name: str,
|
139 |
bot_name: str,
|
|
|
|
|
|
|
140 |
) -> str:
|
141 |
prompt = instructions.strip()
|
142 |
dialogue_history, last_turn_idx = dialogue_history_creation(
|
|
|
144 |
)
|
145 |
prompt = f"{prompt}\n{dialogue_history}"
|
146 |
prompt = f"{prompt}\n\nTurn #{last_turn_idx+1}: {user_name}: {message}\n.\nYou are at Turn #{last_turn_idx+2}."
|
147 |
+
return prompt
|