App_Simulator / chatbot_simulator.py
jjz5463's picture
update invalid input check
f371bf3
from openai import OpenAI
import json_repair
from transformers import AutoTokenizer
from prompts import *
import re
from tenacity import retry, wait_fixed, stop_after_attempt, retry_if_exception_type
from openai import RateLimitError
from difflib import get_close_matches
class ChatbotSimulation:
def __init__(self, app_name, app_description, site_map, relevant_tables_per_page,
database, jinjia_prerender_page, task, solution,
log_location, openai_api_key, agent='human',
max_steps=30, max_tokens=8192, buffer_tokens=500):
self.app_name = app_name
self.app_description = app_description
self.sitemap = site_map
self.relevant_tables_per_page = relevant_tables_per_page
self.database = database
self.jinjia_prerender_page = jinjia_prerender_page
self.task = task
self.solution = solution
self.user_state = dict()
self.user_state['current_page'] = self.sitemap['pages'][0]['id'] # Initialize current page
self.user_state['task_completed'] = 'False'
self.user_state['logged_in'] = 'False'
self.user_state['back'] = 'False'
self.log_location = log_location
self.agent = agent.lower()
if self.agent not in ['human', 'llm']:
raise ValueError("Invalid agent type. Expected 'Human' or 'llm'.")
self.max_steps = max_steps
self.max_tokens = max_tokens
self.buffer_tokens = buffer_tokens
self.conversation = [] # Stores recent conversation snippets
self.trajectory = [{"role": "system", "content": f"Welcome to {app_name} simulator! Your task is: {task}"}]
self.prompt_count = 0
self.client = OpenAI(api_key=openai_api_key)
self.actions = []
self.tokenizer = AutoTokenizer.from_pretrained("gpt2", clean_up_tokenization_spaces=True)
#back button
self.page_history = ['Home']
def _get_relevant_data(self, current_page):
# Check if the current page exists as a key
if current_page in self.relevant_tables_per_page:
relevant_tables = self.relevant_tables_per_page[current_page]
else:
# Find the closest matching key
closest_match = get_close_matches(current_page, self.relevant_tables_per_page.keys(), n=1, cutoff=0.5)
if closest_match:
relevant_tables = self.relevant_tables_per_page[closest_match[0]]
else:
return self.database
return {table: self.database[table] for table in relevant_tables if table in self.database}
def _get_prerender_page(self, current_page):
if current_page in self.jinjia_prerender_page:
return self.jinjia_prerender_page[current_page]
else:
closest_match = get_close_matches(current_page, self.jinjia_prerender_page.keys(), n=1, cutoff=0)
return self.jinjia_prerender_page[closest_match[0]]
def _generate_system_prompt(self):
"""Create a dynamic system prompt based on the current state."""
current_page = self.page_history[-1] if len(self.page_history) >= 1 else self.sitemap['pages'][0]['id']
last_page = self.page_history[-2] if len(self.page_history) > 1 else self.sitemap['pages'][0]['id']
relevant_database = self._get_relevant_data(current_page)
relevant_sitemap = next((page for page in self.sitemap["pages"] if page["id"] == current_page), self.sitemap["pages"])
prerender_page = self._get_prerender_page(current_page)
return get_system_prompt(app_name=self.app_name,
app_description=self.app_description,
relevant_database=relevant_database,
user_state=self.user_state,
task=self.task,
current_page=current_page,
last_page=last_page,
actions=self.actions,
sitemap_page=relevant_sitemap,
jinjia_prerender=prerender_page,
)
@retry(
retry=retry_if_exception_type(RateLimitError),
wait=wait_fixed(5), # Waits for 5 seconds between retries
stop=stop_after_attempt(50000) # Stops after 5 failed attempts
)
def _get_openai_response(self, prompt):
"""Fetch response from OpenAI API using tenacity for handling retries."""
self._trim_conversation()
response = self.client.chat.completions.create(
model="gpt-4",
messages=prompt,
max_tokens=self.buffer_tokens, # Adjusted max_tokens if needed
temperature=0.7,
)
return response.choices[0].message.content
def _calculate_token_count(self, conversation):
"""Accurately calculate the token count in the conversation using a tokenizer."""
total_tokens = 0
for entry in conversation:
# Tokenize each entry content and count tokens
tokens = self.tokenizer.encode(entry['content'], truncation=False, add_special_tokens=False)
total_tokens += len(tokens)
return total_tokens
def _trim_conversation(self):
"""Trim the conversation to keep it within the token limit."""
while self._calculate_token_count(self.conversation) >= (self.max_tokens - self.buffer_tokens * 2):
self.conversation.pop(0)
def one_conversation_round(self, user_input):
"""Conduct one round of conversation between the user and the assistant."""
# User provides input
self.trajectory.append({"role": "user", "content": f'Human: {user_input}'})
valid_input = self._is_valid_input(user_input)
if valid_input[0]:
pass
else:
self.prompt_count += 1
invalid_input_message = f"\n{self.app_name}: Invalid input. {valid_input[1]}"
self.trajectory.append({"role": "assistant", "content": invalid_input_message})
return invalid_input_message
self.actions.append(user_input + f'on {self.user_state["current_page"]} page')
self.conversation.append({"role": "user", "content": user_input})
self.prompt_count += 1
# Update user state using GPT's response
current_page = self.page_history[-1] if len(self.page_history) >= 1 else self.sitemap['pages'][0]['id']
update_prompt = get_user_state_update_prompt(user_input=user_input,
current_page=current_page,
task=self.task,
database=self.database,
solution=self.solution,
user_state=self.user_state,
sitemap=self.sitemap)
self.conversation.append({"role": "user", "content": update_prompt})
updated_state = self._get_openai_response(self.conversation).split("UPDATED", 1)[1].strip()
self.conversation.pop(-1) # update prompt don't have to stay in conversation history
# Parse and update the user state
updated_state = json_repair.loads(updated_state)
# format forcing of updated state
required_keys = {'current_page', 'task_completed', 'back'}
# Ensure `updated_state` is a dictionary
while not isinstance(updated_state, dict):
transform_prompt = f"""
Transform {updated_state} to a properly formatted JSON file.
Example Output Format:
{{
'current_page': 'Home',
'task_completed': False,
'back': False
}}
"""
updated_state = self._get_openai_response([{"role": "system", "content": transform_prompt}])
updated_state = json_repair.loads(updated_state)
# Manually add missing required keys
for key in required_keys:
if key not in updated_state:
if key == 'current_page':
updated_state[key] = self.page_history[-1] if len(self.page_history) >= 1 else "Home"
else:
updated_state[key] = False
try:
if str(updated_state['task_completed']).lower() == 'true':
complete_message = f"{self.app_name}: Task completed! You took {self.prompt_count} steps."
self.trajectory.append({"role": "assistant", "content": complete_message})
return complete_message
except:
updated_state['task_completed'] = 'False'
self.user_state = updated_state
if str(updated_state['back']).lower() == 'false':
self.page_history.append(updated_state['current_page'])
elif self.page_history:
self.page_history.pop()
## no need to store old system prompt while we get a new one
self.conversation = [entry for entry in self.conversation if entry["role"] != "system"]
system_prompt = self._generate_system_prompt()
# GPT generates the page instructions
self.conversation.append({"role": "system", "content": system_prompt})
gpt_instruction = self._get_openai_response(self.conversation)
self.conversation.append({"role": "assistant", "content": gpt_instruction})
self.trajectory.append({"role": "assistant", "content": gpt_instruction})
return gpt_instruction
def start_conversation(self):
greeting = f'\nWelcome to {self.app_name} simulator! Your task is: {self.task} \n'
system_prompt = self._generate_system_prompt()
# GPT generates the page instructions
self.conversation.append({"role": "system", "content": system_prompt})
gpt_instruction = self._get_openai_response(self.conversation)
self.conversation.append({"role": "assistant", "content": gpt_instruction})
return greeting + gpt_instruction
def _extract_buttons(self):
"""Extract button numbers and their action types from the latest conversation if role is 'assistant'."""
# Get the last message
last_message = self.conversation[-1]
# Ensure the role of the last message is 'assistant'
if last_message.get("role") != "assistant":
return {}
# Extract the content of the last message
message_content = last_message.get("content", "")
# Split the message content to isolate the button section
options_split = re.split(r"you have the following options:", message_content, flags=re.IGNORECASE)
# If the split doesn't produce at least two parts, return an empty dictionary
if len(options_split) < 2:
return {}
# Extract button definitions from the second part of the split content
button_section = options_split[1]
pattern = r"(\d+)\.\s+(.*?):\s+([a-zA-Z_]+)" # Capture the number, button name, and action type
buttons = re.findall(pattern, button_section)
# Construct the dictionary with button numbers as keys and action types as values
return {number: action_type.strip().lower() for number, _, action_type in buttons}
def _is_valid_input(self, user_input):
"""Validate user input format."""
valid_buttons = self._extract_buttons()
if valid_buttons == {}:
return [True, "Enter Anything is empty"]
# Validate input format
pattern = r"^(?P<action_type>\w+)\((?P<button_number>[^,]+)(?:,\s*(?P<query>.+))?\)$"
match = re.match(pattern, user_input)
if not match:
return [False,
"Your input doesn't match the format: action_type(button number), OR if text_box, use text_box(button number, query), eg. noop(12). No indent before input and No extra input before or after action_type(button number)!"]
# Extract parsed components
action_type = match.group("action_type").lower()
button_name = match.group("button_number").strip().lower()
query = match.group("query") # Optional query for `type`
# Validate button number and action type
if button_name not in valid_buttons:
return [False,
"Invalid Button number! Recall: Each button is in the format: `number. button name: action_type`. Correct example: link(3), text_box(2, query)"] # Button number must match exactly (case insensitive)
if action_type != valid_buttons[button_name]:
return [False,
"Invalid action type! Recall: Each button is in the format: `number. button name: action_type`"] # Action type must match the button's specified type
if action_type == "text_box" and query is None:
return [False,
"Missing Query for action type 'text_box'! Recall: use the format: `text_box(button number, query)`"] # `text_box` action requires a query
if action_type != "text_box" and query is not None:
return [False,
"Non-`text_box` action_type cannot take query!"] # Non-`type` actions must not have a query
return [True, 'Pass']