Spaces:
Sleeping
Sleeping
File size: 13,412 Bytes
2649124 54d66e1 c634ddd c61089d 700f11b 2649124 700f11b c634ddd 700f11b c634ddd 700f11b 2649124 700f11b 2649124 c634ddd 700f11b 03f43d0 700f11b 2649124 700f11b 2649124 c634ddd 700f11b c634ddd 2649124 700f11b 2649124 700f11b 2649124 c634ddd 700f11b f371bf3 c634ddd 700f11b c634ddd 2649124 c61089d 42b2eba c61089d 2649124 c61089d 8e58f6e c61089d 2649124 54d66e1 2649124 54d66e1 15af633 2649124 700f11b c634ddd 700f11b c634ddd 54d66e1 2649124 3a8cc44 0342ce4 2649124 700f11b c634ddd 700f11b c634ddd 2649124 700f11b c634ddd 700f11b 2649124 c634ddd e3ac915 c634ddd e3ac915 56e14a4 e3ac915 c634ddd 29ff717 15af633 c634ddd 700f11b 15af633 3f815a2 2649124 c634ddd 2649124 c885d38 e1f0dec c634ddd 2649124 700f11b 2649124 7164d93 2649124 c634ddd 6f2d9ff c634ddd 6f2d9ff c634ddd 6f2d9ff c634ddd 6f2d9ff c634ddd cb638a7 c634ddd 6f2d9ff c634ddd 7da64e0 c634ddd 6f2d9ff c634ddd 6f2d9ff c634ddd 7da64e0 c634ddd 7da64e0 c634ddd 7da64e0 c634ddd 7da64e0 c634ddd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 |
from openai import OpenAI
import json_repair
from transformers import AutoTokenizer
from prompts import *
import re
from tenacity import retry, wait_fixed, stop_after_attempt, retry_if_exception_type
from openai import RateLimitError
from difflib import get_close_matches
class ChatbotSimulation:
def __init__(self, app_name, app_description, site_map, relevant_tables_per_page,
database, jinjia_prerender_page, task, solution,
log_location, openai_api_key, agent='human',
max_steps=30, max_tokens=8192, buffer_tokens=500):
self.app_name = app_name
self.app_description = app_description
self.sitemap = site_map
self.relevant_tables_per_page = relevant_tables_per_page
self.database = database
self.jinjia_prerender_page = jinjia_prerender_page
self.task = task
self.solution = solution
self.user_state = dict()
self.user_state['current_page'] = self.sitemap['pages'][0]['id'] # Initialize current page
self.user_state['task_completed'] = 'False'
self.user_state['logged_in'] = 'False'
self.user_state['back'] = 'False'
self.log_location = log_location
self.agent = agent.lower()
if self.agent not in ['human', 'llm']:
raise ValueError("Invalid agent type. Expected 'Human' or 'llm'.")
self.max_steps = max_steps
self.max_tokens = max_tokens
self.buffer_tokens = buffer_tokens
self.conversation = [] # Stores recent conversation snippets
self.trajectory = [{"role": "system", "content": f"Welcome to {app_name} simulator! Your task is: {task}"}]
self.prompt_count = 0
self.client = OpenAI(api_key=openai_api_key)
self.actions = []
self.tokenizer = AutoTokenizer.from_pretrained("gpt2", clean_up_tokenization_spaces=True)
#back button
self.page_history = ['Home']
def _get_relevant_data(self, current_page):
# Check if the current page exists as a key
if current_page in self.relevant_tables_per_page:
relevant_tables = self.relevant_tables_per_page[current_page]
else:
# Find the closest matching key
closest_match = get_close_matches(current_page, self.relevant_tables_per_page.keys(), n=1, cutoff=0.5)
if closest_match:
relevant_tables = self.relevant_tables_per_page[closest_match[0]]
else:
return self.database
return {table: self.database[table] for table in relevant_tables if table in self.database}
def _get_prerender_page(self, current_page):
if current_page in self.jinjia_prerender_page:
return self.jinjia_prerender_page[current_page]
else:
closest_match = get_close_matches(current_page, self.jinjia_prerender_page.keys(), n=1, cutoff=0)
return self.jinjia_prerender_page[closest_match[0]]
def _generate_system_prompt(self):
"""Create a dynamic system prompt based on the current state."""
current_page = self.page_history[-1] if len(self.page_history) >= 1 else self.sitemap['pages'][0]['id']
last_page = self.page_history[-2] if len(self.page_history) > 1 else self.sitemap['pages'][0]['id']
relevant_database = self._get_relevant_data(current_page)
relevant_sitemap = next((page for page in self.sitemap["pages"] if page["id"] == current_page), self.sitemap["pages"])
prerender_page = self._get_prerender_page(current_page)
return get_system_prompt(app_name=self.app_name,
app_description=self.app_description,
relevant_database=relevant_database,
user_state=self.user_state,
task=self.task,
current_page=current_page,
last_page=last_page,
actions=self.actions,
sitemap_page=relevant_sitemap,
jinjia_prerender=prerender_page,
)
@retry(
retry=retry_if_exception_type(RateLimitError),
wait=wait_fixed(5), # Waits for 5 seconds between retries
stop=stop_after_attempt(50000) # Stops after 5 failed attempts
)
def _get_openai_response(self, prompt):
"""Fetch response from OpenAI API using tenacity for handling retries."""
self._trim_conversation()
response = self.client.chat.completions.create(
model="gpt-4",
messages=prompt,
max_tokens=self.buffer_tokens, # Adjusted max_tokens if needed
temperature=0.7,
)
return response.choices[0].message.content
def _calculate_token_count(self, conversation):
"""Accurately calculate the token count in the conversation using a tokenizer."""
total_tokens = 0
for entry in conversation:
# Tokenize each entry content and count tokens
tokens = self.tokenizer.encode(entry['content'], truncation=False, add_special_tokens=False)
total_tokens += len(tokens)
return total_tokens
def _trim_conversation(self):
"""Trim the conversation to keep it within the token limit."""
while self._calculate_token_count(self.conversation) >= (self.max_tokens - self.buffer_tokens * 2):
self.conversation.pop(0)
def one_conversation_round(self, user_input):
"""Conduct one round of conversation between the user and the assistant."""
# User provides input
self.trajectory.append({"role": "user", "content": f'Human: {user_input}'})
valid_input = self._is_valid_input(user_input)
if valid_input[0]:
pass
else:
self.prompt_count += 1
invalid_input_message = f"\n{self.app_name}: Invalid input. {valid_input[1]}"
self.trajectory.append({"role": "assistant", "content": invalid_input_message})
return invalid_input_message
self.actions.append(user_input + f'on {self.user_state["current_page"]} page')
self.conversation.append({"role": "user", "content": user_input})
self.prompt_count += 1
# Update user state using GPT's response
current_page = self.page_history[-1] if len(self.page_history) >= 1 else self.sitemap['pages'][0]['id']
update_prompt = get_user_state_update_prompt(user_input=user_input,
current_page=current_page,
task=self.task,
database=self.database,
solution=self.solution,
user_state=self.user_state,
sitemap=self.sitemap)
self.conversation.append({"role": "user", "content": update_prompt})
updated_state = self._get_openai_response(self.conversation).split("UPDATED", 1)[1].strip()
self.conversation.pop(-1) # update prompt don't have to stay in conversation history
# Parse and update the user state
updated_state = json_repair.loads(updated_state)
# format forcing of updated state
required_keys = {'current_page', 'task_completed', 'back'}
# Ensure `updated_state` is a dictionary
while not isinstance(updated_state, dict):
transform_prompt = f"""
Transform {updated_state} to a properly formatted JSON file.
Example Output Format:
{{
'current_page': 'Home',
'task_completed': False,
'back': False
}}
"""
updated_state = self._get_openai_response([{"role": "system", "content": transform_prompt}])
updated_state = json_repair.loads(updated_state)
# Manually add missing required keys
for key in required_keys:
if key not in updated_state:
if key == 'current_page':
updated_state[key] = self.page_history[-1] if len(self.page_history) >= 1 else "Home"
else:
updated_state[key] = False
try:
if str(updated_state['task_completed']).lower() == 'true':
complete_message = f"{self.app_name}: Task completed! You took {self.prompt_count} steps."
self.trajectory.append({"role": "assistant", "content": complete_message})
return complete_message
except:
updated_state['task_completed'] = 'False'
self.user_state = updated_state
if str(updated_state['back']).lower() == 'false':
self.page_history.append(updated_state['current_page'])
elif self.page_history:
self.page_history.pop()
## no need to store old system prompt while we get a new one
self.conversation = [entry for entry in self.conversation if entry["role"] != "system"]
system_prompt = self._generate_system_prompt()
# GPT generates the page instructions
self.conversation.append({"role": "system", "content": system_prompt})
gpt_instruction = self._get_openai_response(self.conversation)
self.conversation.append({"role": "assistant", "content": gpt_instruction})
self.trajectory.append({"role": "assistant", "content": gpt_instruction})
return gpt_instruction
def start_conversation(self):
greeting = f'\nWelcome to {self.app_name} simulator! Your task is: {self.task} \n'
system_prompt = self._generate_system_prompt()
# GPT generates the page instructions
self.conversation.append({"role": "system", "content": system_prompt})
gpt_instruction = self._get_openai_response(self.conversation)
self.conversation.append({"role": "assistant", "content": gpt_instruction})
return greeting + gpt_instruction
def _extract_buttons(self):
"""Extract button numbers and their action types from the latest conversation if role is 'assistant'."""
# Get the last message
last_message = self.conversation[-1]
# Ensure the role of the last message is 'assistant'
if last_message.get("role") != "assistant":
return {}
# Extract the content of the last message
message_content = last_message.get("content", "")
# Split the message content to isolate the button section
options_split = re.split(r"you have the following options:", message_content, flags=re.IGNORECASE)
# If the split doesn't produce at least two parts, return an empty dictionary
if len(options_split) < 2:
return {}
# Extract button definitions from the second part of the split content
button_section = options_split[1]
pattern = r"(\d+)\.\s+(.*?):\s+([a-zA-Z_]+)" # Capture the number, button name, and action type
buttons = re.findall(pattern, button_section)
# Construct the dictionary with button numbers as keys and action types as values
return {number: action_type.strip().lower() for number, _, action_type in buttons}
def _is_valid_input(self, user_input):
"""Validate user input format."""
valid_buttons = self._extract_buttons()
if valid_buttons == {}:
return [True, "Enter Anything is empty"]
# Validate input format
pattern = r"^(?P<action_type>\w+)\((?P<button_number>[^,]+)(?:,\s*(?P<query>.+))?\)$"
match = re.match(pattern, user_input)
if not match:
return [False,
"Your input doesn't match the format: action_type(button number), OR if text_box, use text_box(button number, query), eg. noop(12). No indent before input and No extra input before or after action_type(button number)!"]
# Extract parsed components
action_type = match.group("action_type").lower()
button_name = match.group("button_number").strip().lower()
query = match.group("query") # Optional query for `type`
# Validate button number and action type
if button_name not in valid_buttons:
return [False,
"Invalid Button number! Recall: Each button is in the format: `number. button name: action_type`. Correct example: link(3), text_box(2, query)"] # Button number must match exactly (case insensitive)
if action_type != valid_buttons[button_name]:
return [False,
"Invalid action type! Recall: Each button is in the format: `number. button name: action_type`"] # Action type must match the button's specified type
if action_type == "text_box" and query is None:
return [False,
"Missing Query for action type 'text_box'! Recall: use the format: `text_box(button number, query)`"] # `text_box` action requires a query
if action_type != "text_box" and query is not None:
return [False,
"Non-`text_box` action_type cannot take query!"] # Non-`type` actions must not have a query
return [True, 'Pass']
|