zetavg commited on
Commit
517781a
1 Parent(s): fb9b56d

auto select prompt template for model

Browse files
llama_lora/ui/inference_ui.py CHANGED
@@ -11,7 +11,8 @@ from ..models import get_base_model, get_model_with_lora, get_tokenizer, get_dev
11
  from ..utils.data import (
12
  get_available_template_names,
13
  get_available_lora_model_names,
14
- get_path_of_available_lora_model)
 
15
  from ..utils.prompter import Prompter
16
  from ..utils.callbacks import Iteratorize, Stream
17
 
@@ -41,7 +42,9 @@ def do_inference(
41
  prompter = Prompter(prompt_template)
42
  prompt = prompter.generate_prompt(variables)
43
 
44
- if lora_model_name is not None and "/" not in lora_model_name and lora_model_name != "None":
 
 
45
  path_of_available_lora_model = get_path_of_available_lora_model(
46
  lora_model_name)
47
  if path_of_available_lora_model:
@@ -75,7 +78,7 @@ def do_inference(
75
  return
76
 
77
  model = get_base_model()
78
- if not lora_model_name == "None" and lora_model_name is not None:
79
  model = get_model_with_lora(lora_model_name)
80
  tokenizer = get_tokenizer()
81
 
@@ -172,7 +175,7 @@ def reload_selections(current_lora_model, current_prompt_template):
172
  gr.Dropdown.update(choices=available_template_names_with_none, value=current_prompt_template))
173
 
174
 
175
- def handle_prompt_template_change(prompt_template):
176
  prompter = Prompter(prompt_template)
177
  var_names = prompter.get_variable_names()
178
  human_var_names = [' '.join(word.capitalize()
@@ -182,7 +185,35 @@ def handle_prompt_template_change(prompt_template):
182
  while len(gr_updates) < 8:
183
  gr_updates.append(gr.Textbox.update(
184
  label="Not Used", visible=False))
185
- return gr_updates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
 
188
  def update_prompt_preview(prompt_template,
@@ -200,12 +231,15 @@ def inference_ui():
200
 
201
  with gr.Blocks() as inference_ui_blocks:
202
  with gr.Row():
203
- lora_model = gr.Dropdown(
204
- label="LoRA Model",
205
- elem_id="inference_lora_model",
206
- value="tloen/alpaca-lora-7b",
207
- allow_custom_value=True,
208
- )
 
 
 
209
  prompt_template = gr.Dropdown(
210
  label="Prompt Template",
211
  elem_id="inference_prompt_template",
@@ -346,10 +380,20 @@ def inference_ui():
346
  )
347
  things_that_might_timeout.append(reload_selections_event)
348
 
349
- prompt_template_change_event = prompt_template.change(fn=handle_prompt_template_change, inputs=[prompt_template], outputs=[
350
- variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
 
 
 
 
351
  things_that_might_timeout.append(prompt_template_change_event)
352
 
 
 
 
 
 
 
353
  generate_event = generate_btn.click(
354
  fn=do_inference,
355
  inputs=[
 
11
  from ..utils.data import (
12
  get_available_template_names,
13
  get_available_lora_model_names,
14
+ get_path_of_available_lora_model,
15
+ get_info_of_available_lora_model)
16
  from ..utils.prompter import Prompter
17
  from ..utils.callbacks import Iteratorize, Stream
18
 
 
42
  prompter = Prompter(prompt_template)
43
  prompt = prompter.generate_prompt(variables)
44
 
45
+ if not lora_model_name:
46
+ lora_model_name = "None"
47
+ if "/" not in lora_model_name and lora_model_name != "None":
48
  path_of_available_lora_model = get_path_of_available_lora_model(
49
  lora_model_name)
50
  if path_of_available_lora_model:
 
78
  return
79
 
80
  model = get_base_model()
81
+ if lora_model_name != "None":
82
  model = get_model_with_lora(lora_model_name)
83
  tokenizer = get_tokenizer()
84
 
 
175
  gr.Dropdown.update(choices=available_template_names_with_none, value=current_prompt_template))
176
 
177
 
178
+ def handle_prompt_template_change(prompt_template, lora_model):
179
  prompter = Prompter(prompt_template)
180
  var_names = prompter.get_variable_names()
181
  human_var_names = [' '.join(word.capitalize()
 
185
  while len(gr_updates) < 8:
186
  gr_updates.append(gr.Textbox.update(
187
  label="Not Used", visible=False))
188
+
189
+ model_prompt_template_message_update = gr.Markdown.update("", visible=False)
190
+ lora_mode_info = get_info_of_available_lora_model(lora_model)
191
+ if lora_mode_info and isinstance(lora_mode_info, dict):
192
+ model_prompt_template = lora_mode_info.get("prompt_template")
193
+ if model_prompt_template and model_prompt_template != prompt_template:
194
+ model_prompt_template_message_update = gr.Markdown.update(
195
+ f"Trained with prompt template `{model_prompt_template}`", visible=True)
196
+
197
+ return [model_prompt_template_message_update] + gr_updates
198
+
199
+
200
+ def handle_lora_model_change(lora_model, prompt_template):
201
+ lora_mode_info = get_info_of_available_lora_model(lora_model)
202
+ if not lora_mode_info:
203
+ return gr.Markdown.update("", visible=False), prompt_template
204
+
205
+ if not isinstance(lora_mode_info, dict):
206
+ return gr.Markdown.update("", visible=False), prompt_template
207
+
208
+ model_prompt_template = lora_mode_info.get("prompt_template")
209
+ if not model_prompt_template:
210
+ return gr.Markdown.update("", visible=False), prompt_template
211
+
212
+ available_template_names = get_available_template_names()
213
+ if model_prompt_template in available_template_names:
214
+ return gr.Markdown.update("", visible=False), model_prompt_template
215
+
216
+ return gr.Markdown.update(f"Trained with prompt template `{model_prompt_template}`", visible=True), prompt_template
217
 
218
 
219
  def update_prompt_preview(prompt_template,
 
231
 
232
  with gr.Blocks() as inference_ui_blocks:
233
  with gr.Row():
234
+ with gr.Column(elem_id="inference_lora_model_group"):
235
+ model_prompt_template_message = gr.Markdown(
236
+ "", visible=False, elem_id="inference_lora_model_prompt_template_message")
237
+ lora_model = gr.Dropdown(
238
+ label="LoRA Model",
239
+ elem_id="inference_lora_model",
240
+ value="tloen/alpaca-lora-7b",
241
+ allow_custom_value=True,
242
+ )
243
  prompt_template = gr.Dropdown(
244
  label="Prompt Template",
245
  elem_id="inference_prompt_template",
 
380
  )
381
  things_that_might_timeout.append(reload_selections_event)
382
 
383
+ prompt_template_change_event = prompt_template.change(
384
+ fn=handle_prompt_template_change,
385
+ inputs=[prompt_template, lora_model],
386
+ outputs=[
387
+ model_prompt_template_message,
388
+ variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
389
  things_that_might_timeout.append(prompt_template_change_event)
390
 
391
+ lora_model_change_event = lora_model.change(
392
+ fn=handle_lora_model_change,
393
+ inputs=[lora_model, prompt_template],
394
+ outputs=[model_prompt_template_message, prompt_template])
395
+ things_that_might_timeout.append(lora_model_change_event)
396
+
397
  generate_event = generate_btn.click(
398
  fn=do_inference,
399
  inputs=[
llama_lora/ui/main_page.py CHANGED
@@ -134,6 +134,41 @@ def main_page_custom_css():
134
  /* text-transform: uppercase; */
135
  }
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  #inference_prompt_box > *:first-child {
138
  border-bottom-left-radius: 0;
139
  border-bottom-right-radius: 0;
@@ -266,12 +301,16 @@ def main_page_custom_css():
266
  }
267
 
268
  @media screen and (min-width: 640px) {
269
- #inference_lora_model, #finetune_template {
 
270
  border-top-right-radius: 0;
271
  border-bottom-right-radius: 0;
272
  border-right: 0;
273
  margin-right: -16px;
274
  }
 
 
 
275
 
276
  #inference_prompt_template {
277
  border-top-left-radius: 0;
@@ -301,7 +340,7 @@ def main_page_custom_css():
301
  height: 42px !important;
302
  min-width: 42px !important;
303
  width: 42px !important;
304
- z-index: 1;
305
  }
306
  }
307
 
 
134
  /* text-transform: uppercase; */
135
  }
136
 
137
+ #inference_lora_model_group {
138
+ border-radius: var(--block-radius);
139
+ background: var(--block-background-fill);
140
+ }
141
+ #inference_lora_model_group #inference_lora_model {
142
+ background: transparent;
143
+ }
144
+ #inference_lora_model_prompt_template_message:not(.hidden) + #inference_lora_model {
145
+ padding-bottom: 28px;
146
+ }
147
+ #inference_lora_model_group > #inference_lora_model_prompt_template_message {
148
+ position: absolute;
149
+ bottom: 8px;
150
+ left: 20px;
151
+ z-index: 1;
152
+ font-size: 12px;
153
+ opacity: 0.7;
154
+ }
155
+ #inference_lora_model_group > #inference_lora_model_prompt_template_message p {
156
+ font-size: 12px;
157
+ }
158
+ #inference_lora_model_prompt_template_message > .wrap {
159
+ display: none;
160
+ }
161
+ #inference_lora_model > .wrap:first-child:not(.hide),
162
+ #inference_prompt_template > .wrap:first-child:not(.hide) {
163
+ opacity: 0.5;
164
+ }
165
+ #inference_lora_model_group, #inference_lora_model {
166
+ z-index: 60;
167
+ }
168
+ #inference_prompt_template {
169
+ z-index: 55;
170
+ }
171
+
172
  #inference_prompt_box > *:first-child {
173
  border-bottom-left-radius: 0;
174
  border-bottom-right-radius: 0;
 
301
  }
302
 
303
  @media screen and (min-width: 640px) {
304
+ #inference_lora_model, #inference_lora_model_group,
305
+ #finetune_template {
306
  border-top-right-radius: 0;
307
  border-bottom-right-radius: 0;
308
  border-right: 0;
309
  margin-right: -16px;
310
  }
311
+ #inference_lora_model_group #inference_lora_model {
312
+ box-shadow: var(--block-shadow);
313
+ }
314
 
315
  #inference_prompt_template {
316
  border-top-left-radius: 0;
 
340
  height: 42px !important;
341
  min-width: 42px !important;
342
  width: 42px !important;
343
+ z-index: 61;
344
  }
345
  }
346
 
llama_lora/utils/data.py CHANGED
@@ -52,6 +52,22 @@ def get_path_of_available_lora_model(name):
52
  return None
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def get_dataset_content(name):
56
  file_name = os.path.join(Global.data_dir, "datasets", name)
57
  if not os.path.exists(file_name):
 
52
  return None
53
 
54
 
55
+ def get_info_of_available_lora_model(name):
56
+ try:
57
+ if "/" in name:
58
+ return None
59
+ path_of_available_lora_model = get_path_of_available_lora_model(
60
+ name)
61
+ if not path_of_available_lora_model:
62
+ return None
63
+
64
+ with open(os.path.join(path_of_available_lora_model, "info.json"), "r") as json_file:
65
+ return json.load(json_file)
66
+
67
+ except Exception as e:
68
+ return None
69
+
70
+
71
  def get_dataset_content(name):
72
  file_name = os.path.join(Global.data_dir, "datasets", name)
73
  if not os.path.exists(file_name):