File size: 13,412 Bytes
2649124
 
54d66e1
c634ddd
 
c61089d
 
700f11b
2649124
 
 
700f11b
 
c634ddd
700f11b
c634ddd
700f11b
2649124
700f11b
 
 
2649124
c634ddd
 
700f11b
 
 
03f43d0
700f11b
 
2649124
 
 
 
 
 
 
 
700f11b
2649124
 
 
c634ddd
 
700f11b
c634ddd
2649124
700f11b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2649124
 
 
700f11b
 
 
 
 
2649124
c634ddd
700f11b
 
f371bf3
c634ddd
 
 
 
700f11b
 
c634ddd
2649124
c61089d
 
 
42b2eba
c61089d
2649124
c61089d
8e58f6e
c61089d
 
 
 
 
 
 
2649124
 
54d66e1
 
 
 
 
 
 
2649124
 
 
54d66e1
15af633
2649124
 
 
 
700f11b
c634ddd
 
 
 
700f11b
 
 
 
c634ddd
54d66e1
2649124
3a8cc44
0342ce4
2649124
700f11b
c634ddd
 
 
700f11b
c634ddd
 
 
2649124
700f11b
c634ddd
700f11b
2649124
 
 
 
c634ddd
 
 
 
 
 
e3ac915
 
c634ddd
 
 
e3ac915
 
56e14a4
e3ac915
c634ddd
 
 
 
 
 
 
29ff717
15af633
c634ddd
700f11b
 
 
15af633
3f815a2
2649124
 
c634ddd
 
 
 
2649124
c885d38
e1f0dec
c634ddd
 
2649124
 
 
700f11b
2649124
 
 
7164d93
2649124
 
 
 
 
 
c634ddd
 
6f2d9ff
c634ddd
 
 
 
 
 
 
 
 
 
6f2d9ff
c634ddd
 
 
 
 
 
 
 
6f2d9ff
c634ddd
 
6f2d9ff
 
c634ddd
 
 
 
 
cb638a7
 
 
c634ddd
6f2d9ff
c634ddd
 
 
7da64e0
 
c634ddd
 
 
6f2d9ff
c634ddd
 
6f2d9ff
c634ddd
 
7da64e0
c634ddd
 
 
7da64e0
c634ddd
7da64e0
 
c634ddd
7da64e0
c634ddd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
from openai import OpenAI
import json_repair
from transformers import AutoTokenizer
from prompts import *
import re
from tenacity import retry, wait_fixed, stop_after_attempt, retry_if_exception_type
from openai import RateLimitError
from difflib import get_close_matches


class ChatbotSimulation:
    def __init__(self, app_name, app_description, site_map, relevant_tables_per_page,
                 database, jinjia_prerender_page, task, solution,
                 log_location, openai_api_key, agent='human',
                 max_steps=30, max_tokens=8192, buffer_tokens=500):
        self.app_name = app_name
        self.app_description = app_description
        self.sitemap = site_map
        self.relevant_tables_per_page = relevant_tables_per_page
        self.database = database
        self.jinjia_prerender_page = jinjia_prerender_page
        self.task = task
        self.solution = solution

        self.user_state = dict()
        self.user_state['current_page'] = self.sitemap['pages'][0]['id']  # Initialize current page
        self.user_state['task_completed'] = 'False'
        self.user_state['logged_in'] = 'False'
        self.user_state['back'] = 'False'

        self.log_location = log_location
        self.agent = agent.lower()
        if self.agent not in ['human', 'llm']:
            raise ValueError("Invalid agent type. Expected 'Human' or 'llm'.")
        self.max_steps = max_steps
        self.max_tokens = max_tokens
        self.buffer_tokens = buffer_tokens
        self.conversation = []  # Stores recent conversation snippets
        self.trajectory = [{"role": "system", "content": f"Welcome to {app_name} simulator! Your task is: {task}"}]
        self.prompt_count = 0
        self.client = OpenAI(api_key=openai_api_key)
        self.actions = []
        self.tokenizer = AutoTokenizer.from_pretrained("gpt2", clean_up_tokenization_spaces=True)

        #back button
        self.page_history = ['Home']

    def _get_relevant_data(self, current_page):
        # Check if the current page exists as a key
        if current_page in self.relevant_tables_per_page:
            relevant_tables = self.relevant_tables_per_page[current_page]
        else:
            # Find the closest matching key
            closest_match = get_close_matches(current_page, self.relevant_tables_per_page.keys(), n=1, cutoff=0.5)
            if closest_match:
                relevant_tables = self.relevant_tables_per_page[closest_match[0]]
            else:
                return self.database

        return {table: self.database[table] for table in relevant_tables if table in self.database}

    def _get_prerender_page(self, current_page):
        if current_page in self.jinjia_prerender_page:
            return self.jinjia_prerender_page[current_page]
        else:
            closest_match = get_close_matches(current_page, self.jinjia_prerender_page.keys(), n=1, cutoff=0)
            return self.jinjia_prerender_page[closest_match[0]]

    def _generate_system_prompt(self):
        """Create a dynamic system prompt based on the current state."""
        current_page = self.page_history[-1] if len(self.page_history) >= 1 else self.sitemap['pages'][0]['id']
        last_page = self.page_history[-2] if len(self.page_history) > 1 else self.sitemap['pages'][0]['id']
        relevant_database = self._get_relevant_data(current_page)
        relevant_sitemap = next((page for page in self.sitemap["pages"] if page["id"] == current_page), self.sitemap["pages"])
        prerender_page = self._get_prerender_page(current_page)

        return get_system_prompt(app_name=self.app_name,
                                 app_description=self.app_description,
                                 relevant_database=relevant_database,
                                 user_state=self.user_state,
                                 task=self.task,
                                 current_page=current_page,
                                 last_page=last_page,
                                 actions=self.actions,
                                 sitemap_page=relevant_sitemap,
                                 jinjia_prerender=prerender_page,
                                 )

    @retry(
        retry=retry_if_exception_type(RateLimitError),
        wait=wait_fixed(5),  # Waits for 5 seconds between retries
        stop=stop_after_attempt(50000)  # Stops after 5 failed attempts
    )
    def _get_openai_response(self, prompt):
        """Fetch response from OpenAI API using tenacity for handling retries."""
        self._trim_conversation()
        response = self.client.chat.completions.create(
            model="gpt-4",
            messages=prompt,
            max_tokens=self.buffer_tokens,  # Adjusted max_tokens if needed
            temperature=0.7,
        )
        return response.choices[0].message.content

    def _calculate_token_count(self, conversation):
        """Accurately calculate the token count in the conversation using a tokenizer."""
        total_tokens = 0
        for entry in conversation:
            # Tokenize each entry content and count tokens
            tokens = self.tokenizer.encode(entry['content'], truncation=False, add_special_tokens=False)
            total_tokens += len(tokens)
        return total_tokens

    def _trim_conversation(self):
        """Trim the conversation to keep it within the token limit."""
        while self._calculate_token_count(self.conversation) >= (self.max_tokens - self.buffer_tokens * 2):
            self.conversation.pop(0)

    def one_conversation_round(self, user_input):
        """Conduct one round of conversation between the user and the assistant."""
        # User provides input
        self.trajectory.append({"role": "user", "content": f'Human: {user_input}'})
        valid_input = self._is_valid_input(user_input)
        if valid_input[0]:
            pass
        else:
            self.prompt_count += 1
            invalid_input_message = f"\n{self.app_name}: Invalid input. {valid_input[1]}"
            self.trajectory.append({"role": "assistant", "content": invalid_input_message})
            return invalid_input_message

        self.actions.append(user_input + f'on {self.user_state["current_page"]} page')
        self.conversation.append({"role": "user", "content": user_input})
        self.prompt_count += 1

        # Update user state using GPT's response
        current_page = self.page_history[-1] if len(self.page_history) >= 1 else self.sitemap['pages'][0]['id']
        update_prompt = get_user_state_update_prompt(user_input=user_input,
                                                     current_page=current_page,
                                                     task=self.task,
                                                     database=self.database,
                                                     solution=self.solution,
                                                     user_state=self.user_state,
                                                     sitemap=self.sitemap)

        self.conversation.append({"role": "user", "content": update_prompt})
        updated_state = self._get_openai_response(self.conversation).split("UPDATED", 1)[1].strip()
        self.conversation.pop(-1) # update prompt don't have to stay in conversation history

        # Parse and update the user state
        updated_state = json_repair.loads(updated_state)

        # format forcing of updated state
        required_keys = {'current_page', 'task_completed', 'back'}
        # Ensure `updated_state` is a dictionary
        while not isinstance(updated_state, dict):
            transform_prompt = f"""
            Transform {updated_state} to a properly formatted JSON file.
            Example Output Format:
            {{
               'current_page': 'Home',
               'task_completed': False,
               'back': False
            }}
            """
            updated_state = self._get_openai_response([{"role": "system", "content": transform_prompt}])
            updated_state = json_repair.loads(updated_state)
        # Manually add missing required keys
        for key in required_keys:
            if key not in updated_state:
                if key == 'current_page':
                    updated_state[key] = self.page_history[-1] if len(self.page_history) >= 1 else "Home"
                else:
                    updated_state[key] = False

        try:
            if str(updated_state['task_completed']).lower() == 'true':
                complete_message = f"{self.app_name}: Task completed! You took {self.prompt_count} steps."
                self.trajectory.append({"role": "assistant", "content": complete_message})
                return complete_message
        except:
            updated_state['task_completed'] = 'False'

        self.user_state = updated_state
        if str(updated_state['back']).lower() == 'false':
            self.page_history.append(updated_state['current_page'])
        elif self.page_history:
            self.page_history.pop()

        ## no need to store old system prompt while we get a new one
        self.conversation = [entry for entry in self.conversation if entry["role"] != "system"]
        system_prompt = self._generate_system_prompt()
        # GPT generates the page instructions
        self.conversation.append({"role": "system", "content": system_prompt})
        gpt_instruction = self._get_openai_response(self.conversation)
        self.conversation.append({"role": "assistant", "content": gpt_instruction})
        self.trajectory.append({"role": "assistant", "content": gpt_instruction})
        return gpt_instruction

    def start_conversation(self):
        greeting = f'\nWelcome to {self.app_name} simulator! Your task is: {self.task} \n'
        system_prompt = self._generate_system_prompt()
        # GPT generates the page instructions
        self.conversation.append({"role": "system", "content": system_prompt})
        gpt_instruction = self._get_openai_response(self.conversation)
        self.conversation.append({"role": "assistant", "content": gpt_instruction})
        return greeting + gpt_instruction

    def _extract_buttons(self):
        """Extract button numbers and their action types from the latest conversation if role is 'assistant'."""
        # Get the last message
        last_message = self.conversation[-1]

        # Ensure the role of the last message is 'assistant'
        if last_message.get("role") != "assistant":
            return {}

        # Extract the content of the last message
        message_content = last_message.get("content", "")

        # Split the message content to isolate the button section
        options_split = re.split(r"you have the following options:", message_content, flags=re.IGNORECASE)

        # If the split doesn't produce at least two parts, return an empty dictionary
        if len(options_split) < 2:
            return {}

        # Extract button definitions from the second part of the split content
        button_section = options_split[1]
        pattern = r"(\d+)\.\s+(.*?):\s+([a-zA-Z_]+)"  # Capture the number, button name, and action type
        buttons = re.findall(pattern, button_section)

        # Construct the dictionary with button numbers as keys and action types as values
        return {number: action_type.strip().lower() for number, _, action_type in buttons}

    def _is_valid_input(self, user_input):
        """Validate user input format."""
        valid_buttons = self._extract_buttons()

        if valid_buttons == {}:
            return [True, "Enter Anything is empty"]

        # Validate input format
        pattern = r"^(?P<action_type>\w+)\((?P<button_number>[^,]+)(?:,\s*(?P<query>.+))?\)$"
        match = re.match(pattern, user_input)

        if not match:
            return [False,
                    "Your input doesn't match the format: action_type(button number), OR if text_box, use text_box(button number, query), eg. noop(12). No indent before input and No extra input before or after action_type(button number)!"]

        # Extract parsed components
        action_type = match.group("action_type").lower()
        button_name = match.group("button_number").strip().lower()
        query = match.group("query")  # Optional query for `type`

        # Validate button number and action type
        if button_name not in valid_buttons:
            return [False,
                    "Invalid Button number! Recall: Each button is in the format: `number. button name: action_type`. Correct example: link(3), text_box(2, query)"]  # Button number must match exactly (case insensitive)
        if action_type != valid_buttons[button_name]:
            return [False,
                    "Invalid action type! Recall: Each button is in the format: `number. button name: action_type`"]  # Action type must match the button's specified type
        if action_type == "text_box" and query is None:
            return [False,
                    "Missing Query for action type 'text_box'! Recall: use the format: `text_box(button number, query)`"]  # `text_box` action requires a query
        if action_type != "text_box" and query is not None:
            return [False,
                    "Non-`text_box` action_type cannot take query!"]  # Non-`type` actions must not have a query
        return [True, 'Pass']