Spaces:
Runtime error
Runtime error
Commit
·
f325d08
1
Parent(s):
1880ac6
feature/add-Model-Selection-UI-part (#27)
Browse files* some minor updates
* modified: app.py
* modified: app.py
* modified: app.py
* modified: app.py
* modified: app.py
* modified: app.py
* modified: app.py
* Delete leaderboard.py
will commit later
* This version is purely local and runnable
* delete pycache
---------
Co-authored-by: Haofei Yu <1125027232@qq.com>
app.py
CHANGED
@@ -41,23 +41,24 @@ def prepare_sotopia_info():
|
|
41 |
return human_agent, machine_agent, scenario, instructions
|
42 |
|
43 |
|
44 |
-
|
45 |
-
|
|
|
|
|
46 |
compute_type = torch.float16
|
47 |
config_dict = PeftConfig.from_json_file("peft_config.json")
|
48 |
config = PeftConfig.from_peft_type(**config_dict)
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
"cuda"
|
57 |
-
)
|
58 |
return model, tokenizer
|
59 |
|
60 |
|
|
|
61 |
def introduction():
|
62 |
with gr.Column(scale=2):
|
63 |
gr.Image(
|
@@ -79,6 +80,12 @@ def introduction():
|
|
79 |
|
80 |
def param_accordion(according_visible=True):
|
81 |
with gr.Accordion("Parameters", open=False, visible=according_visible):
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
temperature = gr.Slider(
|
83 |
minimum=0.1,
|
84 |
maximum=1.0,
|
@@ -101,7 +108,7 @@ def param_accordion(according_visible=True):
|
|
101 |
visible=False,
|
102 |
label="Session ID",
|
103 |
)
|
104 |
-
return temperature, session_id, max_tokens
|
105 |
|
106 |
|
107 |
def sotopia_info_accordion(
|
@@ -168,7 +175,10 @@ def run_chat(
|
|
168 |
temperature: float,
|
169 |
top_p: float,
|
170 |
max_tokens: int,
|
|
|
|
|
171 |
):
|
|
|
172 |
prompt = format_sotopia_prompt(
|
173 |
message, history, instructions, user_name, bot_name
|
174 |
)
|
@@ -190,7 +200,7 @@ def run_chat(
|
|
190 |
|
191 |
|
192 |
def chat_tab():
|
193 |
-
model, tokenizer = prepare()
|
194 |
human_agent, machine_agent, scenario, instructions = prepare_sotopia_info()
|
195 |
|
196 |
# history are input output pairs
|
@@ -203,7 +213,9 @@ def chat_tab():
|
|
203 |
temperature: float,
|
204 |
top_p: float,
|
205 |
max_tokens: int,
|
|
|
206 |
):
|
|
|
207 |
prompt = format_sotopia_prompt(
|
208 |
message, history, instructions, user_name, bot_name
|
209 |
)
|
@@ -227,10 +239,9 @@ def chat_tab():
|
|
227 |
|
228 |
with gr.Column():
|
229 |
with gr.Row():
|
230 |
-
temperature, session_id, max_tokens = param_accordion()
|
231 |
-
user_name, bot_name, scenario = sotopia_info_accordion(
|
232 |
-
|
233 |
-
)
|
234 |
instructions = instructions_accordion(instructions)
|
235 |
|
236 |
with gr.Column():
|
@@ -260,6 +271,7 @@ def chat_tab():
|
|
260 |
temperature,
|
261 |
session_id,
|
262 |
max_tokens,
|
|
|
263 |
],
|
264 |
submit_btn="Send",
|
265 |
stop_btn="Stop",
|
|
|
41 |
return human_agent, machine_agent, scenario, instructions
|
42 |
|
43 |
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
def prepare(model_name):
|
48 |
compute_type = torch.float16
|
49 |
config_dict = PeftConfig.from_json_file("peft_config.json")
|
50 |
config = PeftConfig.from_peft_type(**config_dict)
|
51 |
+
|
52 |
+
if 'mistral'in model_name:
|
53 |
+
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1").to("cuda")
|
54 |
+
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
|
55 |
+
model = PeftModel.from_pretrained(model, model_name, config=config).to(compute_type).to("cuda")
|
56 |
+
else:
|
57 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
58 |
return model, tokenizer
|
59 |
|
60 |
|
61 |
+
|
62 |
def introduction():
|
63 |
with gr.Column(scale=2):
|
64 |
gr.Image(
|
|
|
80 |
|
81 |
def param_accordion(according_visible=True):
|
82 |
with gr.Accordion("Parameters", open=False, visible=according_visible):
|
83 |
+
model_name = gr.Dropdown(
|
84 |
+
choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "mistralai/Mistral-7B-Instruct-v0.1", "GPT3.5"], # Example model choices
|
85 |
+
value="cmu-lti/sotopia-pi-mistral-7b-BC_SR", # Default value
|
86 |
+
interactive=True,
|
87 |
+
label="Model Selection",
|
88 |
+
)
|
89 |
temperature = gr.Slider(
|
90 |
minimum=0.1,
|
91 |
maximum=1.0,
|
|
|
108 |
visible=False,
|
109 |
label="Session ID",
|
110 |
)
|
111 |
+
return temperature, session_id, max_tokens, model_name
|
112 |
|
113 |
|
114 |
def sotopia_info_accordion(
|
|
|
175 |
temperature: float,
|
176 |
top_p: float,
|
177 |
max_tokens: int,
|
178 |
+
model_selection:str
|
179 |
+
|
180 |
):
|
181 |
+
model, tokenizer = prepare(model_selection)
|
182 |
prompt = format_sotopia_prompt(
|
183 |
message, history, instructions, user_name, bot_name
|
184 |
)
|
|
|
200 |
|
201 |
|
202 |
def chat_tab():
|
203 |
+
#model, tokenizer = prepare()
|
204 |
human_agent, machine_agent, scenario, instructions = prepare_sotopia_info()
|
205 |
|
206 |
# history are input output pairs
|
|
|
213 |
temperature: float,
|
214 |
top_p: float,
|
215 |
max_tokens: int,
|
216 |
+
model_selection:str
|
217 |
):
|
218 |
+
model, tokenizer = prepare(model_selection)
|
219 |
prompt = format_sotopia_prompt(
|
220 |
message, history, instructions, user_name, bot_name
|
221 |
)
|
|
|
239 |
|
240 |
with gr.Column():
|
241 |
with gr.Row():
|
242 |
+
temperature, session_id, max_tokens, model = param_accordion()
|
243 |
+
user_name, bot_name, scenario = sotopia_info_accordion(human_agent, machine_agent, scenario)
|
244 |
+
|
|
|
245 |
instructions = instructions_accordion(instructions)
|
246 |
|
247 |
with gr.Column():
|
|
|
271 |
temperature,
|
272 |
session_id,
|
273 |
max_tokens,
|
274 |
+
model,
|
275 |
],
|
276 |
submit_btn="Send",
|
277 |
stop_btn="Stop",
|