Wonderplex commited on
Commit
237ffdd
·
1 Parent(s): a3f9a2a

added parsing error check (#41)

Browse files
Files changed (2) hide show
  1. app.py +28 -10
  2. utils.py +7 -1
app.py CHANGED
@@ -16,7 +16,7 @@ from utils import Agent, format_sotopia_prompt, get_starter_prompt, format_bot_m
16
  from functools import cache
17
 
18
  DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
19
- DEFAULT_MODEL_SELECTION = "cmu-lti/sotopia-pi-mistral-7b-BC_SR"
20
 
21
  def prepare_sotopia_info():
22
  human_agent = Agent(
@@ -59,6 +59,13 @@ def prepare(model_name):
59
  )
60
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
61
  model = PeftModel.from_pretrained(model, model_name).to("cuda")
 
 
 
 
 
 
 
62
  else:
63
  raise RuntimeError(f"Model {model_name} not supported")
64
  return model, tokenizer
@@ -107,13 +114,15 @@ def param_accordion(according_visible=True):
107
  interactive=True,
108
  label="Max Tokens",
109
  )
110
- session_id = gr.Textbox(
111
- value=uuid4,
112
- interactive=False,
113
- visible=False,
114
- label="Session ID",
 
 
115
  )
116
- return temperature, session_id, max_tokens, model_name
117
 
118
 
119
  def sotopia_info_accordion(human_agent, machine_agent, scenario, accordion_visible=True):
@@ -192,11 +201,20 @@ def chat_tab():
192
  text_output = tokenizer.decode(
193
  output_tokens[0], skip_special_tokens=True
194
  )
195
- return format_bot_message(text_output)
 
 
 
 
 
 
 
 
 
196
 
197
  with gr.Column():
198
  with gr.Row():
199
- temperature, session_id, max_tokens, model = param_accordion()
200
  user_name, bot_name, scenario = sotopia_info_accordion(human_agent, machine_agent, scenario)
201
 
202
  instructions = instructions_accordion(instructions)
@@ -226,7 +244,7 @@ def chat_tab():
226
  user_name,
227
  bot_name,
228
  temperature,
229
- session_id,
230
  max_tokens,
231
  model,
232
  ],
 
16
  from functools import cache
17
 
18
  DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
19
+ DEFAULT_MODEL_SELECTION = "cmu-lti/sotopia-pi-mistral-7b-BC_SR" # "mistralai/Mistral-7B-Instruct-v0.1"
20
 
21
  def prepare_sotopia_info():
22
  human_agent = Agent(
 
59
  )
60
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
61
  model = PeftModel.from_pretrained(model, model_name).to("cuda")
62
+ elif 'mistralai/Mistral-7B-Instruct-v0.1' in model_name:
63
+ model = AutoModelForCausalLM.from_pretrained(
64
+ "mistralai/Mistral-7B-Instruct-v0.1",
65
+ cache_dir="./.cache",
66
+ device_map='cuda',
67
+ )
68
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
69
  else:
70
  raise RuntimeError(f"Model {model_name} not supported")
71
  return model, tokenizer
 
114
  interactive=True,
115
  label="Max Tokens",
116
  )
117
+ top_p = gr.Slider(
118
+ minimum=1,
119
+ maximum=3,
120
+ value=1,
121
+ interactive=True,
122
+ visible=True,
123
+ label="Top p",
124
  )
125
+ return temperature, top_p, max_tokens, model_name
126
 
127
 
128
  def sotopia_info_accordion(human_agent, machine_agent, scenario, accordion_visible=True):
 
201
  text_output = tokenizer.decode(
202
  output_tokens[0], skip_special_tokens=True
203
  )
204
+ # import pdb; pdb.set_trace()
205
+ output = ""
206
+ for _ in range(5):
207
+ try:
208
+ output = format_bot_message(text_output)
209
+ break
210
+ except Exception as e:
211
+ print(e)
212
+ print("Retrying...")
213
+ return output
214
 
215
  with gr.Column():
216
  with gr.Row():
217
+ temperature, top_p, max_tokens, model = param_accordion()
218
  user_name, bot_name, scenario = sotopia_info_accordion(human_agent, machine_agent, scenario)
219
 
220
  instructions = instructions_accordion(instructions)
 
244
  user_name,
245
  bot_name,
246
  temperature,
247
+ top_p,
248
  max_tokens,
249
  model,
250
  ],
utils.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import List, Tuple
2
  import ast
 
3
 
4
  FORMAT_TEMPLATE = """ Your available action types are
5
  "none action speak non-verbal communication leave".
@@ -49,7 +50,12 @@ def truncate_dialogue_history_to_length(dia_his, surpass_num, tokenizer):
49
 
50
 
51
  def format_bot_message(bot_message) -> str:
52
- json_response = ast.literal_eval(bot_message)
 
 
 
 
 
53
  match json_response["action_type"]:
54
  case "none":
55
  return 'did nothing'
 
1
  from typing import List, Tuple
2
  import ast
3
+ import re
4
 
5
  FORMAT_TEMPLATE = """ Your available action types are
6
  "none action speak non-verbal communication leave".
 
50
 
51
 
52
  def format_bot_message(bot_message) -> str:
53
+ # import pdb; pdb.set_trace()
54
+ start_idx, end_idx = bot_message.index("{"), bot_message.index("}")
55
+ if end_idx == -1:
56
+ bot_message += "'}"
57
+ end_idx = len(bot_message)
58
+ json_response = ast.literal_eval(bot_message[start_idx:end_idx+1])
59
  match json_response["action_type"]:
60
  case "none":
61
  return 'did nothing'